Skip to content

Commit f404675

Browse files
Decoupling IPolicy from Agent (#3203)
* initial commit * Fixed the compilation errors * fixing the tests * Addressing the comment about the brain parameters * Fixing typo * Made timers more accurate * addressing comments * Better memory allocation * Added some docstrings * Adding better sensor validation * Wrapped in #if DEBUG and also wrapped GenerateSensorData in a timer * Timer changes
1 parent 61b0e6e commit f404675

24 files changed

+363
-328
lines changed

UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public void TestStoreInitalize()
5454
storedVectorActions = new[] { 0f, 1f },
5555
};
5656

57-
demoStore.Record(agentInfo);
57+
demoStore.Record(agentInfo, new System.Collections.Generic.List<Sensor.Observation>());
5858
demoStore.Close();
5959
}
6060

UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs

Lines changed: 29 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Reflection;
55
using Barracuda;
66
using MLAgents.InferenceBrain;
7+
using System;
78

89
namespace MLAgents.Tests
910
{
@@ -19,16 +20,6 @@ public AgentAction GetAction()
1920
}
2021
}
2122

22-
List<Agent> GetFakeAgentInfos()
23-
{
24-
var goA = new GameObject("goA");
25-
var agentA = goA.AddComponent<TestAgent>();
26-
var goB = new GameObject("goB");
27-
var agentB = goB.AddComponent<TestAgent>();
28-
29-
return new List<Agent> { agentA, agentB };
30-
}
31-
3223
[Test]
3324
public void Construction()
3425
{
@@ -48,25 +39,27 @@ public void ApplyContinuousActionOutput()
4839
shape = new long[] { 2, 3 },
4940
data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
5041
};
51-
var agentInfos = GetFakeAgentInfos();
5242

5343
var applier = new ContinuousActionOutputApplier();
54-
applier.Apply(inputTensor, agentInfos);
55-
var agents = agentInfos;
5644

57-
var agent = agents[0] as TestAgent;
58-
Assert.NotNull(agent);
59-
var action = agent.GetAction();
60-
Assert.AreEqual(action.vectorActions[0], 1);
61-
Assert.AreEqual(action.vectorActions[1], 2);
62-
Assert.AreEqual(action.vectorActions[2], 3);
45+
var action0 = new AgentAction();
46+
var action1 = new AgentAction();
47+
var callbacks = new List<AgentIdActionPair>()
48+
{
49+
new AgentIdActionPair{agentId = 0, action = (a) => action0 = a},
50+
new AgentIdActionPair{agentId = 1, action = (a) => action1 = a}
51+
};
52+
53+
applier.Apply(inputTensor, callbacks);
54+
55+
56+
Assert.AreEqual(action0.vectorActions[0], 1);
57+
Assert.AreEqual(action0.vectorActions[1], 2);
58+
Assert.AreEqual(action0.vectorActions[2], 3);
6359

64-
agent = agents[1] as TestAgent;
65-
Assert.NotNull(agent);
66-
action = agent.GetAction();
67-
Assert.AreEqual(action.vectorActions[0], 4);
68-
Assert.AreEqual(action.vectorActions[1], 5);
69-
Assert.AreEqual(action.vectorActions[2], 6);
60+
Assert.AreEqual(action1.vectorActions[0], 4);
61+
Assert.AreEqual(action1.vectorActions[1], 5);
62+
Assert.AreEqual(action1.vectorActions[2], 6);
7063
}
7164

7265
[Test]
@@ -80,49 +73,25 @@ public void ApplyDiscreteActionOutput()
8073
5,
8174
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
8275
};
83-
var agentInfos = GetFakeAgentInfos();
8476
var alloc = new TensorCachingAllocator();
8577
var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
86-
applier.Apply(inputTensor, agentInfos);
87-
var agents = agentInfos;
8878

