Skip to content

Commit 6d70c0d

Browse files
author
Chris Elion
committed
Merge remote-tracking branch 'origin/master' into develop-academy-singleton
2 parents 71f2e75 + f404675 commit 6d70c0d

35 files changed

+461
-405
lines changed

UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public void TestStoreInitalize()
6464
storedVectorActions = new[] { 0f, 1f },
6565
};
6666

67-
demoStore.Record(agentInfo);
67+
demoStore.Record(agentInfo, new System.Collections.Generic.List<Sensor.Observation>());
6868
demoStore.Close();
6969
}
7070

UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs

Lines changed: 29 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Reflection;
55
using Barracuda;
66
using MLAgents.InferenceBrain;
7+
using System;
78

89
namespace MLAgents.Tests
910
{
@@ -19,16 +20,6 @@ public AgentAction GetAction()
1920
}
2021
}
2122

22-
List<Agent> GetFakeAgentInfos()
23-
{
24-
var goA = new GameObject("goA");
25-
var agentA = goA.AddComponent<TestAgent>();
26-
var goB = new GameObject("goB");
27-
var agentB = goB.AddComponent<TestAgent>();
28-
29-
return new List<Agent> { agentA, agentB };
30-
}
31-
3223
[Test]
3324
public void Construction()
3425
{
@@ -48,25 +39,27 @@ public void ApplyContinuousActionOutput()
4839
shape = new long[] { 2, 3 },
4940
data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
5041
};
51-
var agentInfos = GetFakeAgentInfos();
5242

5343
var applier = new ContinuousActionOutputApplier();
54-
applier.Apply(inputTensor, agentInfos);
55-
var agents = agentInfos;
5644

57-
var agent = agents[0] as TestAgent;
58-
Assert.NotNull(agent);
59-
var action = agent.GetAction();
60-
Assert.AreEqual(action.vectorActions[0], 1);
61-
Assert.AreEqual(action.vectorActions[1], 2);
62-
Assert.AreEqual(action.vectorActions[2], 3);
45+
var action0 = new AgentAction();
46+
var action1 = new AgentAction();
47+
var callbacks = new List<AgentIdActionPair>()
48+
{
49+
new AgentIdActionPair{agentId = 0, action = (a) => action0 = a},
50+
new AgentIdActionPair{agentId = 1, action = (a) => action1 = a}
51+
};
52+
53+
applier.Apply(inputTensor, callbacks);
54+
55+
56+
Assert.AreEqual(action0.vectorActions[0], 1);
57+
Assert.AreEqual(action0.vectorActions[1], 2);
58+
Assert.AreEqual(action0.vectorActions[2], 3);
6359

64-
agent = agents[1] as TestAgent;
65-
Assert.NotNull(agent);
66-
action = agent.GetAction();
67-
Assert.AreEqual(action.vectorActions[0], 4);
68-
Assert.AreEqual(action.vectorActions[1], 5);
69-
Assert.AreEqual(action.vectorActions[2], 6);
60+
Assert.AreEqual(action1.vectorActions[0], 4);
61+
Assert.AreEqual(action1.vectorActions[1], 5);
62+
Assert.AreEqual(action1.vectorActions[2], 6);
7063
}
7164

7265
[Test]
@@ -80,49 +73,25 @@ public void ApplyDiscreteActionOutput()
8073
5,
8174
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
8275
};
83-
var agentInfos = GetFakeAgentInfos();
8476
var alloc = new TensorCachingAllocator();
8577
var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
86-
applier.Apply(inputTensor, agentInfos);
87-
var agents = agentInfos;
8878

