Skip to content

Commit 7a6327d

Browse files
author
Chris Elion
authored
Agent.Heuristic takes an float[] (#3765)
1 parent e77b919 commit 7a6327d

File tree

21 files changed

+93
-106
lines changed

21 files changed

+93
-106
lines changed

Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,10 @@ public override void OnEpisodeBegin()
6666
SetResetParameters();
6767
}
6868

69-
public override float[] Heuristic()
69+
public override void Heuristic(float[] actionsOut)
7070
{
71-
var action = new float[2];
72-
73-
action[0] = -Input.GetAxis("Horizontal");
74-
action[1] = Input.GetAxis("Vertical");
75-
return action;
71+
actionsOut[0] = -Input.GetAxis("Horizontal");
72+
actionsOut[1] = Input.GetAxis("Vertical");
7673
}
7774

7875
public void SetBall()

Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,11 @@ void FixedUpdate()
102102
}
103103
}
104104

105-
public override float[] Heuristic()
105+
public override void Heuristic(float[] actionsOut)
106106
{
107-
var action = new float[3];
108-
109-
action[0] = Input.GetAxis("Horizontal");
110-
action[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
111-
action[2] = Input.GetAxis("Vertical");
112-
return action;
107+
actionsOut[0] = Input.GetAxis("Horizontal");
108+
actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
109+
actionsOut[2] = Input.GetAxis("Vertical");
113110
}
114111

115112
void Update()

Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,27 +207,25 @@ public override void OnActionReceived(float[] vectorAction)
207207
MoveAgent(vectorAction);
208208
}
209209

210-
public override float[] Heuristic()
210+
public override void Heuristic(float[] actionsOut)
211211
{
212-
var action = new float[4];
213212
if (Input.GetKey(KeyCode.D))
214213
{
215-
action[2] = 2f;
214+
actionsOut[2] = 2f;
216215
}
217216
if (Input.GetKey(KeyCode.W))
218217
{
219-
action[0] = 1f;
218+
actionsOut[0] = 1f;
220219
}
221220
if (Input.GetKey(KeyCode.A))
222221
{
223-
action[2] = 1f;
222+
actionsOut[2] = 1f;
224223
}
225224
if (Input.GetKey(KeyCode.S))
226225
{
227-
action[0] = 2f;
226+
actionsOut[0] = 2f;
228227
}
229-
action[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
230-
return action;
228+
actionsOut[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
231229
}
232230

233231
public override void OnEpisodeBegin()

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,25 +108,25 @@ public override void OnActionReceived(float[] vectorAction)
108108
}
109109
}
110110

111-
public override float[] Heuristic()
111+
public override void Heuristic(float[] actionsOut)
112112
{
113+
actionsOut[0] = k_NoAction;
113114
if (Input.GetKey(KeyCode.D))
114115
{
115-
return new float[] { k_Right };
116+
actionsOut[0] = k_Right;
116117
}
117118
if (Input.GetKey(KeyCode.W))
118119
{
119-
return new float[] { k_Up };
120+
actionsOut[0] = k_Up;
120121
}
121122
if (Input.GetKey(KeyCode.A))
122123
{
123-
return new float[] { k_Left };
124+
actionsOut[0] = k_Left;
124125
}
125126
if (Input.GetKey(KeyCode.S))
126127
{
127-
return new float[] { k_Down };
128+
actionsOut[0] = k_Down;
128129
}
129-
return new float[] { k_NoAction };
130130
}
131131

132132
// to be implemented by the developer

Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,25 +91,25 @@ void OnCollisionEnter(Collision col)
9191
}
9292
}
9393

94-
public override float[] Heuristic()
94+
public override void Heuristic(float[] actionsOut)
9595
{
96+
actionsOut[0] = 0;
9697
if (Input.GetKey(KeyCode.D))
9798
{
98-
return new float[] { 3 };
99+
actionsOut[0] = 3;
99100
}
100-
if (Input.GetKey(KeyCode.W))
101+
else if (Input.GetKey(KeyCode.W))
101102
{
102-
return new float[] { 1 };
103+
actionsOut[0] = 1;
103104
}
104-
if (Input.GetKey(KeyCode.A))
105+
else if (Input.GetKey(KeyCode.A))
105106
{
106-
return new float[] { 4 };
107+
actionsOut[0] = 4;
107108
}
108-
if (Input.GetKey(KeyCode.S))
109+
else if (Input.GetKey(KeyCode.S))
109110
{
110-
return new float[] { 2 };
111+
actionsOut[0] = 2;
111112
}
112-
return new float[] { 0 };
113113
}
114114

115115
public override void OnEpisodeBegin()

Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,25 +170,25 @@ public override void OnActionReceived(float[] vectorAction)
170170
AddReward(-1f / maxStep);
171171
}
172172