89-
var agent = agents[0] as TestAgent;
90-
Assert.NotNull(agent);
91-
var action = agent.GetAction();
92-
Assert.AreEqual(action.vectorActions[0], 1);
93-
Assert.AreEqual(action.vectorActions[1], 1);
94-
95-
agent = agents[1] as TestAgent;
96-
Assert.NotNull(agent);
97-
action = agent.GetAction();
98-
Assert.AreEqual(action.vectorActions[0], 1);
99-
Assert.AreEqual(action.vectorActions[1], 2);
100-
alloc.Dispose();
101-
}
102-
103-
[Test]
104-
public void ApplyValueEstimate()
105-
{
106-
var inputTensor = new TensorProxy()
79+
var action0 = new AgentAction();
80+
var action1 = new AgentAction();
81+
var callbacks = new List<AgentIdActionPair>()
10782
{
108-
shape = new long[] { 2, 1 },
109-
data = new Tensor(2, 1, new[] { 0.5f, 8f })
83+
new AgentIdActionPair{agentId = 0, action = (a) => action0 = a},
84+
new AgentIdActionPair{agentId = 1, action = (a) => action1 = a}
11085
};
111-
var agentInfos = GetFakeAgentInfos();
11286

113-
var applier = new ValueEstimateApplier();
114-
applier.Apply(inputTensor, agentInfos);
115-
var agents = agentInfos;
87+
applier.Apply(inputTensor, callbacks);
11688

117-
var agent = agents[0] as TestAgent;
118-
Assert.NotNull(agent);
119-
var action = agent.GetAction();
120-
Assert.AreEqual(action.value, 0.5f);
89+
Assert.AreEqual(action0.vectorActions[0], 1);
90+
Assert.AreEqual(action0.vectorActions[1], 1);
12191

122-
agent = agents[1] as TestAgent;
123-
Assert.NotNull(agent);
124-
action = agent.GetAction();
125-
Assert.AreEqual(action.value, 8);
92+
Assert.AreEqual(action1.vectorActions[0], 1);
93+
Assert.AreEqual(action1.vectorActions[1], 2);
94+
alloc.Dispose();
12695
}
12796
}
12897
}

UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace MLAgents.Tests
1010
{
1111
public class EditModeTestInternalBrainTensorGenerator
1212
{
13-
static IEnumerable<Agent> GetFakeAgents()
13+
static List<Agent> GetFakeAgents()
1414
{
1515
var acaGo = new GameObject("TestAcademy");
1616
acaGo.AddComponent<Academy>();
@@ -52,7 +52,6 @@ static IEnumerable<Agent> GetFakeAgents()
5252

5353
agentA.Info = infoA;
5454
agentB.Info = infoB;
55-
5655
return agents;
5756
}
5857

@@ -106,7 +105,14 @@ public void GenerateVectorObservation()
106105
generator.AddSensorIndex(0);
107106
generator.AddSensorIndex(1);
108107
generator.AddSensorIndex(2);
109-
generator.Generate(inputTensor, batchSize, agentInfos);
108+
var agent0 = agentInfos[0];
109+
var agent1 = agentInfos[1];
110+
var inputs = new List<AgentInfoSensorsPair>
111+
{
112+
new AgentInfoSensorsPair{agentInfo = agent0.Info, sensors = agent0.sensors},
113+
new AgentInfoSensorsPair{agentInfo = agent1.Info, sensors = agent1.sensors},
114+
};
115+
generator.Generate(inputTensor, batchSize, inputs);
110116
Assert.IsNotNull(inputTensor.data);
111117
Assert.AreEqual(inputTensor.data[0, 0], 1);
112118
Assert.AreEqual(inputTensor.data[0, 2], 3);
@@ -127,8 +133,14 @@ public void GeneratePreviousActionInput()
127133
var agentInfos = GetFakeAgents();
128134
var alloc = new TensorCachingAllocator();
129135
var generator = new PreviousActionInputGenerator(alloc);
130-
131-
generator.Generate(inputTensor, batchSize, agentInfos);
136+
var agent0 = agentInfos[0];
137+
var agent1 = agentInfos[1];
138+
var inputs = new List<AgentInfoSensorsPair>
139+
{
140+
new AgentInfoSensorsPair{agentInfo = agent0.Info, sensors = agent0.sensors},
141+
new AgentInfoSensorsPair{agentInfo = agent1.Info, sensors = agent1.sensors},
142+
};
143+
generator.Generate(inputTensor, batchSize, inputs);
132144
Assert.IsNotNull(inputTensor.data);
133145
Assert.AreEqual(inputTensor.data[0, 0], 1);
134146
Assert.AreEqual(inputTensor.data[0, 1], 2);
@@ -149,7 +161,16 @@ public void GenerateActionMaskInput()
149161
var agentInfos = GetFakeAgents();
150162
var alloc = new TensorCachingAllocator();
151163
var generator = new ActionMaskInputGenerator(alloc);
152-
generator.Generate(inputTensor, batchSize, agentInfos);
164+
165+
var agent0 = agentInfos[0];
166+
var agent1 = agentInfos[1];
167+
var inputs = new List<AgentInfoSensorsPair>
168+
{
169+
new AgentInfoSensorsPair{agentInfo = agent0.Info, sensors = agent0.sensors},
170+
new AgentInfoSensorsPair{agentInfo = agent1.Info, sensors = agent1.sensors},
171+
};
172+
173+
generator.Generate(inputTensor, batchSize, inputs);
153174
Assert.IsNotNull(inputTensor.data);
154175
Assert.AreEqual(inputTensor.data[0, 0], 1);
155176
Assert.AreEqual(inputTensor.data[0, 4], 1);

UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public override float[] Heuristic()
5050
{
5151
return new float[0];
5252
}
53+
5354
}
5455

5556
public class TestSensor : ISensor

UnitySDK/Assets/ML-Agents/Scripts/Agent.cs

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ namespace MLAgents
1313
/// </summary>
1414
public struct AgentInfo
1515
{
16-
/// <summary>
17-
/// Most recent observations.
18-
/// </summary>
19-
public List<Observation> observations;
2016

2117
/// <summary>
2218
/// Keeps track of the last vector action taken by the Brain.
@@ -456,8 +452,6 @@ void ResetData()
456452
m_Info.storedVectorActions = new float[param.vectorActionSize.Length];
457453
}
458454
}
459-
460-
m_Info.observations = new List<Observation>();
461455
}
462456