89-
var agent = agents[0] as TestAgent;
90-
Assert.NotNull(agent);
91-
var action = agent.GetAction();
92-
Assert.AreEqual(action.vectorActions[0], 1);
93-
Assert.AreEqual(action.vectorActions[1], 1);
94-
95-
agent = agents[1] as TestAgent;
96-
Assert.NotNull(agent);
97-
action = agent.GetAction();
98-
Assert.AreEqual(action.vectorActions[0], 1);
99-
Assert.AreEqual(action.vectorActions[1], 2);
100-
alloc.Dispose();
101-
}
102-
103-
[Test]
104-
public void ApplyValueEstimate()
105-
{
106-
var inputTensor = new TensorProxy()
79+
var action0 = new AgentAction();
80+
var action1 = new AgentAction();
81+
var callbacks = new List<AgentIdActionPair>()
10782
{
108-
shape = new long[] { 2, 1 },
109-
data = new Tensor(2, 1, new[] { 0.5f, 8f })
83+
new AgentIdActionPair{agentId = 0, action = (a) => action0 = a},
84+
new AgentIdActionPair{agentId = 1, action = (a) => action1 = a}
11085
};
111-
var agentInfos = GetFakeAgentInfos();
11286

113-
var applier = new ValueEstimateApplier();
114-
applier.Apply(inputTensor, agentInfos);
115-
var agents = agentInfos;
87+
applier.Apply(inputTensor, callbacks);
11688

117-
var agent = agents[0] as TestAgent;
118-
Assert.NotNull(agent);
119-
var action = agent.GetAction();
120-
Assert.AreEqual(action.value, 0.5f);
89+
Assert.AreEqual(action0.vectorActions[0], 1);
90+
Assert.AreEqual(action0.vectorActions[1], 1);
12191

122-
agent = agents[1] as TestAgent;
123-
Assert.NotNull(agent);
124-
action = agent.GetAction();
125-
Assert.AreEqual(action.value, 8);
92+
Assert.AreEqual(action1.vectorActions[0], 1);
93+
Assert.AreEqual(action1.vectorActions[1], 2);
94+
alloc.Dispose();
12695
}
12796
}
12897
}

UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public void SetUp()
2020
}
2121
}
2222

23-
static IEnumerable<Agent> GetFakeAgents()
23+
static List<Agent> GetFakeAgents()
2424
{
2525
var goA = new GameObject("goA");
2626
var bpA = goA.AddComponent<BehaviorParameters>();
@@ -58,7 +58,6 @@ static IEnumerable<Agent> GetFakeAgents()
5858

5959
agentA.Info = infoA;
6060
agentB.Info = infoB;
61-
6261
return agents;
6362
}
6463

@@ -112,7 +111,14 @@ public void GenerateVectorObservation()
112111
generator.AddSensorIndex(0);
113112
generator.AddSensorIndex(1);
114113
generator.AddSensorIndex(2);
115-
generator.Generate(inputTensor, batchSize, agentInfos);
114+
var agent0 = agentInfos[0];
115+
var agent1 = agentInfos[1];
116+
var inputs = new List<AgentInfoSensorsPair>
117+
{
118+
new AgentInfoSensorsPair{agentInfo = agent0.Info, sensors = agent0.sensors},
119+
new AgentInfoSensorsPair{agentInfo = agent1.Info, sensors = agent1.sensors},
120+
};
121+
generator.Generate(inputTensor, batchSize, inputs);
116122
Assert.IsNotNull(inputTensor.data);
117123
Assert.AreEqual(inputTensor.data[0, 0], 1);
118124
Assert.AreEqual(inputTensor.data[0, 2], 3);
@@ -133,8 +139,14 @@ public void GeneratePreviousActionInput()
133139
var agentInfos = GetFakeAgents();
134140
var alloc = new TensorCachingAllocator();
135141
var generator = new PreviousActionInputGenerator(alloc);
136-
137-
generator.Generate(inputTensor, batchSize, agentInfos);
142+
var agent0 = agentInfos[0];
143+
var agent1 = agentInfos[1];
144+
var inputs = new List<AgentInfoSensorsPair>
145+
{
146+
new AgentInfoSensorsPair{agentInfo = agent0.Info, sensors = agent0.sensors},
147+
new AgentInfoSensorsPair{agentInfo = agent1.Info, sensors = agent1.sensors},
148+
};
149+
generator.Generate(inputTensor, batchSize, inputs);
138150
Assert.IsNotNull(inputTensor.data);
139151
Assert.AreEqual(inputTensor.data[0, 0], 1);
140152
Assert.AreEqual(inputTensor.data[0, 1], 2);
@@ -155,7 +167,16 @@ public void GenerateActionMaskInput()
155167
var agentInfos = GetFakeAgents();
156168
var alloc = new TensorCachingAllocator();
157169
var generator = new ActionMaskInputGenerator(alloc);
158-
generator.Generate(inputTensor, batchSize, agentInfos);
170+
171+
var agent0 = agentInfos[0];
172+
var agent1 = agentInfos[1];
173+
var inputs = new List<AgentInfoSensorsPair>
174+
{
175+
new AgentInfoSensorsPair{agentInfo = agent0.Info, sensors = agent0.sensors},
176+
new AgentInfoSensorsPair{agentInfo = agent1.Info, sensors = agent1.sensors},
177+
};
178+
179+
generator.Generate(inputTensor, batchSize, inputs);
159180
Assert.IsNotNull(inputTensor.data);
160181
Assert.AreEqual(inputTensor.data[0, 0], 1);
161182
Assert.AreEqual(inputTensor.data[0, 4], 1);

UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public override float[] Heuristic()
5050
{
5151
return new float[0];
5252
}
53+
5354
}
5455

5556
public class TestSensor : ISensor

UnitySDK/Assets/ML-Agents/Scripts/Agent.cs

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ namespace MLAgents
1313
/// </summary>
1414
public struct AgentInfo
1515
{
16-
/// <summary>
17-
/// Most recent observations.
18-
/// </summary>
19-
public List<Observation> observations;
2016

2117
/// <summary>
2218
/// Keeps track of the last vector action taken by the Brain.
@@ -449,8 +445,6 @@ void ResetData()
449445
m_Info.storedVectorActions = new float[param.vectorActionSize.Length];
450446
}
451447
}
452-
453-
m_Info.observations = new List<Observation>();
454448
}
455449

456450
/// <summary>
@@ -490,7 +484,7 @@ public void InitializeSensors()
490484
{
491485
// Get all attached sensor components
492486
SensorComponent[] attachedSensorComponents;
493-
if(m_PolicyFactory.useChildSensors)
487+
if (m_PolicyFactory.useChildSensors)
494488
{
495489
attachedSensorComponents = GetComponentsInChildren<SensorComponent>();
496490
}
@@ -531,17 +525,6 @@ public void InitializeSensors()
531525
Debug.Assert(!sensors[i].GetName().Equals(sensors[i + 1].GetName()), "Sensor names must be unique.");
532526
}
533527
#endif
534-
// Create a buffer for writing uncompressed (i.e. float) sensor data to
535-
int numFloatObservations = 0;
536-
for (var i = 0; i < sensors.Count; i++)
537-
{
538-
if (sensors[i].GetCompressionType() == SensorCompressionType.None)
539-
{
540-
numFloatObservations += sensors[i].ObservationSize();
541-
}
542-
}
543-
544-
m_VectorSensorBuffer = new float[numFloatObservations];
545528
}
546529

547530
/// <summary>
@@ -555,7 +538,6 @@ void SendInfoToBrain()
555538
}
556539

557540
m_Info.storedVectorActions = m_Action.vectorActions;
558-
m_Info.observations.Clear();
559541
m_ActionMasker.ResetMask();
560542
UpdateSensors();
561543
using (TimerStack.Instance.Scoped("CollectObservations"))
@@ -571,18 +553,23 @@ void SendInfoToBrain()
571553
m_Info.maxStepReached = m_MaxStepReached;
572554
m_Info.id = m_Id;
573555

574-
m_Brain.RequestDecision(this);
556+
m_Brain.RequestDecision(m_Info, sensors, UpdateAgentAction);
575557

576558
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
577559
{
578-
// This is a bit of a hack - if we're in inference mode, observations won't be generated
579-
// But we need these to be generated for the recorder. So generate them here.
580-
if (m_Info.observations.Count == 0)
560+
561+
if (m_VectorSensorBuffer == null)
581562
{
582-
GenerateSensorData();
563+
// Create a buffer for writing uncompressed (i.e. float) sensor data to
564+
m_VectorSensorBuffer = new float[sensors.GetSensorFloatObservationSize()];
583565
}
584566

585-
m_Recorder.WriteExperience(m_Info);
567+
// This is a bit of a hack - if we're in inference mode, observations won't be generated
568+
// But we need these to be generated for the recorder. So generate them here.
569+
var observations = new List<Observation>();
570+
GenerateSensorData(sensors, m_VectorSensorBuffer, m_WriteAdapter, observations);
571+
572+
m_Recorder.WriteExperience(m_Info, observations);
586573
}
587574

588575
}
@@ -596,12 +583,17 @@ void UpdateSensors()
596583
}
597584

