Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ public override ISensor CreateSensor()
{
return new BasicSensor(basicController);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { BasicController.k_Extents };
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public abstract class SensorBase : ISensor
{
/// <summary>
/// Write the observations to the output buffer. This size of the buffer will be product
/// of the sizes returned by <see cref="GetObservationShape"/>.
/// of the Shape array values returned by <see cref="ObservationSpec"/>.
/// </summary>
/// <param name="output"></param>
public abstract void WriteObservation(float[] output);
Expand All @@ -28,7 +28,7 @@ public abstract class SensorBase : ISensor
/// <returns>The number of elements written.</returns>
public virtual int Write(ObservationWriter writer)
{
// TODO reuse buffer for similar agents, don't call GetObservationShape()
// TODO reuse buffer for similar agents
var numFloats = this.ObservationSize();
float[] buffer = new float[numFloats];
WriteObservation(buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,5 @@ public override ISensor CreateSensor()
}
return m_Sensor;
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
var width = TestTexture.width;
var height = TestTexture.height;
var observationShape = new[] { height, width, 3 };

var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
if (stacks > 1)
{
observationShape[2] *= stacks;
}

return observationShape;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,5 @@ public override ISensor CreateSensor()
return new Match3Sensor(board, ObservationType, SensorName);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
var board = GetComponent<AbstractBoard>();
if (board == null)
{
return System.Array.Empty<int>();
}

var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1;
return ObservationType == Match3ObservationType.Vector ?
new[] { board.Rows * board.Columns * (board.NumCellTypes + specialSize) } :
new[] { board.Rows, board.Columns, board.NumCellTypes + specialSize };
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,6 @@ public override ISensor CreateSensor()
return new PhysicsBodySensor(RootBody, Settings, sensorName);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}

// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;

foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
{
numJointObservations += ArticulationBodyJointExtractor.NumObservations(articBody, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}
}

}
Expand Down
7 changes: 0 additions & 7 deletions com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -917,13 +917,6 @@ public ObservationSpec GetObservationSpec()
return m_ObservationSpec;
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}

/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ string sensorName
}

#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null)
{
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_PoseExtractor = poseExtractor;
Expand All @@ -57,7 +57,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin

var numJointExtractorObservations = 0;
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
foreach (var articBody in poseExtractor.GetEnabledArticulationBodies())
{
var jointExtractor = new ArticulationBodyJointExtractor(articBody);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
Expand All @@ -67,6 +67,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}

#endif

/// <inheritdoc/>
Expand Down Expand Up @@ -126,6 +127,5 @@ public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.PhysicsBodySensor;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,6 @@ public override ISensor CreateSensor()
return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}

var poseExtractor = GetPoseExtractor();
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);

var numJointObservations = 0;
foreach (var rb in poseExtractor.GetEnabledRigidbodies())
{
var joint = rb.GetComponent<Joint>();
numJointObservations += RigidBodyJointExtractor.NumObservations(rb, joint, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}

/// <summary>
/// Get the DisplayNodes of the hierarchy.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class Match3SensorTests
public void TestVectorObservations()
{
var boardString =
@"000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not look correct. (The indent I mean)

@"000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intended?
(There's a few more below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, coding package reformatted them (and a few other files). Will revert.

000
010";
var gameObj = new GameObject("board");
Expand All @@ -29,9 +29,8 @@ public void TestVectorObservations()
sensorComponent.ObservationType = Match3ObservationType.Vector;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3 * 3 * 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

var expectedObs = new float[]
{
Expand All @@ -46,11 +45,11 @@ public void TestVectorObservations()
public void TestVectorObservationsSpecial()
{
var boardString =
@"000
@"000
000
010";
var specialString =
@"010
@"010
200
000";

Expand All @@ -63,9 +62,8 @@ public void TestVectorObservationsSpecial()
sensorComponent.ObservationType = Match3ObservationType.Vector;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3 * 3 * (2 + 3));
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

var expectedObs = new float[]
{
Expand All @@ -76,12 +74,11 @@ public void TestVectorObservationsSpecial()
SensorTestHelper.CompareObservation(sensor, expectedObs);
}


[Test]
public void TestVisualObservations()
{
var boardString =
@"000
@"000
000
010";
var gameObj = new GameObject("board");
Expand All @@ -92,9 +89,8 @@ public void TestVisualObservations()
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);

Expand All @@ -119,11 +115,11 @@ public void TestVisualObservations()
public void TestVisualObservationsSpecial()
{
var boardString =
@"000
@"000
000
010";
var specialString =
@"010
@"010
200
000";

Expand All @@ -136,9 +132,8 @@ public void TestVisualObservationsSpecial()
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2 + 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);

Expand All @@ -163,7 +158,7 @@ public void TestVisualObservationsSpecial()
public void TestCompressedVisualObservations()
{
var boardString =
@"000
@"000
000
010";
var gameObj = new GameObject("board");
Expand All @@ -174,9 +169,8 @@ public void TestCompressedVisualObservations()
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);

Expand All @@ -191,17 +185,15 @@ public void TestCompressedVisualObservations()
Assert.AreEqual(expectedPng, pngData);
}



[Test]
public void TestCompressedVisualObservationsSpecial()
{
var boardString =
@"000
@"000
000
010";
var specialString =
@"010
@"010
200
000";

Expand All @@ -214,9 +206,8 @@ public void TestCompressedVisualObservationsSpecial()
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2 + 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);

Expand All @@ -229,7 +220,6 @@ public void TestCompressedVisualObservationsSpecial()
}
var expectedPng = LoadPNGs(pathPrefix, 2);
Assert.AreEqual(expectedPng, concatenatedPngData);

}

/// <summary>
Expand Down Expand Up @@ -306,7 +296,6 @@ byte[] LoadPNGs(string pathPrefix, int numExpected)
}

return bytesOut.ToArray();

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ public void OneChannelDepthOne()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

var expectedShape = new InplaceArray<int>(10, 10, 1);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}


Expand All @@ -51,9 +50,8 @@ public void OneChannelDepthTwo()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

var expectedShape = new InplaceArray<int>(10, 10, 2);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}

[Test]
Expand All @@ -66,8 +64,8 @@ public void TwoChannelsDepthTwoOne()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 3);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);

}

Expand All @@ -81,9 +79,8 @@ public void TwoChannelsDepthThreeThree()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

var expectedShape = new InplaceArray<int>(10, 10, 6);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}

}
Expand Down
Loading