173-
public override float[] Heuristic()
173+
public override void Heuristic(float[] actionsOut)
174174
{
175+
actionsOut[0] = 0;
175176
if (Input.GetKey(KeyCode.D))
176177
{
177-
return new float[] { 3 };
178+
actionsOut[0] = 3;
178179
}
179-
if (Input.GetKey(KeyCode.W))
180+
else if (Input.GetKey(KeyCode.W))
180181
{
181-
return new float[] { 1 };
182+
actionsOut[0] = 1;
182183
}
183-
if (Input.GetKey(KeyCode.A))
184+
else if (Input.GetKey(KeyCode.A))
184185
{
185-
return new float[] { 4 };
186+
actionsOut[0] = 4;
186187
}
187-
if (Input.GetKey(KeyCode.S))
188+
else if (Input.GetKey(KeyCode.S))
188189
{
189-
return new float[] { 2 };
190+
actionsOut[0] = 2;
190191
}
191-
return new float[] { 0 };
192192
}
193193

194194
/// <summary>

Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,25 +61,25 @@ public override void OnActionReceived(float[] vectorAction)
6161
MoveAgent(vectorAction);
6262
}
6363

64-
public override float[] Heuristic()
64+
public override void Heuristic(float[] actionsOut)
6565
{
66+
actionsOut[0] = 0;
6667
if (Input.GetKey(KeyCode.D))
6768
{
68-
return new float[] { 3 };
69+
actionsOut[0] = 3;
6970
}
70-
if (Input.GetKey(KeyCode.W))
71+
else if (Input.GetKey(KeyCode.W))
7172
{
72-
return new float[] { 1 };
73+
actionsOut[0] = 1;
7374
}
74-
if (Input.GetKey(KeyCode.A))
75+
else if (Input.GetKey(KeyCode.A))
7576
{
76-
return new float[] { 4 };
77+
actionsOut[0] = 4;
7778
}
78-
if (Input.GetKey(KeyCode.S))
79+
else if (Input.GetKey(KeyCode.S))
7980
{
80-
return new float[] { 2 };
81+
actionsOut[0] = 2;
8182
}
82-
return new float[] { 0 };
8383
}
8484

8585
public override void OnEpisodeBegin()

Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,37 +112,35 @@ public override void OnActionReceived(float[] vectorAction)
112112
MoveAgent(vectorAction);
113113
}
114114

115-
public override float[] Heuristic()
115+
public override void Heuristic(float[] actionsOut)
116116
{
117-
var action = new float[3];
118117
//forward
119118
if (Input.GetKey(KeyCode.W))
120119
{
121-
action[0] = 1f;
120+
actionsOut[0] = 1f;
122121
}
123122
if (Input.GetKey(KeyCode.S))
124123
{
125-
action[0] = 2f;
124+
actionsOut[0] = 2f;
126125
}
127126
//rotate
128127
if (Input.GetKey(KeyCode.A))
129128
{
130-
action[2] = 1f;
129+
actionsOut[2] = 1f;
131130
}
132131
if (Input.GetKey(KeyCode.D))
133132
{
134-
action[2] = 2f;
133+
actionsOut[2] = 2f;
135134
}
136135
//right
137136
if (Input.GetKey(KeyCode.E))
138137
{
139-
action[1] = 1f;
138+
actionsOut[1] = 1f;
140139
}
141140
if (Input.GetKey(KeyCode.Q))
142141
{
143-
action[1] = 2f;
142+
actionsOut[1] = 2f;
144143
}
145-
return action;
146144
}
147145
/// <summary>
148146
/// Used to provide a "kick" to the ball.

Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,11 @@ public override void OnActionReceived(float[] vectorAction)
8686
m_TextComponent.text = score.ToString();
8787
}
8888

89-
public override float[] Heuristic()
89+
public override void Heuristic(float[] actionsOut)
9090
{
91-
var action = new float[3];
92-
93-
action[0] = Input.GetAxis("Horizontal"); // Racket Movement
94-
action[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
95-
action[2] = Input.GetAxis("Vertical"); // Racket Rotation
96-
return action;
91+
actionsOut[0] = Input.GetAxis("Horizontal"); // Racket Movement
92+
actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
93+
actionsOut[2] = Input.GetAxis("Vertical"); // Racket Rotation
9794
}
9895

9996
public override void OnEpisodeBegin()

Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,27 +241,25 @@ public override void OnActionReceived(float[] vectorAction)
241241
}
242242
}
243243

244-
public override float[] Heuristic()
244+
public override void Heuristic(float[] actionsOut)
245245
{
246-
var action = new float[4];
247246
if (Input.GetKey(KeyCode.D))
248247
{
249-
action[1] = 2f;
248+
actionsOut[1] = 2f;
250249
}
251250
if (Input.GetKey(KeyCode.W))
252251
{
253-
action[0] = 1f;
252+
actionsOut[0] = 1f;
254253
}
255254
if (Input.GetKey(KeyCode.A))
256255
{
257-
action[1] = 1f;
256+
actionsOut[1] = 1f;
258257
}
259258
if (Input.GetKey(KeyCode.S))
260259
{
261-
action[0] = 2f;
260+
actionsOut[0] = 2f;
262261
}
263-
action[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
264-
return action;
262+
actionsOut[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
265263
}
266264

267265
// Detect when the agent hits the goal

0 commit comments

Comments
 (0)