Skip to content

Commit 3e49abb

Browse files
author
Chris Elion
authored
Add SensorShapeValidator unit (#3504)
1 parent 608fdcf commit 3e49abb

File tree

3 files changed

+158
-2
lines changed

3 files changed

+158
-2
lines changed

com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ public void ValidateSensors(List<ISensor> sensors)
2727
// Check for compatibility with the other Agents' Sensors
2828
// TODO make sure this only checks once per agent
2929
Debug.Assert(m_SensorShapes.Count == sensors.Count, $"Number of Sensors must match. {m_SensorShapes.Count} != {sensors.Count}");
30-
for (var i = 0; i < m_SensorShapes.Count; i++)
30+
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++)
3131
{
3232
var cachedShape = m_SensorShapes[i];
3333
var sensorShape = sensors[i].GetObservationShape();
3434
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
35-
for (var j = 0; j < cachedShape.Length; j++)
35+
for (var j = 0; j < Mathf.Min(cachedShape.Length, sensorShape.Length); j++)
3636
{
3737
Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes much match.");
3838
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
using System.Collections.Generic;
2+
using NUnit.Framework;
3+
using UnityEngine;
4+
using UnityEngine.TestTools;
5+
6+
namespace MLAgents.Tests
7+
{
8+
public class DummySensor : ISensor
9+
{
10+
string m_Name = "DummySensor";
11+
int[] m_Shape;
12+
13+
public DummySensor(int dim1)
14+
{
15+
m_Shape = new[] { dim1 };
16+
}
17+
18+
public DummySensor(int dim1, int dim2)
19+
{
20+
m_Shape = new[] { dim1, dim2, };
21+
}
22+
23+
public DummySensor(int dim1, int dim2, int dim3)
24+
{
25+
m_Shape = new[] { dim1, dim2, dim3};
26+
}
27+
28+
public string GetName()
29+
{
30+
return m_Name;
31+
}
32+
33+
public int[] GetObservationShape()
34+
{
35+
return m_Shape;
36+
}
37+
38+
public byte[] GetCompressedObservation()
39+
{
40+
return null;
41+
}
42+
43+
public int Write(WriteAdapter adapter)
44+
{
45+
return this.ObservationSize();
46+
}
47+
48+
public void Update() { }
49+
50+
public SensorCompressionType GetCompressionType()
51+
{
52+
return SensorCompressionType.None;
53+
}
54+
}
55+
56+
public class SensorShapeValidatorTests
57+
{
58+
59+
[Test]
60+
public void TestShapesAgree()
61+
{
62+
var validator = new SensorShapeValidator();
63+
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
64+
validator.ValidateSensors(sensorList1);
65+
66+
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
67+
validator.ValidateSensors(sensorList2);
68+
}
69+
70+
[Test]
71+
public void TestNumSensorMismatch()
72+
{
73+
var validator = new SensorShapeValidator();
74+
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
75+
validator.ValidateSensors(sensorList1);
76+
77+
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), };
78+
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
79+
validator.ValidateSensors(sensorList2);
80+
81+
// Add the sensors in the other order
82+
validator = new SensorShapeValidator();
83+
validator.ValidateSensors(sensorList2);
84+
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
85+
validator.ValidateSensors(sensorList1);
86+
}
87+
[Test]
88+
public void TestDimensionMismatch()
89+
{
90+
var validator = new SensorShapeValidator();
91+
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
92+
validator.ValidateSensors(sensorList1);
93+
94+
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) };
95+
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
96+
validator.ValidateSensors(sensorList2);
97+
98+
// Add the sensors in the other order
99+
validator = new SensorShapeValidator();
100+
validator.ValidateSensors(sensorList2);
101+
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
102+
validator.ValidateSensors(sensorList1);
103+
}
104+
105+
[Test]
106+
public void TestSizeMismatch()
107+
{
108+
var validator = new SensorShapeValidator();
109+
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
110+
validator.ValidateSensors(sensorList1);
111+
112+
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) };
113+
LogAssert.Expect(LogType.Assert, "Sensor sizes much match.");
114+
validator.ValidateSensors(sensorList2);
115+
116+
// Add the sensors in the other order
117+
validator = new SensorShapeValidator();
118+
validator.ValidateSensors(sensorList2);
119+
LogAssert.Expect(LogType.Assert, "Sensor sizes much match.");
120+
validator.ValidateSensors(sensorList1);
121+
}
122+
123+
[Test]
124+
public void TestEverythingMismatch()
125+
{
126+
var validator = new SensorShapeValidator();
127+
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
128+
validator.ValidateSensors(sensorList1);
129+
130+
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) };
131+
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
132+
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
133+
LogAssert.Expect(LogType.Assert, "Sensor sizes much match.");
134+
validator.ValidateSensors(sensorList2);
135+
136+
// Add the sensors in the other order
137+
validator = new SensorShapeValidator();
138+
validator.ValidateSensors(sensorList2);
139+
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
140+
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
141+
LogAssert.Expect(LogType.Assert, "Sensor sizes much match.");
142+
validator.ValidateSensors(sensorList1);
143+
}
144+
}
145+
}

com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)