-
Notifications
You must be signed in to change notification settings - Fork 4.4k
[MLA-1634] Compression spec #5164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
099f8b9
7d30a97
53cc277
6c68ea1
326ed43
80c974b
35aea24
d2aaca4
9097be8
2a9079b
3422b8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -342,7 +342,8 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat | |
| var obsSpec = sensor.GetObservationSpec(); | ||
| var shape = obsSpec.Shape; | ||
| ObservationProto observationProto = null; | ||
| var compressionType = sensor.GetCompressionType(); | ||
| var compressionSpec = sensor.GetCompressionSpec(); | ||
| var compressionType = compressionSpec.SensorCompressionType; | ||
| // Check capabilities if we need to concatenate PNGs | ||
| if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3) | ||
| { | ||
|
|
@@ -365,7 +366,7 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat | |
| if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3) | ||
| { | ||
| var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping; | ||
| var isTrivialMapping = IsTrivialMapping(sensor); | ||
| var isTrivialMapping = compressionSpec.IsTrivialMapping(); | ||
| if (!trainerCanHandleMapping && !isTrivialMapping) | ||
| { | ||
| if (!s_HaveWarnedTrainerCapabilitiesMapping) | ||
|
|
@@ -411,18 +412,17 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat | |
| throw new UnityAgentsException( | ||
| $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " + | ||
| "You must return a byte[]. If you don't want to use compressed observations, " + | ||
| "return SensorCompressionType.None from GetCompressionType()." | ||
| "return CompressionSpec.Default() from GetCompressionSpec()." | ||
| ); | ||
| } | ||
| observationProto = new ObservationProto | ||
| { | ||
| CompressedData = ByteString.CopyFrom(compressedObs), | ||
| CompressionType = (CompressionTypeProto)sensor.GetCompressionType(), | ||
| CompressionType = (CompressionTypeProto)sensor.GetCompressionSpec().SensorCompressionType, | ||
| }; | ||
| var compressibleSensor = sensor as ISparseChannelSensor; | ||
| if (compressibleSensor != null) | ||
| if (compressionSpec.CompressedChannelMapping != null) | ||
| { | ||
| observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping()); | ||
| observationProto.CompressedChannelMapping.AddRange(compressionSpec.CompressedChannelMapping); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -488,34 +488,6 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps) | |
| }; | ||
| } | ||
|
|
||
| internal static bool IsTrivialMapping(ISensor sensor) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved this logic to CompressionSpec |
||
| { | ||
| var compressibleSensor = sensor as ISparseChannelSensor; | ||
| if (compressibleSensor is null) | ||
| { | ||
| return true; | ||
| } | ||
| var mapping = compressibleSensor.GetCompressedChannelMapping(); | ||
| if (mapping == null) | ||
| { | ||
| return true; | ||
| } | ||
| // check if mapping equals zero mapping | ||
| if (mapping.Length == 3 && mapping.All(m => m == 0)) | ||
| { | ||
| return true; | ||
| } | ||
| // check if mapping equals identity mapping | ||
| for (var i = 0; i < mapping.Length; i++) | ||
| { | ||
| if (mapping[i] != i) | ||
| { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| #region Analytics | ||
| internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent( | ||
| this TrainingEnvironmentInitialized inputProto) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| using System.Linq; | ||
| namespace Unity.MLAgents.Sensors | ||
| { | ||
| /// <summary> | ||
| /// The compression setting for visual/camera observations. | ||
| /// </summary> | ||
| public enum SensorCompressionType | ||
| { | ||
| /// <summary> | ||
| /// No compression. Data is preserved as float arrays. | ||
| /// </summary> | ||
| None, | ||
|
|
||
| /// <summary> | ||
| /// PNG format. Data will be stored in binary format. | ||
| /// </summary> | ||
| PNG | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// A description of the compression used for observations. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// Most ISensor implementations can't take advantage of compression, | ||
| /// and should return CompressionSpec.Default() from their ISensor.GetCompressionSpec() methods. | ||
| /// Visual observations, or mulitdimensional categorical observations (for example, image segmentation | ||
| /// or the piece types in a match-3 game board) can use PNG compression reduce the amount of | ||
| /// data transferred between Unity and the trainer. | ||
| /// </remarks> | ||
| public struct CompressionSpec | ||
| { | ||
| internal SensorCompressionType m_SensorCompressionType; | ||
|
|
||
| /// <summary> | ||
| /// The compression type that the sensor will use for its observations. | ||
| /// </summary> | ||
| public SensorCompressionType SensorCompressionType | ||
| { | ||
| get => m_SensorCompressionType; | ||
| } | ||
|
|
||
| internal int[] m_CompressedChannelMapping; | ||
|
|
||
| /// The mapping of the channels in compressed data to the actual channel after decompression. | ||
| /// The mapping is a list of integer index with the same length as | ||
| /// the number of output observation layers (channels), including padding if there's any. | ||
| /// Each index indicates the actual channel the layer will go into. | ||
| /// Layers with the same index will be averaged, and layers with negative index will be dropped. | ||
| /// For example, mapping for CameraSensor using grayscale and stacking of two: [0, 0, 0, 1, 1, 1] | ||
| /// Mapping for GridSensor of 4 channels and stacking of two: [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1] | ||
| public int[] CompressedChannelMapping | ||
| { | ||
| get => m_CompressedChannelMapping; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Return a CompressionSpec indicating no compression. This is recommended for most sensors. | ||
| /// </summary> | ||
| /// <returns></returns> | ||
| public static CompressionSpec Default() | ||
| { | ||
| return new CompressionSpec | ||
| { | ||
| m_SensorCompressionType = SensorCompressionType.None, | ||
| m_CompressedChannelMapping = null | ||
| }; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Return a CompressionSpec indicating possible compression. | ||
| /// </summary> | ||
| /// <param name="sensorCompressionType">The compression type to use.</param> | ||
| /// <param name="compressedChannelMapping">Optional mapping mapping of the channels in compressed data to the actual channel after decompression.</param> | ||
| /// <returns></returns> | ||
| public static CompressionSpec Compressed(SensorCompressionType sensorCompressionType, int[] compressedChannelMapping = null) | ||
| { | ||
| return new CompressionSpec | ||
| { | ||
| m_SensorCompressionType = sensorCompressionType, | ||
| m_CompressedChannelMapping = compressedChannelMapping | ||
| }; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. technically you can pass in SensorCompressionType.None and non-null mapping? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, you can. I think I'll just make this the constructor instead of using a static method. |
||
| } | ||
|
|
||
| /// <summary> | ||
| /// Return whether the compressed channel mapping is "trivial"; if so it doesn't need to be sent to the | ||
| /// trainer. | ||
| /// </summary> | ||
| /// <returns></returns> | ||
| internal bool IsTrivialMapping() | ||
| { | ||
| var mapping = CompressedChannelMapping; | ||
| if (mapping == null) | ||
| { | ||
| return true; | ||
| } | ||
| // check if mapping equals zero mapping | ||
| if (mapping.Length == 3 && mapping.All(m => m == 0)) | ||
| { | ||
| return true; | ||
| } | ||
| // check if mapping equals identity mapping | ||
| for (var i = 0; i < mapping.Length; i++) | ||
| { | ||
| if (mapping[i] != i) | ||
| { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
| } | ||
| } | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A potential improvement of this is to have a m_CompressionSpec and return that like we do with
m_ObservationSpec.Not sure if it's worthy though given this looks pretty light weight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cached the shapes previously, because they required allocating memory. I kept that pattern for the
ObservationSpecs - it's not really necessary for performance (struct with InplaceArray, so no allocations), but it's not a bad idea since they shouldn't change at runtime. I don't think it's necessary for CompressionSpecs.