Neodroid  0.2.0
Machine Learning Environment Prototyping Tool
ObjectiveFunction.cs
Go to the documentation of this file.
1 using System;
2 using System.Globalization;
8 using UnityEngine;
9 
10 namespace droid.Runtime.Prototyping.Evaluation {
14  [Serializable]
15  public abstract class ObjectiveFunction : PrototypingGameObject,
16  //IHasRegister<Term>,
17  //IResetable,
18  IObjective {
21  [SerializeField]
22  protected float _solved_reward = 1.0f;
23 
26  [SerializeField]
27  float _failed_reward = -1.0f;
28 
31  [SerializeField]
32  protected float _default_reward = -0.001f;
33 
36  [SerializeField]
37  public float _Episode_Return;
38 
42  public override String PrototypingTypeName { get { return ""; } }
43 
46  public AbstractPrototypingEnvironment ParentEnvironment {
47  get { return this._environment; }
48  set { this._environment = value; }
49  }
50 
51  /*
56  public virtual void Register(Term term) { this.Register(term, term.Identifier); }
57 
58 
64  public void Register(Term term, string identifier) {
65  if (!this._Extra_Terms_Dict.ContainsKey(identifier)) {
66  #if NEODROID_DEBUG
67  if (this.Debugging) {
68  Debug.Log($"ObjectiveFunction {this.name} has registered term {identifier}");
69  }
70  #endif
71 
72  this._Extra_Terms_Dict.Add(identifier, term);
73  this._Extra_Term_Weights.Add(term, 1);
74  } else {
75  Debug.LogWarning($"WARNING! Please check for duplicates, ObjectiveFunction {this.name} already has term {identifier} registered");
76  }
77  }
78 
83  public void UnRegister(Term term, string identifier) {
84  if (this._Extra_Terms_Dict.ContainsKey(identifier)) {
85  #if NEODROID_DEBUG
86  if (this.Debugging) {
87  Debug.Log($"ObjectiveFunction {this.name} unregistered term {identifier}");
88  }
89  #endif
90 
91  this._Extra_Term_Weights.Remove(this._Extra_Terms_Dict[identifier]);
92  this._Extra_Terms_Dict.Remove(identifier);
93  }
94  }
95 
99  public void UnRegister(Term term) { this.UnRegister(term, term.Identifier); }
100 */
103  public float SolvedThreshold {
104  get { return this._solved_threshold; }
105  set { this._solved_threshold = value; }
106  }
107 
112  public float Evaluate() {
113  var signal = 0.0f;
114  signal += this.InternalEvaluate();
115  //signal += this.EvaluateExtraTerms();
116 
117  //signal = signal * Mathf.Pow(this._internal_discount_factor, this._environment.CurrentFrameNumber);
118 
119  if (this.EpisodeLength > 0 && this._environment.CurrentFrameNumber >= this.EpisodeLength) {
120  #if NEODROID_DEBUG
121  if (this.Debugging) {
122  Debug.Log($"Maximum episode length reached, Length {this._environment.CurrentFrameNumber}");
123  }
124  #endif
125 
126  signal = this._failed_reward;
127 
128  this._environment.Terminate("Maximum episode length reached");
129  }
130 
131  #if NEODROID_DEBUG
132  if (this.Debugging) {
133  Debug.Log(signal);
134  }
135  #endif
136 
137  this._last_signal = signal;
138 
139  this._Episode_Return += signal;
140 
141  return signal;
142  }
143 
147  public void EnvironmentReset() {
148  this._last_signal = 0;
149  this._Episode_Return = 0;
150  this.InternalReset();
151  }
152 
156  protected override void Clear() {
157  /*
158  this._Extra_Term_Weights.Clear();
159  this._Extra_Terms_Dict.Clear();
160  */
161  }
162 
166  protected sealed override void Setup() {
167  //foreach (var go in this._extra_terms_external)
168  // this.Register(go);
169 
170  if (this.ParentEnvironment == null) {
171  this.ParentEnvironment = FindObjectOfType<AbstractPrototypingEnvironment>();
172  }
173 
174  this.PostSetup();
175  }
176 
179  protected virtual void PostSetup() { }
180 
184  public void SignalString(DataPoller recipient) {
185  recipient.PollData($"{this._last_signal.ToString(CultureInfo.InvariantCulture)}, {this._Episode_Return}");
186  }
187 
191  protected override void RegisterComponent() { }
192 
196  protected override void UnRegisterComponent() { }
197 
201  public abstract float InternalEvaluate();
202 
205  public abstract void InternalReset();
206 
207  /*
212  public virtual void AdjustExtraTermsWeights(Term term, float new_weight) {
213  if (this._Extra_Term_Weights.ContainsKey(term)) {
214  this._Extra_Term_Weights[term] = new_weight;
215  }
216  }
217 
221  public virtual float EvaluateExtraTerms() {
222  float extra_terms_output = 0;
223  foreach (var term in this._Extra_Terms_Dict.Values) {
224  #if NEODROID_DEBUG
225  if (this.Debugging) {
226  Debug.Log($"Extra term: {term}");
227  }
228  #endif
229 
230  extra_terms_output += this._Extra_Term_Weights[term] * term.Evaluate();
231  }
232 
233  #if NEODROID_DEBUG
234  if (this.Debugging) {
235  Debug.Log($"Extra terms signal: {extra_terms_output}");
236  }
237  #endif
238  return extra_terms_output;
239  }
240 */
241 
242  #region Fields
243 
244  [Header("References", order = 100)]
245  [SerializeField]
246  //[SerializeField]float _internal_discount_factor = 1.0f;
247  AbstractPrototypingEnvironment _environment = null;
248 
249  //[SerializeField] Term[] _extra_terms_external;
250 
251  //[SerializeField] protected Dictionary<string, Term> _Extra_Terms_Dict = new Dictionary<string, Term>();
252 
253  //[SerializeField] protected Dictionary<Term, float> _Extra_Term_Weights = new Dictionary<Term, float>();
254 
255  [Header("General", order = 101)]
256  [SerializeField]
257  float _solved_threshold = 0f;
258 
259  [SerializeField] float _last_signal = 0f;
260 
263  [SerializeField]
264  int _episode_length = 1000;
265 
269  public int EpisodeLength { get { return this._episode_length; } set { this._episode_length = value; } }
270 
274  public Space1 SignalSpace { get; set; }
275 
276  #endregion
277  }
278 }