From e15cc144aab5b7d0c57ad258ddaf6fb2ee070fe4 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 22 Mar 2021 14:39:34 -0700 Subject: [PATCH 1/3] Remove SensorComponent.GetObservationShape() --- .../Basic/Scripts/BasicSensorComponent.cs | 6 --- .../SharedAssets/Scripts/SensorBase.cs | 4 +- .../TestTextureSensorComponent.cs | 16 ------ .../Runtime/Match3/Match3SensorComponent.cs | 14 ----- .../ArticulationBodySensorComponent.cs | 20 ------- .../Runtime/Sensors/GridSensor.cs | 7 --- .../Runtime/Sensors/PhysicsBodySensor.cs | 6 +-- .../Sensors/RigidBodySensorComponent.cs | 20 ------- .../Tests/Editor/Match3/Match3SensorTests.cs | 53 ++++++++----------- .../Editor/Sensors/ChannelHotShapeTests.cs | 19 +++---- .../Tests/Editor/Sensors/ChannelShapeTests.cs | 15 +++--- .../Editor/Sensors/GridSensorTestUtils.cs | 27 ---------- .../Sensors/ArticulationBodySensorTests.cs | 6 +-- .../Runtime/Sensors/RigidBodySensorTests.cs | 6 +-- .../Runtime/Communicator/GrpcExtensions.cs | 2 +- .../Runtime/Sensors/BufferSensor.cs | 2 - .../Runtime/Sensors/BufferSensorComponent.cs | 6 --- .../Runtime/Sensors/CameraSensor.cs | 1 - .../Runtime/Sensors/CameraSensorComponent.cs | 15 ------ .../RayPerceptionSensorComponentBase.cs | 14 ----- .../Reflection/ReflectionSensorBase.cs | 1 - .../Runtime/Sensors/RenderTextureSensor.cs | 1 - .../Sensors/RenderTextureSensorComponent.cs | 16 ------ .../Runtime/Sensors/SensorComponent.cs | 7 --- .../Editor/Analytics/TrainingAnalyticsTest.cs | 1 - .../Editor/Inference/ParameterLoaderTest.cs | 7 +-- .../Tests/Editor/MLAgentsEditModeTest.cs | 1 - .../Tests/Runtime/RuntimeAPITest.cs | 11 ---- .../Tests/Runtime/Sensor/BufferSensorTest.cs | 4 +- .../Sensor/CameraSensorComponentTest.cs | 7 +-- .../RenderTextureSensorComponentTests.cs | 5 +- .../Sensor/SensorShapeValidatorTests.cs | 2 +- .../Runtime/Sensor/StackingSensorTests.cs | 1 - .../Tests/Runtime/Utils/TestClasses.cs | 1 - 34 files changed, 58 insertions(+), 266 deletions(-) diff --git a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs index a62d2819a3..cddbc27439 100644 --- a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs +++ b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs @@ -19,12 +19,6 @@ public override ISensor CreateSensor() { return new BasicSensor(basicController); } - - /// - public override int[] GetObservationShape() - { - return new[] { BasicController.k_Extents }; - } } /// diff --git a/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs b/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs index 7bca30e7c9..31bc7c8556 100644 --- a/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs +++ b/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs @@ -9,7 +9,7 @@ public abstract class SensorBase : ISensor { /// /// Write the observations to the output buffer. This size of the buffer will be product - /// of the sizes returned by . + /// of the Shape array values returned by . /// /// public abstract void WriteObservation(float[] output); @@ -28,7 +28,7 @@ public abstract class SensorBase : ISensor /// The number of elements written. public virtual int Write(ObservationWriter writer) { - // TODO reuse buffer for similar agents, don't call GetObservationShape() + // TODO reuse buffer for similar agents var numFloats = this.ObservationSize(); float[] buffer = new float[numFloats]; WriteObservation(buffer); diff --git a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs index 093e6652ba..504e0c3d19 100644 --- a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs +++ b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs @@ -32,21 +32,5 @@ public override ISensor CreateSensor() } return m_Sensor; } - - /// - public override int[] GetObservationShape() - { - var width = TestTexture.width; - var height = TestTexture.height; - var observationShape = new[] { height, width, 3 }; - - var stacks = ObservationStacks > 1 ? ObservationStacks : 1; - if (stacks > 1) - { - observationShape[2] *= stacks; - } - - return observationShape; - } } diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs index 0467d5025b..4dbc1303c2 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs @@ -27,19 +27,5 @@ public override ISensor CreateSensor() return new Match3Sensor(board, ObservationType, SensorName); } - /// - public override int[] GetObservationShape() - { - var board = GetComponent(); - if (board == null) - { - return System.Array.Empty(); - } - - var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1; - return ObservationType == Match3ObservationType.Vector ? - new[] { board.Rows * board.Columns * (board.NumCellTypes + specialSize) } : - new[] { board.Rows, board.Columns, board.NumCellTypes + specialSize }; - } } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs index 58619245ec..b8b2ac8017 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs @@ -21,26 +21,6 @@ public override ISensor CreateSensor() return new PhysicsBodySensor(RootBody, Settings, sensorName); } - /// - public override int[] GetObservationShape() - { - if (RootBody == null) - { - return new[] { 0 }; - } - - // TODO static method in PhysicsBodySensor? - // TODO only update PoseExtractor when body changes? - var poseExtractor = new ArticulationBodyPoseExtractor(RootBody); - var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); - var numJointObservations = 0; - - foreach(var articBody in poseExtractor.GetEnabledArticulationBodies()) - { - numJointObservations += ArticulationBodyJointExtractor.NumObservations(articBody, Settings); - } - return new[] { numPoseObservations + numJointObservations }; - } } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs index ad3886ace7..8224139d5e 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs @@ -917,13 +917,6 @@ public ObservationSpec GetObservationSpec() return m_ObservationSpec; } - /// - public override int[] GetObservationShape() - { - var shape = m_ObservationSpec.Shape; - return new int[] { shape[0], shape[1], shape[2] }; - } - /// public int Write(ObservationWriter writer) { diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs index 32d35603bf..edcfd16966 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs @@ -48,7 +48,7 @@ string sensorName } #if UNITY_2020_1_OR_NEWER - public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null) + public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null) { var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); m_PoseExtractor = poseExtractor; @@ -57,7 +57,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin var numJointExtractorObservations = 0; m_JointExtractors = new List(poseExtractor.NumEnabledPoses); - foreach(var articBody in poseExtractor.GetEnabledArticulationBodies()) + foreach (var articBody in poseExtractor.GetEnabledArticulationBodies()) { var jointExtractor = new ArticulationBodyJointExtractor(articBody); numJointExtractorObservations += jointExtractor.NumObservations(settings); @@ -67,6 +67,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); } + #endif /// @@ -126,6 +127,5 @@ public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.PhysicsBodySensor; } - } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs index 21e0fa0586..03da767b8b 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs @@ -45,26 +45,6 @@ public override ISensor CreateSensor() return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName); } - /// - public override int[] GetObservationShape() - { - if (RootBody == null) - { - return new[] { 0 }; - } - - var poseExtractor = GetPoseExtractor(); - var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); - - var numJointObservations = 0; - foreach (var rb in poseExtractor.GetEnabledRigidbodies()) - { - var joint = rb.GetComponent(); - numJointObservations += RigidBodyJointExtractor.NumObservations(rb, joint, Settings); - } - return new[] { numPoseObservations + numJointObservations }; - } - /// /// Get the DisplayNodes of the hierarchy. /// diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs index 9013d3d73f..bee5d183da 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs @@ -18,7 +18,7 @@ public class Match3SensorTests public void TestVectorObservations() { var boardString = - @"000 +@"000 000 010"; var gameObj = new GameObject("board"); @@ -29,9 +29,8 @@ public void TestVectorObservations() sensorComponent.ObservationType = Match3ObservationType.Vector; var sensor = sensorComponent.CreateSensor(); - var expectedShape = new[] { 3 * 3 * 2 }; - Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); + var expectedShape = new InplaceArray(3 * 3 * 2); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); var expectedObs = new float[] { @@ -46,11 +45,11 @@ public void TestVectorObservations() public void TestVectorObservationsSpecial() { var boardString = - @"000 +@"000 000 010"; var specialString = - @"010 +@"010 200 000"; @@ -63,9 +62,8 @@ public void TestVectorObservationsSpecial() sensorComponent.ObservationType = Match3ObservationType.Vector; var sensor = sensorComponent.CreateSensor(); - var expectedShape = new[] { 3 * 3 * (2 + 3) }; - Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); + var expectedShape = new InplaceArray(3 * 3 * (2 + 3)); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); var expectedObs = new float[] { @@ -76,12 +74,11 @@ public void TestVectorObservationsSpecial() SensorTestHelper.CompareObservation(sensor, expectedObs); } - [Test] public void TestVisualObservations() { var boardString = - @"000 +@"000 000 010"; var gameObj = new GameObject("board"); @@ -92,9 +89,8 @@ public void TestVisualObservations() sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; var sensor = sensorComponent.CreateSensor(); - var expectedShape = new[] { 3, 3, 2 }; - Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); + var expectedShape = new InplaceArray(3, 3, 2); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType); @@ -119,11 +115,11 @@ public void TestVisualObservations() public void TestVisualObservationsSpecial() { var boardString = - @"000 +@"000 000 010"; var specialString = - @"010 +@"010 200 000"; @@ -136,9 +132,8 @@ public void TestVisualObservationsSpecial() sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; var sensor = sensorComponent.CreateSensor(); - var expectedShape = new[] { 3, 3, 2 + 3 }; - Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); + var expectedShape = new InplaceArray(3, 3, 2 + 3); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType); @@ -163,7 +158,7 @@ public void TestVisualObservationsSpecial() public void TestCompressedVisualObservations() { var boardString = - @"000 +@"000 000 010"; var gameObj = new GameObject("board"); @@ -174,9 +169,8 @@ public void TestCompressedVisualObservations() sensorComponent.ObservationType = Match3ObservationType.CompressedVisual; var sensor = sensorComponent.CreateSensor(); - var expectedShape = new[] { 3, 3, 2 }; - Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); + var expectedShape = new InplaceArray(3, 3, 2); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType); @@ -191,17 +185,15 @@ public void TestCompressedVisualObservations() Assert.AreEqual(expectedPng, pngData); } - - [Test] public void TestCompressedVisualObservationsSpecial() { var boardString = - @"000 +@"000 000 010"; var specialString = - @"010 +@"010 200 000"; @@ -214,9 +206,8 @@ public void TestCompressedVisualObservationsSpecial() sensorComponent.ObservationType = Match3ObservationType.CompressedVisual; var sensor = sensorComponent.CreateSensor(); - var expectedShape = new[] { 3, 3, 2 + 3 }; - Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape()); - Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); + var expectedShape = new InplaceArray(3, 3, 2 + 3); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType); @@ -229,7 +220,6 @@ public void TestCompressedVisualObservationsSpecial() } var expectedPng = LoadPNGs(pathPrefix, 2); Assert.AreEqual(expectedPng, concatenatedPngData); - } /// @@ -306,7 +296,6 @@ byte[] LoadPNGs(string pathPrefix, int numExpected) } return bytesOut.ToArray(); - } } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs index 3faa0ce49a..dd5fcc18bc 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs @@ -34,9 +34,8 @@ public void OneChannelDepthOne() 1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors); gridSensor.Start(); - int[] expectedShape = { 10, 10, 1 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); - + var expectedShape = new InplaceArray(10, 10, 1); + Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape); } @@ -51,9 +50,8 @@ public void OneChannelDepthTwo() 1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors); gridSensor.Start(); - int[] expectedShape = { 10, 10, 2 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); - + var expectedShape = new InplaceArray(10, 10, 2); + Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape); } [Test] @@ -66,8 +64,8 @@ public void TwoChannelsDepthTwoOne() 1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors); gridSensor.Start(); - int[] expectedShape = { 10, 10, 3 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); + var expectedShape = new InplaceArray(10, 10, 3); + Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape); } @@ -81,9 +79,8 @@ public void TwoChannelsDepthThreeThree() 1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors); gridSensor.Start(); - int[] expectedShape = { 10, 10, 6 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); - + var expectedShape = new InplaceArray(10, 10, 6); + Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape); } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs index 1f71f827b7..7234e54a77 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs @@ -34,8 +34,9 @@ public void OneChannel() 1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors); gridSensor.Start(); - int[] expectedShape = { 10, 10, 1 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); + var expectedShape = new InplaceArray(10, 10, 1); + Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape); + } [Test] @@ -48,8 +49,9 @@ public void TwoChannel() 1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors); gridSensor.Start(); - int[] expectedShape = { 10, 10, 2 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); + var expectedShape = new InplaceArray(10, 10, 2); + Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape); + } [Test] @@ -62,8 +64,9 @@ public void SevenChannel() 1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors); gridSensor.Start(); - int[] expectedShape = { 10, 10, 7 }; - GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape()); + var expectedShape = new InplaceArray(10, 10, 7); + Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape); + } } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs index eef171784c..9d7de21023 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs @@ -66,33 +66,6 @@ public static float[][] DuplicateArray(float[] array, int numCopies) return duplicated; } - /// - /// Asserts that 2 int arrays are the same - /// - /// The expected array - /// The actual array - public static void AssertArraysAreEqual(int[] expected, int[] actual) - { - Assert.AreEqual(expected.Length, actual.Length, "Lengths are not the same"); - for (int i = 0; i < actual.Length; i++) - { - Assert.AreEqual(expected[i], actual[i], "Got " + Array2Str(actual) + ", expected " + Array2Str(expected)); - } - } - - /// - /// Asserts that 2 float arrays are the same - /// - /// The expected array - /// The actual array - public static void AssertArraysAreEqual(float[] expected, float[] actual) - { - Assert.AreEqual(expected.Length, actual.Length, "Lengths are not the same"); - for (int i = 0; i < actual.Length; i++) - { - Assert.AreEqual(expected[i], actual[i], "Got " + Array2Str(actual) + ", expected " + Array2Str(expected)); - } - } /// /// Asserts that the sub-arrays of the total array are equal to specific subarrays at specific subarray indicies and equal to a default everywhere else. diff --git a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs index 5c0684f1ca..c88c494b0a 100644 --- a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs @@ -42,7 +42,7 @@ public void TestSingleBody() 0f, 0f, 0f, 1f // LocalSpaceRotations }; SensorTestHelper.CompareObservation(sensor, expected); - Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); } [Test] @@ -110,7 +110,7 @@ public void TestBodiesWithJoint() #endif }; SensorTestHelper.CompareObservation(sensor, expected); - Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); // Update the settings to only process joint observations sensorComponent.Settings = new PhysicsSensorSettings @@ -133,7 +133,7 @@ public void TestBodiesWithJoint() 0f, // joint2.force }; SensorTestHelper.CompareObservation(sensor, expected); - Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); } } } diff --git a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs index 99911147cd..3bf956c210 100644 --- a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs @@ -56,7 +56,7 @@ public void TestSingleRigidbody() // The root body is ignored since it always generates identity values // and there are no other bodies to generate observations. var expected = new float[0]; - Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); SensorTestHelper.CompareObservation(sensor, expected); } @@ -115,7 +115,7 @@ public void TestBodiesWithJoint() -1f, 1f, 0f, // Attached vel 0f, -1f, 1f // Leaf vel }; - Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); SensorTestHelper.CompareObservation(sensor, expected); // Update the settings to only process joint observations @@ -136,7 +136,7 @@ public void TestBodiesWithJoint() 0f, 0f, 0f, // joint2.torque }; SensorTestHelper.CompareObservation(sensor, expected); - Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); + Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); } } diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 8005f3c416..812efe7e46 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -532,7 +532,7 @@ internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEv NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits, }; } - #endregion + #endregion } } diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs index 5c08546da9..09db85d224 100644 --- a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs @@ -110,7 +110,5 @@ public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.BufferSensor; } - } - } diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs index 26ddc62f0d..2bf357b47e 100644 --- a/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs @@ -55,12 +55,6 @@ public override ISensor CreateSensor() return m_Sensor; } - /// - public override int[] GetObservationShape() - { - return new[] { MaxNumObservables, ObservableSize }; - } - /// /// Appends an observation to the buffer. If the buffer is full (maximum number /// of observation is reached) the observation will be ignored. the length of diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index 6dae738c5d..3cede2408d 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -197,6 +197,5 @@ public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.CameraSensor; } - } } diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs index ebb9d9dc73..0c677ee585 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs @@ -117,21 +117,6 @@ public override ISensor CreateSensor() return m_Sensor; } - /// - /// Computes the observation shape of the sensor. - /// - /// The observation shape of the associated object. - public override int[] GetObservationShape() - { - var stacks = ObservationStacks > 1 ? ObservationStacks : 1; - var cameraSensorshape = CameraSensor.GenerateShape(m_Width, m_Height, Grayscale); - if (stacks > 1) - { - cameraSensorshape[cameraSensorshape.Length - 1] *= stacks; - } - return cameraSensorshape; - } - /// /// Update fields that are safe to change on the Sensor at runtime. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs index d1a332e525..d80bcfc1b5 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs @@ -219,20 +219,6 @@ internal static float[] GetRayAngles(int raysPerDirection, float maxRayDegrees) return anglesOut; } - /// - /// Returns the observation shape for this raycast sensor which depends on the number - /// of tags for detected objects and the number of rays. - /// - /// - public override int[] GetObservationShape() - { - var numRays = 2 * RaysPerDirection + 1; - var numTags = m_DetectableTags?.Count ?? 0; - var obsSize = (numTags + 2) * numRays; - var stacks = ObservationStacks > 1 ? ObservationStacks : 1; - return new[] { obsSize * stacks }; - } - /// /// Get the RayPerceptionInput that is used by the . /// diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs index 9d77f7794e..9a0219146e 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -107,6 +107,5 @@ public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.ReflectionSensor; } - } } diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs index 745ad023ae..7aec8d57eb 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs @@ -99,7 +99,6 @@ public BuiltInSensorType GetBuiltInSensorType() return BuiltInSensorType.RenderTextureSensor; } - /// /// Converts a RenderTexture to a 2D texture. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs index d22a3ffcaf..32c6bec07d 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs @@ -92,22 +92,6 @@ public override ISensor CreateSensor() return m_Sensor; } - /// - public override int[] GetObservationShape() - { - var width = RenderTexture != null ? RenderTexture.width : 0; - var height = RenderTexture != null ? RenderTexture.height : 0; - var observationShape = new[] { height, width, Grayscale ? 1 : 3 }; - - var stacks = ObservationStacks > 1 ? ObservationStacks : 1; - if (stacks > 1) - { - observationShape[2] *= stacks; - } - - return observationShape; - } - /// /// Update fields that are safe to change on the Sensor at runtime. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs index a217e5d159..b270cdd646 100644 --- a/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs @@ -14,12 +14,5 @@ public abstract class SensorComponent : MonoBehaviour /// /// Created ISensor object. public abstract ISensor CreateSensor(); - - /// - /// Returns the shape of the sensor observations that will be created. - /// - /// Shape of the sensor observation. - public abstract int[] GetObservationShape(); - } } diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs index e116c0108b..4bfa8196ac 100644 --- a/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs @@ -77,7 +77,6 @@ public void TestEnableAnalytics() #else Assert.IsFalse(TrainingAnalytics.EnableAnalytics()); #endif - } } } diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs index 540671a345..059f1a0a13 100644 --- a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs @@ -18,13 +18,8 @@ public override ISensor CreateSensor() { return Sensor; } - - public override int[] GetObservationShape() - { - var shape = Sensor.GetObservationSpec().Shape; - return new int[] { shape[0], shape[1], shape[2] }; - } } + public class Test3DSensor : ISensor, IBuiltInSensor { int m_Width; diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index fa160c312d..221fe71cf7 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -13,7 +13,6 @@ namespace Unity.MLAgents.Tests { - [TestFixture] public class EditModeTestGeneration { diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs index 1905f3dbd8..ea0755856b 100644 --- a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs +++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs @@ -37,17 +37,6 @@ public override ISensor CreateSensor() var wrappedSensor = wrappedComponent.CreateSensor(); return new StackingSensor(wrappedSensor, numStacks); } - - public override int[] GetObservationShape() - { - int[] shape = (int[])wrappedComponent.GetObservationShape().Clone(); - for (var i = 0; i < shape.Length; i++) - { - shape[i] *= numStacks; - } - - return shape; - } } public class RuntimeApiTest diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs index 376984e578..c17f6cecb5 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs @@ -56,7 +56,7 @@ public void TestBufferSensorComponent() bufferComponent.SensorName = "TestName"; var sensor = bufferComponent.CreateSensor(); - var shape = bufferComponent.GetObservationShape(); + var shape = sensor.GetObservationSpec().Shape; Assert.AreEqual(shape[0], 20); Assert.AreEqual(shape[1], 4); @@ -68,7 +68,7 @@ public void TestBufferSensorComponent() var obsWriter = new ObservationWriter(); var obs = sensor.GetObservationProto(obsWriter); - Assert.AreEqual(shape, obs.Shape); + Assert.AreEqual(shape, InplaceArray.FromList(obs.Shape)); Assert.AreEqual(obs.DimensionProperties.Count, 2); Assert.AreEqual(sensor.GetName(), "TestName"); diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs index 5757daf6e0..5bb2c74fe1 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs @@ -29,12 +29,9 @@ public void TestCameraSensorComponent() cameraComponent.Grayscale = grayscale; cameraComponent.CompressionType = compression; - var expectedShape = new[] { height, width, grayscale ? 1 : 3 }; - Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape()); - var sensor = cameraComponent.CreateSensor(); - var expectedShapeInplace = new InplaceArray(height, width, grayscale ? 1 : 3); - Assert.AreEqual(expectedShapeInplace, sensor.GetObservationSpec().Shape); + var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(CameraSensor), sensor.GetType()); } } diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs index d28dd0bd1b..d21e544dd7 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs @@ -26,11 +26,10 @@ public void TestRenderTextureSensorComponent() renderTexComponent.Grayscale = grayscale; renderTexComponent.CompressionType = compression; - var expectedShape = new[] { height, width, grayscale ? 1 : 3 }; - Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape()); + var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3); var sensor = renderTexComponent.CreateSensor(); - Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType()); } } diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs index d03b2a4a4d..84b69b6172 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs @@ -58,7 +58,6 @@ public CompressionSpec GetCompressionSpec() public class SensorShapeValidatorTests { - [Test] public void TestShapesAgree() { @@ -87,6 +86,7 @@ public void TestNumSensorMismatch() LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3"); validator.ValidateSensors(sensorList1); } + [Test] public void TestDimensionMismatch() { diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs index c959192720..df5644e999 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs @@ -166,7 +166,6 @@ public string GetName() { return "Dummy"; } - } [Test] diff --git a/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs b/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs index ba5013900a..f17b88279e 100644 --- a/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs +++ b/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs @@ -165,6 +165,5 @@ public void Reset() public class TestClasses { - } } From d53481af6379d838de6764d4552e9015545d8395 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 22 Mar 2021 14:45:15 -0700 Subject: [PATCH 2/3] changelog and migration --- com.unity.ml-agents/CHANGELOG.md | 1 + docs/Migrating.md | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index a1e2d3829e..e0e19b9aa1 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -22,6 +22,7 @@ details. and `IDimensionPropertiesSensor` interfaces were removed. (#5127) - `ISensor.GetCompressionType()` was removed, and `GetCompressionSpec()` was added. The `ISparseChannelSensor` interface was removed. (#5164) +- The abstract method `SensorComponent.GetObservationShape()` was no longer being called, so it has been removed. (#5172) #### ml-agents / ml-agents-envs / gym-unity (Python) diff --git a/docs/Migrating.md b/docs/Migrating.md index 3c98b05498..9971e96d6b 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -44,9 +44,11 @@ public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask) actionMask.SetActionEnabled(branch, 3, false); } ``` +### IActuator changes - The `IActuator` interface now implements `IHeuristicProvider`. Please add the corresponding `Heuristic(in ActionBuffers)` method to your custom Actuator classes. +### ISensor and SensorComponent changes - The `ISensor.GetObservationShape()` method and `ITypedSensor` and `IDimensionPropertiesSensor` interfaces were removed, and `GetObservationSpec()` was added. You can use `ObservationSpec.Vector()` or `ObservationSpec.Visual()` to generate `ObservationSpec`s that are equivalent to @@ -88,6 +90,8 @@ public CompressionSpec GetCompressionSpec() } ``` +- The abstract method `SensorComponent.GetObservationShape()` was removed. + ## Migrating to Release 13 ### Implementing IHeuristic in your IActuator implementations - If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator From 685e177f79cba83ddb070cc44f077ca20d718278 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 22 Mar 2021 15:00:42 -0700 Subject: [PATCH 3/3] fix raw strings --- .../Tests/Editor/Match3/Match3SensorTests.cs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs index bee5d183da..80438f857e 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs @@ -18,7 +18,7 @@ public class Match3SensorTests public void TestVectorObservations() { var boardString = -@"000 + @"000 000 010"; var gameObj = new GameObject("board"); @@ -45,11 +45,11 @@ public void TestVectorObservations() public void TestVectorObservationsSpecial() { var boardString = -@"000 + @"000 000 010"; var specialString = -@"010 + @"010 200 000"; @@ -78,7 +78,7 @@ public void TestVectorObservationsSpecial() public void TestVisualObservations() { var boardString = -@"000 + @"000 000 010"; var gameObj = new GameObject("board"); @@ -115,11 +115,11 @@ public void TestVisualObservations() public void TestVisualObservationsSpecial() { var boardString = -@"000 + @"000 000 010"; var specialString = -@"010 + @"010 200 000"; @@ -158,7 +158,7 @@ public void TestVisualObservationsSpecial() public void TestCompressedVisualObservations() { var boardString = -@"000 + @"000 000 010"; var gameObj = new GameObject("board"); @@ -189,11 +189,11 @@ public void TestCompressedVisualObservations() public void TestCompressedVisualObservationsSpecial() { var boardString = -@"000 + @"000 000 010"; var specialString = -@"010 + @"010 200 000";