diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs new file mode 100644 index 0000000000..fdbdc60322 --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs @@ -0,0 +1,108 @@ +using System; +using NUnit.Framework; +using MLAgents; +using System.Collections.Generic; +using System.Text; + +namespace MLAgents.Tests +{ + public class SideChannelTests + { + + // This test side channel only deals in integers + public class TestSideChannel : SideChannel + { + + public List m_MessagesReceived = new List(); + + public override int ChannelType() { return -1; } + + public override void OnMessageReceived(byte[] data) + { + m_MessagesReceived.Add(BitConverter.ToInt32(data, 0)); + } + + public void SendInt(int data) + { + QueueMessageToSend(BitConverter.GetBytes(data)); + } + } + + [Test] + public void TestIntegerSideChannel() + { + var intSender = new TestSideChannel(); + var intReceiver = new TestSideChannel(); + var dictSender = new Dictionary { { intSender.ChannelType(), intSender } }; + var dictReceiver = new Dictionary { { intReceiver.ChannelType(), intReceiver } }; + + intSender.SendInt(4); + intSender.SendInt(5); + intSender.SendInt(6); + + byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender); + RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData); + + Assert.AreEqual(intReceiver.m_MessagesReceived[0], 4); + Assert.AreEqual(intReceiver.m_MessagesReceived[1], 5); + Assert.AreEqual(intReceiver.m_MessagesReceived[2], 6); + } + + [Test] + public void TestRawBytesSideChannel() + { + var str1 = "Test string"; + var str2 = "Test string, second"; + + var strSender = new RawBytesChannel(); + var strReceiver = new RawBytesChannel(); + var dictSender = new Dictionary { { strSender.ChannelType(), strSender } }; + var dictReceiver = new Dictionary { { strReceiver.ChannelType(), strReceiver } }; + + strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1)); + strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2)); + + byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender); + RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData); + + var messages = strReceiver.GetAndClearReceivedMessages(); + + Assert.AreEqual(messages.Count, 2); + Assert.AreEqual(Encoding.ASCII.GetString(messages[0]), str1); + Assert.AreEqual(Encoding.ASCII.GetString(messages[1]), str2); + } + + [Test] + public void TestFloatPropertiesSideChannel() + { + var k1 = "gravity"; + var k2 = "length"; + int wasCalled = 0; + + var propA = new FloatPropertiesChannel(); + var propB = new FloatPropertiesChannel(); + var dictReceiver = new Dictionary { { propA.ChannelType(), propA } }; + var dictSender = new Dictionary { { propB.ChannelType(), propB } }; + + propA.RegisterCallback(k1, f => { wasCalled++; }); + var tmp = propB.GetPropertyWithDefault(k2, 3.0f); + Assert.AreEqual(tmp, 3.0f); + propB.SetProperty(k2, 1.0f); + tmp = propB.GetPropertyWithDefault(k2, 3.0f); + Assert.AreEqual(tmp, 1.0f); + + byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender); + RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData); + + tmp = propA.GetPropertyWithDefault(k2, 3.0f); + Assert.AreEqual(tmp, 1.0f); + + Assert.AreEqual(wasCalled, 0); + propB.SetProperty(k1, 1.0f); + Assert.AreEqual(wasCalled, 0); + fakeData = RpcCommunicator.GetSideChannelMessage(dictSender); + RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData); + Assert.AreEqual(wasCalled, 1); + } + } +} diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs.meta b/UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs.meta new file mode 100644 index 0000000000..cef0d1104e --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 589f475debcdb479295a24799777b5e5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Academy.cs b/UnitySDK/Assets/ML-Agents/Scripts/Academy.cs index 9fd45df26b..b89077acfe 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/Academy.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/Academy.cs @@ -136,6 +136,9 @@ public abstract class Academy : MonoBehaviour [Tooltip("List of custom parameters that can be changed in the " + "environment when it resets.")] public ResetParameters resetParameters; + + public IFloatProperties FloatProperties; + public CommunicatorObjects.CustomResetParametersProto customResetParameters; // Fields not provided in the Inspector. @@ -265,6 +268,8 @@ void InitializeEnvironment() m_OriginalMaximumDeltaTime = Time.maximumDeltaTime; InitializeAcademy(); + var floatProperties = new FloatPropertiesChannel(); + FloatProperties = floatProperties; // Try to launch the communicator by using the arguments passed at launch try @@ -316,6 +321,8 @@ void InitializeEnvironment() Communicator.QuitCommandReceived += OnQuitCommandReceived; Communicator.ResetCommandReceived += OnResetCommand; Communicator.RLInputReceived += OnRLInputReceived; + Communicator.RegisterSideChannel(new EngineConfigurationChannel()); + Communicator.RegisterSideChannel(floatProperties); } } diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInput.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInput.cs index 2a9f416de7..563a28907f 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInput.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInput.cs @@ -29,22 +29,23 @@ static UnityRlInputReflection() { "ZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9hY3Rpb24ucHJvdG8a", "P21sYWdlbnRzL2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvZW52aXJvbm1l", "bnRfcGFyYW1ldGVycy5wcm90bxowbWxhZ2VudHMvZW52cy9jb21tdW5pY2F0", - "b3Jfb2JqZWN0cy9jb21tYW5kLnByb3RvIsMDChFVbml0eVJMSW5wdXRQcm90", + "b3Jfb2JqZWN0cy9jb21tYW5kLnByb3RvItkDChFVbml0eVJMSW5wdXRQcm90", "bxJQCg1hZ2VudF9hY3Rpb25zGAEgAygLMjkuY29tbXVuaWNhdG9yX29iamVj", "dHMuVW5pdHlSTElucHV0UHJvdG8uQWdlbnRBY3Rpb25zRW50cnkSUAoWZW52", "aXJvbm1lbnRfcGFyYW1ldGVycxgCIAEoCzIwLmNvbW11bmljYXRvcl9vYmpl", "Y3RzLkVudmlyb25tZW50UGFyYW1ldGVyc1Byb3RvEhMKC2lzX3RyYWluaW5n", "GAMgASgIEjMKB2NvbW1hbmQYBCABKA4yIi5jb21tdW5pY2F0b3Jfb2JqZWN0", - "cy5Db21tYW5kUHJvdG8aTQoUTGlzdEFnZW50QWN0aW9uUHJvdG8SNQoFdmFs", - "dWUYASADKAsyJi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5BZ2VudEFjdGlvblBy", - "b3RvGnEKEUFnZW50QWN0aW9uc0VudHJ5EgsKA2tleRgBIAEoCRJLCgV2YWx1", - "ZRgCIAEoCzI8LmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxJbnB1dFBy", - "b3RvLkxpc3RBZ2VudEFjdGlvblByb3RvOgI4AUIfqgIcTUxBZ2VudHMuQ29t", - "bXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z")); + "cy5Db21tYW5kUHJvdG8SFAoMc2lkZV9jaGFubmVsGAUgASgMGk0KFExpc3RB", + "Z2VudEFjdGlvblByb3RvEjUKBXZhbHVlGAEgAygLMiYuY29tbXVuaWNhdG9y", + "X29iamVjdHMuQWdlbnRBY3Rpb25Qcm90bxpxChFBZ2VudEFjdGlvbnNFbnRy", + "eRILCgNrZXkYASABKAkSSwoFdmFsdWUYAiABKAsyPC5jb21tdW5pY2F0b3Jf", + "b2JqZWN0cy5Vbml0eVJMSW5wdXRQcm90by5MaXN0QWdlbnRBY3Rpb25Qcm90", + "bzoCOAFCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3Rv", + "Mw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentActionReflection.Descriptor, global::MLAgents.CommunicatorObjects.EnvironmentParametersReflection.Descriptor, global::MLAgents.CommunicatorObjects.CommandReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Parser, new[]{ "AgentActions", "EnvironmentParameters", "IsTraining", "Command" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Parser, new[]{ "AgentActions", "EnvironmentParameters", "IsTraining", "Command", "SideChannel" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null), null, }) })); } @@ -81,6 +82,7 @@ public UnityRLInputProto(UnityRLInputProto other) : this() { EnvironmentParameters = other.environmentParameters_ != null ? other.EnvironmentParameters.Clone() : null; isTraining_ = other.isTraining_; command_ = other.command_; + sideChannel_ = other.sideChannel_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -132,6 +134,17 @@ public bool IsTraining { } } + /// Field number for the "side_channel" field. + public const int SideChannelFieldNumber = 5; + private pb::ByteString sideChannel_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString SideChannel { + get { return sideChannel_; } + set { + sideChannel_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as UnityRLInputProto); @@ -149,6 +162,7 @@ public bool Equals(UnityRLInputProto other) { if (!object.Equals(EnvironmentParameters, other.EnvironmentParameters)) return false; if (IsTraining != other.IsTraining) return false; if (Command != other.Command) return false; + if (SideChannel != other.SideChannel) return false; return Equals(_unknownFields, other._unknownFields); } @@ -159,6 +173,7 @@ public override int GetHashCode() { if (environmentParameters_ != null) hash ^= EnvironmentParameters.GetHashCode(); if (IsTraining != false) hash ^= IsTraining.GetHashCode(); if (Command != 0) hash ^= Command.GetHashCode(); + if (SideChannel.Length != 0) hash ^= SideChannel.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -185,6 +200,10 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(32); output.WriteEnum((int) Command); } + if (SideChannel.Length != 0) { + output.WriteRawTag(42); + output.WriteBytes(SideChannel); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -203,6 +222,9 @@ public int CalculateSize() { if (Command != 0) { size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Command); } + if (SideChannel.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(SideChannel); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -227,6 +249,9 @@ public void MergeFrom(UnityRLInputProto other) { if (other.Command != 0) { Command = other.Command; } + if (other.SideChannel.Length != 0) { + SideChannel = other.SideChannel; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -257,6 +282,10 @@ public void MergeFrom(pb::CodedInputStream input) { command_ = (global::MLAgents.CommunicatorObjects.CommandProto) input.ReadEnum(); break; } + case 42: { + SideChannel = input.ReadBytes(); + break; + } } } } diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlOutput.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlOutput.cs index f5b381d2f7..1a5082c0ff 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlOutput.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlOutput.cs @@ -26,19 +26,19 @@ static UnityRlOutputReflection() { string.Concat( "CjhtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js", "X291dHB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaM21sYWdlbnRz", - "L2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdlbnRfaW5mby5wcm90byKj", + "L2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdlbnRfaW5mby5wcm90byK5", "AgoSVW5pdHlSTE91dHB1dFByb3RvEkwKCmFnZW50SW5mb3MYAiADKAsyOC5j", "b21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMT3V0cHV0UHJvdG8uQWdlbnRJ", - "bmZvc0VudHJ5GkkKEkxpc3RBZ2VudEluZm9Qcm90bxIzCgV2YWx1ZRgBIAMo", - "CzIkLmNvbW11bmljYXRvcl9vYmplY3RzLkFnZW50SW5mb1Byb3RvGm4KD0Fn", - "ZW50SW5mb3NFbnRyeRILCgNrZXkYASABKAkSSgoFdmFsdWUYAiABKAsyOy5j", - "b21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMT3V0cHV0UHJvdG8uTGlzdEFn", - "ZW50SW5mb1Byb3RvOgI4AUoECAEQAkIfqgIcTUxBZ2VudHMuQ29tbXVuaWNh", - "dG9yT2JqZWN0c2IGcHJvdG8z")); + "bmZvc0VudHJ5EhQKDHNpZGVfY2hhbm5lbBgDIAEoDBpJChJMaXN0QWdlbnRJ", + "bmZvUHJvdG8SMwoFdmFsdWUYASADKAsyJC5jb21tdW5pY2F0b3Jfb2JqZWN0", + "cy5BZ2VudEluZm9Qcm90bxpuCg9BZ2VudEluZm9zRW50cnkSCwoDa2V5GAEg", + "ASgJEkoKBXZhbHVlGAIgASgLMjsuY29tbXVuaWNhdG9yX29iamVjdHMuVW5p", + "dHlSTE91dHB1dFByb3RvLkxpc3RBZ2VudEluZm9Qcm90bzoCOAFKBAgBEAJC", + "H6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentInfoReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Parser, new[]{ "AgentInfos" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Parser, new[]{ "AgentInfos", "SideChannel" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null), null, }) })); } @@ -72,6 +72,7 @@ public UnityRLOutputProto() { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public UnityRLOutputProto(UnityRLOutputProto other) : this() { agentInfos_ = other.agentInfos_.Clone(); + sideChannel_ = other.sideChannel_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -90,6 +91,17 @@ public UnityRLOutputProto Clone() { get { return agentInfos_; } } + /// Field number for the "side_channel" field. + public const int SideChannelFieldNumber = 3; + private pb::ByteString sideChannel_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString SideChannel { + get { return sideChannel_; } + set { + sideChannel_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as UnityRLOutputProto); @@ -104,6 +116,7 @@ public bool Equals(UnityRLOutputProto other) { return true; } if (!AgentInfos.Equals(other.AgentInfos)) return false; + if (SideChannel != other.SideChannel) return false; return Equals(_unknownFields, other._unknownFields); } @@ -111,6 +124,7 @@ public bool Equals(UnityRLOutputProto other) { public override int GetHashCode() { int hash = 1; hash ^= AgentInfos.GetHashCode(); + if (SideChannel.Length != 0) hash ^= SideChannel.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -125,6 +139,10 @@ public override string ToString() { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public void WriteTo(pb::CodedOutputStream output) { agentInfos_.WriteTo(output, _map_agentInfos_codec); + if (SideChannel.Length != 0) { + output.WriteRawTag(26); + output.WriteBytes(SideChannel); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -134,6 +152,9 @@ public void WriteTo(pb::CodedOutputStream output) { public int CalculateSize() { int size = 0; size += agentInfos_.CalculateSize(_map_agentInfos_codec); + if (SideChannel.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(SideChannel); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -146,6 +167,9 @@ public void MergeFrom(UnityRLOutputProto other) { return; } agentInfos_.Add(other.agentInfos_); + if (other.SideChannel.Length != 0) { + SideChannel = other.SideChannel; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -161,6 +185,10 @@ public void MergeFrom(pb::CodedInputStream input) { agentInfos_.AddEntriesFrom(input, _map_agentInfos_codec); break; } + case 26: { + SideChannel = input.ReadBytes(); + break; + } } } } diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs index be89c885e9..63834ae5d0 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs @@ -9,6 +9,8 @@ using System.Linq; using UnityEngine; using MLAgents.CommunicatorObjects; +using System.IO; +using Google.Protobuf; namespace MLAgents { @@ -48,6 +50,8 @@ public class RpcCommunicator : ICommunicator /// The communicator parameters sent at construction CommunicatorInitParameters m_CommunicatorInitParameters; + Dictionary m_SideChannels = new Dictionary(); + /// /// Initializes a new instance of the RPCCommunicator class. /// @@ -136,6 +140,7 @@ void UpdateEnvironmentWithInput(UnityRLInputProto rlInput) { SendRLInputReceivedEvent(rlInput.IsTraining); SendCommandEvent(rlInput.Command, rlInput.EnvironmentParameters); + ProcessSideChannelData(m_SideChannels, rlInput.SideChannel.ToArray()); } UnityInputProto Initialize(UnityOutputProto unityOutput, @@ -284,6 +289,9 @@ void SendBatchedMessageHelper() message.RlInitializationOutput = tempUnityRlInitializationOutput; } + byte[] messageAggregated = GetSideChannelMessage(m_SideChannels); + message.RlOutput.SideChannel = ByteString.CopyFrom(messageAggregated); + var input = Exchange(message); UpdateSentBrainParameters(tempUnityRlInitializationOutput); @@ -434,6 +442,102 @@ void UpdateSentBrainParameters(UnityRLInitializationOutputProto output) #endregion + + #region Handling side channels + + /// + /// Registers a side channel to the communicator. The side channel will exchange + /// messages with its Python equivalent. + /// + /// The side channel to be registered. + public void RegisterSideChannel(SideChannel sideChannel) + { + if (m_SideChannels.ContainsKey(sideChannel.ChannelType())) + { + throw new UnityAgentsException(string.Format( + "A side channel with type index {} is already registered. You cannot register multiple " + + "side channels of the same type.")); + } + m_SideChannels.Add(sideChannel.ChannelType(), sideChannel); + } + + /// + /// Grabs the messages that the registered side channels will send to Python at the current step + /// into a singe byte array. + /// + /// A dictionary of channel type to channel. + /// + public static byte[] GetSideChannelMessage(Dictionary sideChannels) + { + using (var memStream = new MemoryStream()) + { + using (var binaryWriter = new BinaryWriter(memStream)) + { + foreach (var sideChannel in sideChannels.Values) + { + var messageList = sideChannel.MessageQueue; + foreach (var message in messageList) + { + binaryWriter.Write(sideChannel.ChannelType()); + binaryWriter.Write(message.Count()); + binaryWriter.Write(message); + } + sideChannel.MessageQueue.Clear(); + } + return memStream.ToArray(); + } + } + } + + /// + /// Separates the data received from Python into individual messages for each registered side channel. + /// + /// A dictionary of channel type to channel. + /// The byte array of data received from Python. + public static void ProcessSideChannelData(Dictionary sideChannels, byte[] dataReceived) + { + if (dataReceived.Length == 0) + { + return; + } + using (var memStream = new MemoryStream(dataReceived)) + { + using (var binaryReader = new BinaryReader(memStream)) + { + while (memStream.Position < memStream.Length) + { + int channelType = 0; + byte[] message = null; + try + { + channelType = binaryReader.ReadInt32(); + var messageLength = binaryReader.ReadInt32(); + message = binaryReader.ReadBytes(messageLength); + } + catch (Exception ex) + { + throw new UnityAgentsException( + "There was a problem reading a message in a SideChannel. Please make sure the " + + "version of MLAgents in Unity is compatible with the Python version. Original error : " + + ex.Message); + } + if (sideChannels.ContainsKey(channelType)) + { + sideChannels[channelType].OnMessageReceived(message); + } + else + { + Debug.Log(string.Format( + "Unknown side channel data received. Channel type " + + ": {0}", channelType)); + } + } + } + } + } + + #endregion + #if UNITY_EDITOR #if UNITY_2017_2_OR_NEWER /// diff --git a/UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs b/UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs index ef85028440..78cbff6f56 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs @@ -161,5 +161,12 @@ public interface ICommunicator /// A key to identify which actions to get /// Dictionary GetActions(string key); + + /// + /// Registers a side channel to the communicator. The side channel will exchange + /// messages with its Python equivalent. + /// + /// The side channel to be registered. + void RegisterSideChannel(SideChannel sideChannel); } } diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel.meta b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel.meta new file mode 100644 index 0000000000..5678df6fdf --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: cb2f03ed7ea59456380730bd0f9b5bcb +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs new file mode 100644 index 0000000000..f20747dc75 --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs @@ -0,0 +1,36 @@ +using System.Collections.Generic; +using System.IO; +using UnityEngine; + +namespace MLAgents +{ + public class EngineConfigurationChannel : SideChannel + { + + public override int ChannelType() + { + return (int)SideChannelType.EngineSettings; + } + + public override void OnMessageReceived(byte[] data) + { + using (var memStream = new MemoryStream(data)) + { + using (var binaryReader = new BinaryReader(memStream)) + { + var width = binaryReader.ReadInt32(); + var height = binaryReader.ReadInt32(); + var qualityLevel = binaryReader.ReadInt32(); + var timeScale = binaryReader.ReadSingle(); + var targetFrameRate = binaryReader.ReadInt32(); + + Screen.SetResolution(width, height, false); + QualitySettings.SetQualityLevel(qualityLevel, true); + Time.timeScale = timeScale; + Time.captureFramerate = 60; + Application.targetFrameRate = targetFrameRate; + } + } + } + } +} diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs.meta new file mode 100644 index 0000000000..8f6335e9b0 --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/EngineConfigurationChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 18ccdf3ce76784f2db68016fa284c33f +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs new file mode 100644 index 0000000000..99a75c14ef --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs @@ -0,0 +1,123 @@ +using System.Collections.Generic; +using System.IO; +using System; +using System.Text; + +namespace MLAgents +{ + + public interface IFloatProperties + { + /// + /// Sets one of the float properties of the environment. This data will be sent to Python. + /// + /// The string identifier of the property. + /// The float value of the property. + void SetProperty(string key, float value); + + /// + /// Get an Environment property with a default value. If there is a value for this property, + /// it will be returned, otherwise, the default value will be returned. + /// + /// The string identifier of the property. + /// The default value of the property. + /// + float GetPropertyWithDefault(string key, float defaultValue); + + /// + /// Registers an action to be performed everytime the property is changed. + /// + /// The string identifier of the property. + /// The action that ill be performed. Takes a float as input. + void RegisterCallback(string key, Action action); + + /// + /// Returns a list of all the string identifiers of the properties currently present. + /// + /// The list of string identifiers + IList ListProperties(); + } + + public class FloatPropertiesChannel : SideChannel, IFloatProperties + { + + private Dictionary m_FloatProperties = new Dictionary(); + private Dictionary> m_RegisteredActions = new Dictionary>(); + + public override int ChannelType() + { + return (int)SideChannelType.FloatProperties; + } + + public override void OnMessageReceived(byte[] data) + { + var kv = DeserializeMessage(data); + m_FloatProperties[kv.Key] = kv.Value; + if (m_RegisteredActions.ContainsKey(kv.Key)) + { + m_RegisteredActions[kv.Key].Invoke(kv.Value); + } + } + + public void SetProperty(string key, float value) + { + m_FloatProperties[key] = value; + QueueMessageToSend(SerializeMessage(key, value)); + if (m_RegisteredActions.ContainsKey(key)) + { + m_RegisteredActions[key].Invoke(value); + } + } + + public float GetPropertyWithDefault(string key, float defaultValue) + { + if (m_FloatProperties.ContainsKey(key)) + { + return m_FloatProperties[key]; + } + else + { + return defaultValue; + } + } + + public void RegisterCallback(string key, Action action) + { + m_RegisteredActions[key] = action; + } + + public IList ListProperties() + { + return new List(m_FloatProperties.Keys); + } + + private static KeyValuePair DeserializeMessage(byte[] data) + { + using (var memStream = new MemoryStream(data)) + { + using (var binaryReader = new BinaryReader(memStream)) + { + var keyLength = binaryReader.ReadInt32(); + var key = Encoding.ASCII.GetString(binaryReader.ReadBytes(keyLength)); + var value = binaryReader.ReadSingle(); + return new KeyValuePair(key, value); + } + } + } + + private static byte[] SerializeMessage(string key, float value) + { + using (var memStream = new MemoryStream()) + { + using (var binaryWriter = new BinaryWriter(memStream)) + { + var stringEncoded = Encoding.ASCII.GetBytes(key); + binaryWriter.Write(stringEncoded.Length); + binaryWriter.Write(stringEncoded); + binaryWriter.Write(value); + return memStream.ToArray(); + } + } + } + } +} diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs.meta new file mode 100644 index 0000000000..d4b87eb1e4 --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/FloatPropertiesChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 452f8b3c01c4642aba645dcf0b6bfc6e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs new file mode 100644 index 0000000000..8179349aae --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs @@ -0,0 +1,65 @@ +using System.Collections.Generic; +namespace MLAgents +{ + public class RawBytesChannel : SideChannel + { + + private List m_MessagesReceived = new List(); + private int m_ChannelId; + + /// + /// RawBytesChannel provides a way to exchange raw byte arrays between Unity and Python. + /// + /// The identifier for the RawBytesChannel. Must be + /// the same on Python and Unity. + public RawBytesChannel(int channelId = 0) + { + m_ChannelId = channelId; + } + public override int ChannelType() + { + return (int)SideChannelType.RawBytesChannelStart + m_ChannelId; + } + + public override void OnMessageReceived(byte[] data) + { + m_MessagesReceived.Add(data); + } + + /// + /// Sends the byte array message to the Python side channel. The message will be sent + /// alongside the simulation step. + /// + /// The byte array of data to send to Python. + public void SendRawBytes(byte[] data) + { + QueueMessageToSend(data); + } + + /// + /// Gets the messages that were sent by python since the last call to + /// GetAndClearReceivedMessages. + /// + /// a list of byte array messages that Python has sent. + public IList GetAndClearReceivedMessages() + { + var result = new List(); + result.AddRange(m_MessagesReceived); + m_MessagesReceived.Clear(); + return result; + } + + /// + /// Gets the messages that were sent by python since the last call to + /// GetAndClearReceivedMessages. Note that the messages received will not + /// be cleared with a call to GetReceivedMessages. + /// + /// a list of byte array messages that Python has sent. + public IList GetReceivedMessages() + { + var result = new List(); + result.AddRange(m_MessagesReceived); + return result; + } + } +} diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs.meta new file mode 100644 index 0000000000..90a49234ba --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/RawBytesChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 40b01e9cdbfd94865b54ebeb4e5aeaa5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs new file mode 100644 index 0000000000..61e47d5ad3 --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs @@ -0,0 +1,49 @@ +using System.Collections.Generic; + +namespace MLAgents +{ + public enum SideChannelType + { + // Invalid side channel + Invalid = 0, + // Reserved for the FloatPropertiesChannel. + FloatProperties = 1, + //Reserved for the EngineConfigurationChannel. + EngineSettings = 2, + // Raw bytes channels should start here to avoid conflicting with other Unity ones. + RawBytesChannelStart = 1000, + // custom side channels should start here to avoid conflicting with Unity ones. + UserSideChannelStart = 2000, + } + + public abstract class SideChannel + { + // The list of messages (byte arrays) that need to be sent to Python via the communicator. + // Should only ever be read and cleared by a ICommunicator object. + public List MessageQueue = new List(); + + /// + /// An int identifier for the SideChannel. Ensures that there is only ever one side channel + /// of each type. Ensure the Unity side channels will be linked to their Python equivalent. + /// + /// The integer identifier of the SideChannel + public abstract int ChannelType(); + + /// + /// Is called by the communicator every time a message is received from Python by the SideChannel. + /// Can be called multiple times per simulation step if multiple messages were sent. + /// + /// the payload of the message. + public abstract void OnMessageReceived(byte[] data); + + /// + /// Queues a message to be sent to Python during the next simulation step. + /// + /// The byte array of data to be sent to Python. + protected void QueueMessageToSend(byte[] data) + { + MessageQueue.Add(data); + } + + } +} diff --git a/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs.meta new file mode 100644 index 0000000000..c668b0187f --- /dev/null +++ b/UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 77b7d19dd6ce343eeba907540b5a2286 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.py index bd726152a1..ce8a23635d 100644 --- a/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.py +++ b/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.py @@ -22,7 +22,7 @@ name='mlagents/envs/communicator_objects/unity_rl_input.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n7mlagents/envs/communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents/envs/communicator_objects/agent_action.proto\x1a?mlagents/envs/communicator_objects/environment_parameters.proto\x1a\x30mlagents/envs/communicator_objects/command.proto\"\xc3\x03\n\x11UnityRLInputProto\x12P\n\ragent_actions\x18\x01 \x03(\x0b\x32\x39.communicator_objects.UnityRLInputProto.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1aq\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12K\n\x05value\x18\x02 \x01(\x0b\x32<.communicator_objects.UnityRLInputProto.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n7mlagents/envs/communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents/envs/communicator_objects/agent_action.proto\x1a?mlagents/envs/communicator_objects/environment_parameters.proto\x1a\x30mlagents/envs/communicator_objects/command.proto\"\xd9\x03\n\x11UnityRLInputProto\x12P\n\ragent_actions\x18\x01 \x03(\x0b\x32\x39.communicator_objects.UnityRLInputProto.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x12\x14\n\x0cside_channel\x18\x05 \x01(\x0c\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1aq\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12K\n\x05value\x18\x02 \x01(\x0b\x32<.communicator_objects.UnityRLInputProto.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') , dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_agent__action__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_command__pb2.DESCRIPTOR,]) @@ -55,8 +55,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=511, - serialized_end=588, + serialized_start=533, + serialized_end=610, ) _UNITYRLINPUTPROTO_AGENTACTIONSENTRY = _descriptor.Descriptor( @@ -92,8 +92,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=590, - serialized_end=703, + serialized_start=612, + serialized_end=725, ) _UNITYRLINPUTPROTO = _descriptor.Descriptor( @@ -131,6 +131,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='side_channel', full_name='communicator_objects.UnityRLInputProto.side_channel', index=4, + number=5, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -144,7 +151,7 @@ oneofs=[ ], serialized_start=252, - serialized_end=703, + serialized_end=725, ) _UNITYRLINPUTPROTO_LISTAGENTACTIONPROTO.fields_by_name['value'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_agent__action__pb2._AGENTACTIONPROTO diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.pyi index 9cefffdced..f3a6107256 100644 --- a/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.pyi +++ b/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.pyi @@ -89,6 +89,7 @@ class UnityRLInputProto(google___protobuf___message___Message): is_training = ... # type: builtin___bool command = ... # type: mlagents___envs___communicator_objects___command_pb2___CommandProto + side_channel = ... # type: builtin___bytes @property def agent_actions(self) -> typing___MutableMapping[typing___Text, UnityRLInputProto.ListAgentActionProto]: ... @@ -102,6 +103,7 @@ class UnityRLInputProto(google___protobuf___message___Message): environment_parameters : typing___Optional[mlagents___envs___communicator_objects___environment_parameters_pb2___EnvironmentParametersProto] = None, is_training : typing___Optional[builtin___bool] = None, command : typing___Optional[mlagents___envs___communicator_objects___command_pb2___CommandProto] = None, + side_channel : typing___Optional[builtin___bytes] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> UnityRLInputProto: ... @@ -109,7 +111,7 @@ class UnityRLInputProto(google___protobuf___message___Message): def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): def HasField(self, field_name: typing_extensions___Literal[u"environment_parameters"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"agent_actions",u"command",u"environment_parameters",u"is_training"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"agent_actions",u"command",u"environment_parameters",u"is_training",u"side_channel"]) -> None: ... else: def HasField(self, field_name: typing_extensions___Literal[u"environment_parameters",b"environment_parameters"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"agent_actions",b"agent_actions",u"command",b"command",u"environment_parameters",b"environment_parameters",u"is_training",b"is_training"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"agent_actions",b"agent_actions",u"command",b"command",u"environment_parameters",b"environment_parameters",u"is_training",b"is_training",u"side_channel",b"side_channel"]) -> None: ... diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.py index 6bf33beb6b..064a4dde07 100644 --- a/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.py +++ b/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.py @@ -20,7 +20,7 @@ name='mlagents/envs/communicator_objects/unity_rl_output.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n8mlagents/envs/communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a\x33mlagents/envs/communicator_objects/agent_info.proto\"\xa3\x02\n\x12UnityRLOutputProto\x12L\n\nagentInfos\x18\x02 \x03(\x0b\x32\x38.communicator_objects.UnityRLOutputProto.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1an\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12J\n\x05value\x18\x02 \x01(\x0b\x32;.communicator_objects.UnityRLOutputProto.ListAgentInfoProto:\x02\x38\x01J\x04\x08\x01\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n8mlagents/envs/communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a\x33mlagents/envs/communicator_objects/agent_info.proto\"\xb9\x02\n\x12UnityRLOutputProto\x12L\n\nagentInfos\x18\x02 \x03(\x0b\x32\x38.communicator_objects.UnityRLOutputProto.AgentInfosEntry\x12\x14\n\x0cside_channel\x18\x03 \x01(\x0c\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1an\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12J\n\x05value\x18\x02 \x01(\x0b\x32;.communicator_objects.UnityRLOutputProto.ListAgentInfoProto:\x02\x38\x01J\x04\x08\x01\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') , dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_agent__info__pb2.DESCRIPTOR,]) @@ -53,8 +53,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=236, - serialized_end=309, + serialized_start=258, + serialized_end=331, ) _UNITYRLOUTPUTPROTO_AGENTINFOSENTRY = _descriptor.Descriptor( @@ -90,8 +90,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=311, - serialized_end=421, + serialized_start=333, + serialized_end=443, ) _UNITYRLOUTPUTPROTO = _descriptor.Descriptor( @@ -108,6 +108,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='side_channel', full_name='communicator_objects.UnityRLOutputProto.side_channel', index=1, + number=3, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -121,7 +128,7 @@ oneofs=[ ], serialized_start=136, - serialized_end=427, + serialized_end=449, ) _UNITYRLOUTPUTPROTO_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_agent__info__pb2._AGENTINFOPROTO diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.pyi index a1e289f0fc..8a1548b998 100644 --- a/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.pyi +++ b/ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.pyi @@ -79,6 +79,7 @@ class UnityRLOutputProto(google___protobuf___message___Message): def HasField(self, field_name: typing_extensions___Literal[u"value",b"value"]) -> builtin___bool: ... def ClearField(self, field_name: typing_extensions___Literal[u"key",b"key",u"value",b"value"]) -> None: ... + side_channel = ... # type: builtin___bytes @property def agentInfos(self) -> typing___MutableMapping[typing___Text, UnityRLOutputProto.ListAgentInfoProto]: ... @@ -86,12 +87,13 @@ class UnityRLOutputProto(google___protobuf___message___Message): def __init__(self, *, agentInfos : typing___Optional[typing___Mapping[typing___Text, UnityRLOutputProto.ListAgentInfoProto]] = None, + side_channel : typing___Optional[builtin___bytes] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> UnityRLOutputProto: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def ClearField(self, field_name: typing_extensions___Literal[u"agentInfos"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"agentInfos",u"side_channel"]) -> None: ... else: - def ClearField(self, field_name: typing_extensions___Literal[u"agentInfos",b"agentInfos"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"agentInfos",b"agentInfos",u"side_channel",b"side_channel"]) -> None: ... diff --git a/ml-agents-envs/mlagents/envs/environment.py b/ml-agents-envs/mlagents/envs/environment.py index 7af75cd19e..90424b1b63 100644 --- a/ml-agents-envs/mlagents/envs/environment.py +++ b/ml-agents-envs/mlagents/envs/environment.py @@ -6,6 +6,7 @@ import subprocess from typing import Dict, List, Optional, Any +from mlagents.envs.side_channel.side_channel import SideChannel from mlagents.envs.base_unity_environment import BaseUnityEnvironment from mlagents.envs.timers import timed, hierarchical_timer from .brain import AllBrainInfo, BrainInfo, BrainParameters @@ -32,6 +33,7 @@ from .rpc_communicator import RpcCommunicator from sys import platform import signal +import struct logging.basicConfig(level=logging.INFO) logger = logging.getLogger("mlagents.envs") @@ -52,6 +54,7 @@ def __init__( no_graphics: bool = False, timeout_wait: int = 60, args: Optional[List[str]] = None, + side_channels: Optional[List[SideChannel]] = None, ): """ Starts a new unity environment and establishes a connection with the environment. @@ -66,6 +69,7 @@ def __init__( :int timeout_wait: Time (in seconds) to wait for connection from environment. :bool train_mode: Whether to run in training mode, speeding up the simulation, by default. :list args: Addition Unity command line arguments + :list side_channels: Additional side channel for no-rl communication with Unity """ args = args or [] atexit.register(self._close) @@ -79,6 +83,16 @@ def __init__( self.timeout_wait: int = timeout_wait self.communicator = self.get_communicator(worker_id, base_port, timeout_wait) self.worker_id = worker_id + self.side_channels: Dict[int, SideChannel] = {} + if side_channels is not None: + for _sc in side_channels: + if _sc.channel_type in self.side_channels: + raise UnityEnvironmentException( + "There cannot be two side channels with the same channel type {0}.".format( + _sc.channel_type + ) + ) + self.side_channels[_sc.channel_type] = _sc # If the environment name is None, a new environment will not be launched # and the communicator will directly try to connect to an existing unity environment. @@ -527,8 +541,50 @@ def _get_state(self, output: UnityRLOutputProto) -> AllBrainInfo: _data[brain_name] = BrainInfo.from_agent_proto( self.worker_id, agent_info_list, self.brains[brain_name] ) + self._parse_side_channel_message(self.side_channels, output.side_channel) return _data + @staticmethod + def _parse_side_channel_message( + side_channels: Dict[int, SideChannel], data: bytearray + ) -> None: + offset = 0 + while offset < len(data): + try: + channel_type, message_len = struct.unpack_from(" bytearray: + result = bytearray() + for channel_type, channel in side_channels.items(): + for message in channel.message_queue: + result += struct.pack(" None: init_output = output.rl_initialization_output @@ -563,6 +619,7 @@ def _generate_step_input( action.value = float(value[b][i]) rl_in.agent_actions[b].value.extend([action]) rl_in.command = 0 + rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels)) return self.wrap_unity_input(rl_in) def _generate_reset_input( @@ -578,6 +635,7 @@ def _generate_reset_input( custom_reset_parameters ) rl_in.command = 1 + rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels)) return self.wrap_unity_input(rl_in) def send_academy_parameters( diff --git a/ml-agents-envs/mlagents/envs/side_channel/__init__.py b/ml-agents-envs/mlagents/envs/side_channel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ml-agents-envs/mlagents/envs/side_channel/engine_configuration_channel.py b/ml-agents-envs/mlagents/envs/side_channel/engine_configuration_channel.py new file mode 100644 index 0000000000..26ce2763ef --- /dev/null +++ b/ml-agents-envs/mlagents/envs/side_channel/engine_configuration_channel.py @@ -0,0 +1,61 @@ +from mlagents.envs.side_channel.side_channel import SideChannel, SideChannelType +from mlagents.envs.exception import UnityCommunicationException +import struct + + +class EngineConfigurationChannel(SideChannel): + """ + This is the SideChannel for engine configuration exchange. The data in the + engine configuration is as follows : + - int width; + - int height; + - int qualityLevel; + - float timeScale; + - int targetFrameRate; + """ + + @property + def channel_type(self) -> int: + return SideChannelType.EngineSettings + + def on_message_received(self, data: bytearray) -> None: + """ + Is called by the environment to the side channel. Can be called + multiple times per step if multiple messages are meant for that + SideChannel. + Note that Python should never receive an engine configuration from + Unity + """ + raise UnityCommunicationException( + "The EngineConfigurationChannel received a message from Unity, " + + "this should not have happend." + ) + + def set_configuration( + self, + width: int = 80, + height: int = 80, + quality_level: int = 1, + time_scale: float = 20.0, + target_frame_rate: int = -1, + ) -> None: + """ + Sets the engine configuration. Takes as input the configurations of the + engine. + :param width: Defines the width of the display. Default 80. + :param height: Defines the height of the display. Default 80. + :param quality_level: Defines the quality level of the simulation. + Default 1. + :param time_scale: Defines the multiplier for the deltatime in the + simulation. If set to a higher value, time will pass faaster in the + simulation but the physics might break. Default 20. + :param target_frame_rate: Instructs simulation to try to render at a + specified frame rate. Default -1. + """ + data = bytearray() + data += struct.pack(" int: + return SideChannelType.FloatProperties + + def on_message_received(self, data: bytearray) -> None: + """ + Is called by the environment to the side channel. Can be called + multiple times per step if multiple messages are meant for that + SideChannel. + Note that Python should never receive an engine configuration from + Unity + """ + k, v = self.deserialize_float_prop(data) + self._float_properties[k] = v + + def set_property(self, key: str, value: float) -> None: + """ + Sets a property in the Unity Environment. + :param key: The string identifier of the property. + :param value: The float value of the property. + """ + self._float_properties[key] = value + super().queue_message_to_send(self.serialize_float_prop(key, value)) + + def get_property(self, key: str) -> Optional[float]: + """ + Gets a property in the Unity Environment. If the property was not + found, will return None. + :param key: The string identifier of the property. + :return: The float value of the property or None. + """ + return self._float_properties.get(key) + + def list_properties(self) -> List[str]: + """ + Returns a list of all the string identifiers of the properties + currently present in the Unity Environment. + """ + return self._float_properties.keys() + + @staticmethod + def serialize_float_prop(key: str, value: float) -> bytearray: + result = bytearray() + encoded_key = key.encode("ascii") + result += struct.pack(" Tuple[str, float]: + offset = 0 + encoded_key_len = struct.unpack_from(" int: + return SideChannelType.RawBytesChannelStart + self._channel_id + + def on_message_received(self, data: bytearray) -> None: + """ + Is called by the environment to the side channel. Can be called + multiple times per step if multiple messages are meant for that + SideChannel. + """ + self._received_messages.append(data) + + def get_and_clear_received_messages(self) -> List[bytearray]: + """ + returns a list of bytearray received from the environment. + """ + result = list(self._received_messages) + self._received_messages = [] + return result + + def send_raw_data(self, data: bytearray) -> None: + """ + Queues a message to be sent by the environment at the next call to + step. + """ + super().queue_message_to_send(data) diff --git a/ml-agents-envs/mlagents/envs/side_channel/side_channel.py b/ml-agents-envs/mlagents/envs/side_channel/side_channel.py new file mode 100644 index 0000000000..4a1c611612 --- /dev/null +++ b/ml-agents-envs/mlagents/envs/side_channel/side_channel.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from enum import IntEnum + + +class SideChannelType(IntEnum): + FloatProperties = 1 + EngineSettings = 2 + # Raw bytes channels should start here to avoid conflicting with other + # Unity ones. + RawBytesChannelStart = 1000 + # custom side channels should start here to avoid conflicting with Unity + # ones. + UserSideChannelStart = 2000 + + +class SideChannel(ABC): + """ + The side channel just get access to a bytes buffer that will be shared + between C# and Python. For example, We will create a specific side channel + for properties that will be a list of string (fixed size) to float number, + that can be modified by both C# and Python. All side channels are passed + to the Env object at construction. + """ + + def __init__(self): + self.message_queue = [] + + def queue_message_to_send(self, data: bytearray) -> None: + """ + Queues a message to be sent by the environment at the next call to + step. + """ + self.message_queue.append(data) + + @abstractmethod + def on_message_received(self, data: bytearray) -> None: + """ + Is called by the environment to the side channel. Can be called + multiple times per step if multiple messages are meant for that + SideChannel. + """ + pass + + @property + @abstractmethod + def channel_type(self) -> int: + """ + :return:The type of side channel used. Will influence how the data is + processed in the environment. + """ + pass diff --git a/ml-agents-envs/mlagents/envs/tests/test_side_channel.py b/ml-agents-envs/mlagents/envs/tests/test_side_channel.py new file mode 100644 index 0000000000..19ca3a0e4c --- /dev/null +++ b/ml-agents-envs/mlagents/envs/tests/test_side_channel.py @@ -0,0 +1,91 @@ +import struct +from mlagents.envs.side_channel.side_channel import SideChannel +from mlagents.envs.side_channel.float_properties_channel import FloatPropertiesChannel +from mlagents.envs.side_channel.raw_bytes_channel import RawBytesChannel +from mlagents.envs.environment import UnityEnvironment + + +class IntChannel(SideChannel): + def __init__(self): + self.list_int = [] + super().__init__() + + @property + def channel_type(self): + return -1 + + def on_message_received(self, data): + val = struct.unpack_from(" agentInfos = 2; + bytes side_channel = 3; }