463457
/// <summary>
@@ -497,7 +491,7 @@ public void InitializeSensors()
497491
{
498492
// Get all attached sensor components
499493
SensorComponent[] attachedSensorComponents;
500-
if(m_PolicyFactory.useChildSensors)
494+
if (m_PolicyFactory.useChildSensors)
501495
{
502496
attachedSensorComponents = GetComponentsInChildren<SensorComponent>();
503497
}
@@ -538,17 +532,6 @@ public void InitializeSensors()
538532
Debug.Assert(!sensors[i].GetName().Equals(sensors[i + 1].GetName()), "Sensor names must be unique.");
539533
}
540534
#endif
541-
// Create a buffer for writing uncompressed (i.e. float) sensor data to
542-
int numFloatObservations = 0;
543-
for (var i = 0; i < sensors.Count; i++)
544-
{
545-
if (sensors[i].GetCompressionType() == SensorCompressionType.None)
546-
{
547-
numFloatObservations += sensors[i].ObservationSize();
548-
}
549-
}
550-
551-
m_VectorSensorBuffer = new float[numFloatObservations];
552535
}
553536

554537
/// <summary>
@@ -562,7 +545,6 @@ void SendInfoToBrain()
562545
}
563546

564547
m_Info.storedVectorActions = m_Action.vectorActions;
565-
m_Info.observations.Clear();
566548
m_ActionMasker.ResetMask();
567549
UpdateSensors();
568550
using (TimerStack.Instance.Scoped("CollectObservations"))
@@ -578,18 +560,23 @@ void SendInfoToBrain()
578560
m_Info.maxStepReached = m_MaxStepReached;
579561
m_Info.id = m_Id;
580562

581-
m_Brain.RequestDecision(this);
563+
m_Brain.RequestDecision(m_Info, sensors, UpdateAgentAction);
582564

583565
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
584566
{
585-
// This is a bit of a hack - if we're in inference mode, observations won't be generated
586-
// But we need these to be generated for the recorder. So generate them here.
587-
if (m_Info.observations.Count == 0)
567+
568+
if (m_VectorSensorBuffer == null)
588569
{
589-
GenerateSensorData();
570+
// Create a buffer for writing uncompressed (i.e. float) sensor data to
571+
m_VectorSensorBuffer = new float[sensors.GetSensorFloatObservationSize()];
590572
}
591573

592-
m_Recorder.WriteExperience(m_Info);
574+
// This is a bit of a hack - if we're in inference mode, observations won't be generated
575+
// But we need these to be generated for the recorder. So generate them here.
576+
var observations = new List<Observation>();
577+
GenerateSensorData(sensors, m_VectorSensorBuffer, m_WriteAdapter, observations);
578+
579+
m_Recorder.WriteExperience(m_Info, observations);
593580
}
594581