598585
/// <summary>
599-
/// Generate data for each sensor and store it on the Agent's AgentInfo.
586+
/// Generate data for each sensor and store it in the observations input.
600587
/// NOTE: At the moment, this is only called during training or when using a DemonstrationRecorder;
601588
/// during inference the Sensors are used to write directly to the Tensor data. This will likely change in the
602589
/// future to be controlled by the type of brain being used.
603590
/// </summary>
604-
public void GenerateSensorData()
591+
/// <param name="sensors"> List of ISensors that will be used to generate the data.</param>
592+
/// <param name="buffer"> A float array that will be used as buffer when generating the observations. Must
593+
/// be at least the same length as the total number of uncompressed floats in the observations</param>
594+
/// <param name="adapter"> The WriteAdapter that will be used to write the ISensor data to the observations</param>
595+
/// <param name="observations"> A list of observations outputs. This argument will be modified by this method.</param>//
596+
public static void GenerateSensorData(List<ISensor> sensors, float[] buffer, WriteAdapter adapter, List<Observation> observations)
605597
{
606598
int floatsWritten = 0;
607599
// Generate data for all Sensors
@@ -611,15 +603,15 @@ public void GenerateSensorData()
611603
if (sensor.GetCompressionType() == SensorCompressionType.None)
612604
{
613605
// TODO handle in communicator code instead
614-
m_WriteAdapter.SetTarget(m_VectorSensorBuffer, sensor.GetObservationShape(), floatsWritten);
615-
var numFloats = sensor.Write(m_WriteAdapter);
606+
adapter.SetTarget(buffer, sensor.GetObservationShape(), floatsWritten);
607+
var numFloats = sensor.Write(adapter);
616608
var floatObs = new Observation
617609
{
618-
FloatData = new ArraySegment<float>(m_VectorSensorBuffer, floatsWritten, numFloats),
610+
FloatData = new ArraySegment<float>(buffer, floatsWritten, numFloats),
619611
Shape = sensor.GetObservationShape(),
620612
CompressionType = sensor.GetCompressionType()
621613
};
622-
m_Info.observations.Add(floatObs);
614+
observations.Add(floatObs);
623615
floatsWritten += numFloats;
624616
}
625617
else
@@ -630,7 +622,7 @@ public void GenerateSensorData()
630622
Shape = sensor.GetObservationShape(),
631623
CompressionType = sensor.GetCompressionType()
632624
};
633-
m_Info.observations.Add(compressedObs);
625+
observations.Add(compressedObs);
634626
}
635627
}
636628
}

UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System.IO.Abstractions;
22
using System.Text.RegularExpressions;
33
using UnityEngine;
4+
using System.Collections.Generic;
5+
using MLAgents.Sensor;
46

57
namespace MLAgents
68
{
@@ -68,9 +70,9 @@ public static string SanitizeName(string demoName, int maxNameLength)
6870
/// <summary>
6971
/// Forwards AgentInfo to Demonstration Store.
7072
/// </summary>
71-
public void WriteExperience(AgentInfo info)
73+
public void WriteExperience(AgentInfo info, List<Observation> observations)
7274
{
73-
m_DemoStore.Record(info);
75+
m_DemoStore.Record(info, observations);
7476
}
7577

7678
public void Close()

UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using System.IO;
22
using System.IO.Abstractions;
33
using Google.Protobuf;
4-
using UnityEngine;
4+
using System.Collections.Generic;
5+
using MLAgents.Sensor;
56

67
namespace MLAgents
78
{
@@ -91,7 +92,7 @@ void WriteBrainParameters(string brainName, BrainParameters brainParameters)
9192
/// <summary>
9293
/// Write AgentInfo experience to file.
9394
/// </summary>
94-
public void Record(AgentInfo info)
95+
public void Record(AgentInfo info, List<Observation> observations)
9596
{
9697
// Increment meta-data counters.
9798
m_MetaData.numberExperiences++;
@@ -102,7 +103,7 @@ public void Record(AgentInfo info)
102103
}
103104

104105
// Write AgentInfo to file.
105-
var agentProto = info.ToInfoActionPairProto();
106+
var agentProto = info.ToInfoActionPairProto(observations);
106107
agentProto.WriteDelimitedTo(m_Writer);
107108
}
108109

0 commit comments

Comments
 (0)