595582
}
@@ -603,12 +590,17 @@ void UpdateSensors()
603590
}
604591

605592
/// <summary>
606-
/// Generate data for each sensor and store it on the Agent's AgentInfo.
593+
/// Generate data for each sensor and store it in the observations input.
607594
/// NOTE: At the moment, this is only called during training or when using a DemonstrationRecorder;
608595
/// during inference the Sensors are used to write directly to the Tensor data. This will likely change in the
609596
/// future to be controlled by the type of brain being used.
610597
/// </summary>
611-
public void GenerateSensorData()
598+
/// <param name="sensors"> List of ISensors that will be used to generate the data.</param>
599+
/// <param name="buffer"> A float array that will be used as buffer when generating the observations. Must
600+
/// be at least the same length as the total number of uncompressed floats in the observations</param>
601+
/// <param name="adapter"> The WriteAdapter that will be used to write the ISensor data to the observations</param>
602+
/// <param name="observations"> A list of observations outputs. This argument will be modified by this method.</param>//
603+
public static void GenerateSensorData(List<ISensor> sensors, float[] buffer, WriteAdapter adapter, List<Observation> observations)
612604
{
613605
int floatsWritten = 0;
614606
// Generate data for all Sensors
@@ -618,15 +610,15 @@ public void GenerateSensorData()
618610
if (sensor.GetCompressionType() == SensorCompressionType.None)
619611
{
620612
// TODO handle in communicator code instead
621-
m_WriteAdapter.SetTarget(m_VectorSensorBuffer, sensor.GetObservationShape(), floatsWritten);
622-
var numFloats = sensor.Write(m_WriteAdapter);
613+
adapter.SetTarget(buffer, sensor.GetObservationShape(), floatsWritten);
614+
var numFloats = sensor.Write(adapter);
623615
var floatObs = new Observation
624616
{
625-
FloatData = new ArraySegment<float>(m_VectorSensorBuffer, floatsWritten, numFloats),
617+
FloatData = new ArraySegment<float>(buffer, floatsWritten, numFloats),
626618
Shape = sensor.GetObservationShape(),
627619
CompressionType = sensor.GetCompressionType()
628620
};
629-
m_Info.observations.Add(floatObs);
621+
observations.Add(floatObs);
630622
floatsWritten += numFloats;
631623
}
632624
else
@@ -637,7 +629,7 @@ public void GenerateSensorData()
637629
Shape = sensor.GetObservationShape(),
638630
CompressionType = sensor.GetCompressionType()
639631
};
640-
m_Info.observations.Add(compressedObs);
632+
observations.Add(compressedObs);
641633
}
642634
}
643635
}

UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System.IO.Abstractions;
22
using System.Text.RegularExpressions;
33
using UnityEngine;
4+
using System.Collections.Generic;
5+
using MLAgents.Sensor;
46

57
namespace MLAgents
68
{
@@ -68,9 +70,9 @@ public static string SanitizeName(string demoName, int maxNameLength)
6870
/// <summary>
6971
/// Forwards AgentInfo to Demonstration Store.
7072
/// </summary>
71-
public void WriteExperience(AgentInfo info)
73+
public void WriteExperience(AgentInfo info, List<Observation> observations)
7274
{
73-
m_DemoStore.Record(info);
75+
m_DemoStore.Record(info, observations);
7476
}
7577

7678
public void Close()

UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using System.IO;
22
using System.IO.Abstractions;
33
using Google.Protobuf;
4-
using UnityEngine;
4+
using System.Collections.Generic;
5+
using MLAgents.Sensor;
56

67
namespace MLAgents
78
{
@@ -91,7 +92,7 @@ void WriteBrainParameters(string brainName, BrainParameters brainParameters)
9192
/// <summary>
9293
/// Write AgentInfo experience to file.
9394
/// </summary>
94-
public void Record(AgentInfo info)
95+
public void Record(AgentInfo info, List<Observation> observations)
9596
{
9697
// Increment meta-data counters.
9798
m_MetaData.numberExperiences++;
@@ -102,7 +103,7 @@ public void Record(AgentInfo info)
102103
}
103104

104105
// Write AgentInfo to file.
105-
var agentProto = info.ToInfoActionPairProto();
106+
var agentProto = info.ToInfoActionPairProto(observations);
106107
agentProto.WriteDelimitedTo(m_Writer);
107108
}
108109

0 commit comments

Comments
 (0)