Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6749d38
[WIP] Side Channel initial layout
vincentpierre Nov 22, 2019
15a8b92
Working prototype for raw bytes
vincentpierre Nov 22, 2019
02317df
fixing format mistake
vincentpierre Nov 22, 2019
93d7dd0
Added some errors and some unit tests in C#
vincentpierre Nov 23, 2019
b98cf97
Added the side channel for the Engine Configuration. (#2958)
vincentpierre Nov 25, 2019
b65195e
Logging a message when an unknown side channel number has been receiv…
vincentpierre Nov 25, 2019
d513e2f
Addressing comments
vincentpierre Nov 25, 2019
83beef9
renamings
vincentpierre Nov 25, 2019
5daadb0
renamings
vincentpierre Nov 26, 2019
8980e7a
Adding FloatProperties to the side channels (#2968)
vincentpierre Nov 26, 2019
3aec6b9
renaming m_SideChannelsDict to m_SideChannel
vincentpierre Nov 26, 2019
2dff366
renaming and some comments
vincentpierre Nov 26, 2019
c81b324
renaming and adding a GetAndClearReceivedMessages() in the RawBytesSi…
vincentpierre Nov 26, 2019
37786fd
micro-optimization
vincentpierre Nov 26, 2019
cf867fe
more errors and some nit
vincentpierre Nov 26, 2019
d6cfd52
addressing comments
vincentpierre Nov 26, 2019
12857e2
Using little-endian format in Python
vincentpierre Nov 26, 2019
5b35693
adding some comments
vincentpierre Nov 26, 2019
22b790a
Code comments
vincentpierre Nov 26, 2019
efa6de3
some changes and added the unit tests on both Python and C#
vincentpierre Nov 26, 2019
0d2c337
removing default default in get default
vincentpierre Nov 26, 2019
55b596e
Update UnitySDK/Assets/ML-Agents/Scripts/SideChannel/SideChannel.cs
vincentpierre Nov 26, 2019
14689b0
Update ml-agents-envs/mlagents/envs/side_channel/raw_bytes_channel.py
vincentpierre Nov 26, 2019
bfaa732
addressing comments
vincentpierre Nov 26, 2019
5f1c174
Merge branch 'develop-side-channel' of https://github.com/Unity-Techn…
vincentpierre Nov 26, 2019
c3c2bdf
fixing tests
vincentpierre Nov 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using System;
using NUnit.Framework;
using MLAgents;
using System.Collections.Generic;
using System.Text;

namespace MLAgents.Tests
{
public class SideChannelTests
{

// This test side channel only deals in integers
public class TestSideChannel : SideChannel
{

public List<int> m_MessagesReceived = new List<int>();

public override int ChannelType() { return -1; }

public override void OnMessageReceived(byte[] data)
{
m_MessagesReceived.Add(BitConverter.ToInt32(data, 0));
}

public void SendInt(int data)
{
QueueMessageToSend(BitConverter.GetBytes(data));
}
}

[Test]
public void TestIntegerSideChannel()
{
var intSender = new TestSideChannel();
var intReceiver = new TestSideChannel();
var dictSender = new Dictionary<int, SideChannel> { { intSender.ChannelType(), intSender } };
var dictReceiver = new Dictionary<int, SideChannel> { { intReceiver.ChannelType(), intReceiver } };

intSender.SendInt(4);
intSender.SendInt(5);
intSender.SendInt(6);

byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.SendSideChannelData(dictReceiver, fakeData);

Assert.AreEqual(intReceiver.m_MessagesReceived[0], 4);
Assert.AreEqual(intReceiver.m_MessagesReceived[1], 5);
Assert.AreEqual(intReceiver.m_MessagesReceived[2], 6);
}

[Test]
public void TestRawBytesSideChannel()
{
var str1 = "Test string";
var str2 = "Test string, second";

var strSender = new RawBytesChannel();
var strReceiver = new RawBytesChannel();
var dictSender = new Dictionary<int, SideChannel> { { strSender.ChannelType(), strSender } };
var dictReceiver = new Dictionary<int, SideChannel> { { strReceiver.ChannelType(), strReceiver } };

strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1));
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2));

byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.SendSideChannelData(dictReceiver, fakeData);

var messages = strReceiver.ReceiveRawBytes();

Assert.AreEqual(messages.Count, 2);
Assert.AreEqual(Encoding.ASCII.GetString(messages[0]), str1);
Assert.AreEqual(Encoding.ASCII.GetString(messages[1]), str2);

}

}
}
11 changes: 11 additions & 0 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/SideChannelTests.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ void InitializeEnvironment()
Communicator.QuitCommandReceived += OnQuitCommandReceived;
Communicator.ResetCommandReceived += OnResetCommand;
Communicator.RLInputReceived += OnRLInputReceived;
Communicator.RegisterSideChannel(new EngineConfigurationChannel());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,23 @@ static UnityRlInputReflection() {
"ZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9hY3Rpb24ucHJvdG8a",
"P21sYWdlbnRzL2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvZW52aXJvbm1l",
"bnRfcGFyYW1ldGVycy5wcm90bxowbWxhZ2VudHMvZW52cy9jb21tdW5pY2F0",
"b3Jfb2JqZWN0cy9jb21tYW5kLnByb3RvIsMDChFVbml0eVJMSW5wdXRQcm90",
"b3Jfb2JqZWN0cy9jb21tYW5kLnByb3RvItkDChFVbml0eVJMSW5wdXRQcm90",
"bxJQCg1hZ2VudF9hY3Rpb25zGAEgAygLMjkuY29tbXVuaWNhdG9yX29iamVj",
"dHMuVW5pdHlSTElucHV0UHJvdG8uQWdlbnRBY3Rpb25zRW50cnkSUAoWZW52",
"aXJvbm1lbnRfcGFyYW1ldGVycxgCIAEoCzIwLmNvbW11bmljYXRvcl9vYmpl",
"Y3RzLkVudmlyb25tZW50UGFyYW1ldGVyc1Byb3RvEhMKC2lzX3RyYWluaW5n",
"GAMgASgIEjMKB2NvbW1hbmQYBCABKA4yIi5jb21tdW5pY2F0b3Jfb2JqZWN0",
"cy5Db21tYW5kUHJvdG8aTQoUTGlzdEFnZW50QWN0aW9uUHJvdG8SNQoFdmFs",
"dWUYASADKAsyJi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5BZ2VudEFjdGlvblBy",
"b3RvGnEKEUFnZW50QWN0aW9uc0VudHJ5EgsKA2tleRgBIAEoCRJLCgV2YWx1",
"ZRgCIAEoCzI8LmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxJbnB1dFBy",
"b3RvLkxpc3RBZ2VudEFjdGlvblByb3RvOgI4AUIfqgIcTUxBZ2VudHMuQ29t",
"bXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
"cy5Db21tYW5kUHJvdG8SFAoMc2lkZV9jaGFubmVsGAUgASgMGk0KFExpc3RB",
"Z2VudEFjdGlvblByb3RvEjUKBXZhbHVlGAEgAygLMiYuY29tbXVuaWNhdG9y",
"X29iamVjdHMuQWdlbnRBY3Rpb25Qcm90bxpxChFBZ2VudEFjdGlvbnNFbnRy",
"eRILCgNrZXkYASABKAkSSwoFdmFsdWUYAiABKAsyPC5jb21tdW5pY2F0b3Jf",
"b2JqZWN0cy5Vbml0eVJMSW5wdXRQcm90by5MaXN0QWdlbnRBY3Rpb25Qcm90",
"bzoCOAFCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3Rv",
"Mw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentActionReflection.Descriptor, global::MLAgents.CommunicatorObjects.EnvironmentParametersReflection.Descriptor, global::MLAgents.CommunicatorObjects.CommandReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Parser, new[]{ "AgentActions", "EnvironmentParameters", "IsTraining", "Command" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Parser, new[]{ "AgentActions", "EnvironmentParameters", "IsTraining", "Command", "SideChannel" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto), global::MLAgents.CommunicatorObjects.UnityRLInputProto.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null),
null, })
}));
}
Expand Down Expand Up @@ -81,6 +82,7 @@ public UnityRLInputProto(UnityRLInputProto other) : this() {
EnvironmentParameters = other.environmentParameters_ != null ? other.EnvironmentParameters.Clone() : null;
isTraining_ = other.isTraining_;
command_ = other.command_;
sideChannel_ = other.sideChannel_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -132,6 +134,17 @@ public bool IsTraining {
}
}

/// <summary>Field number for the "side_channel" field.</summary>
public const int SideChannelFieldNumber = 5;
private pb::ByteString sideChannel_ = pb::ByteString.Empty;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pb::ByteString SideChannel {
get { return sideChannel_; }
set {
sideChannel_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLInputProto);
Expand All @@ -149,6 +162,7 @@ public bool Equals(UnityRLInputProto other) {
if (!object.Equals(EnvironmentParameters, other.EnvironmentParameters)) return false;
if (IsTraining != other.IsTraining) return false;
if (Command != other.Command) return false;
if (SideChannel != other.SideChannel) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -159,6 +173,7 @@ public override int GetHashCode() {
if (environmentParameters_ != null) hash ^= EnvironmentParameters.GetHashCode();
if (IsTraining != false) hash ^= IsTraining.GetHashCode();
if (Command != 0) hash ^= Command.GetHashCode();
if (SideChannel.Length != 0) hash ^= SideChannel.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand All @@ -185,6 +200,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(32);
output.WriteEnum((int) Command);
}
if (SideChannel.Length != 0) {
output.WriteRawTag(42);
output.WriteBytes(SideChannel);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -203,6 +222,9 @@ public int CalculateSize() {
if (Command != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Command);
}
if (SideChannel.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeBytesSize(SideChannel);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -227,6 +249,9 @@ public void MergeFrom(UnityRLInputProto other) {
if (other.Command != 0) {
Command = other.Command;
}
if (other.SideChannel.Length != 0) {
SideChannel = other.SideChannel;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand Down Expand Up @@ -257,6 +282,10 @@ public void MergeFrom(pb::CodedInputStream input) {
command_ = (global::MLAgents.CommunicatorObjects.CommandProto) input.ReadEnum();
break;
}
case 42: {
SideChannel = input.ReadBytes();
break;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ static UnityRlOutputReflection() {
string.Concat(
"CjhtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js",
"X291dHB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaM21sYWdlbnRz",
"L2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdlbnRfaW5mby5wcm90byKj",
"L2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdlbnRfaW5mby5wcm90byK5",
"AgoSVW5pdHlSTE91dHB1dFByb3RvEkwKCmFnZW50SW5mb3MYAiADKAsyOC5j",
"b21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMT3V0cHV0UHJvdG8uQWdlbnRJ",
"bmZvc0VudHJ5GkkKEkxpc3RBZ2VudEluZm9Qcm90bxIzCgV2YWx1ZRgBIAMo",
"CzIkLmNvbW11bmljYXRvcl9vYmplY3RzLkFnZW50SW5mb1Byb3RvGm4KD0Fn",
"ZW50SW5mb3NFbnRyeRILCgNrZXkYASABKAkSSgoFdmFsdWUYAiABKAsyOy5j",
"b21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMT3V0cHV0UHJvdG8uTGlzdEFn",
"ZW50SW5mb1Byb3RvOgI4AUoECAEQAkIfqgIcTUxBZ2VudHMuQ29tbXVuaWNh",
"dG9yT2JqZWN0c2IGcHJvdG8z"));
"bmZvc0VudHJ5EhQKDHNpZGVfY2hhbm5lbBgDIAEoDBpJChJMaXN0QWdlbnRJ",
"bmZvUHJvdG8SMwoFdmFsdWUYASADKAsyJC5jb21tdW5pY2F0b3Jfb2JqZWN0",
"cy5BZ2VudEluZm9Qcm90bxpuCg9BZ2VudEluZm9zRW50cnkSCwoDa2V5GAEg",
"ASgJEkoKBXZhbHVlGAIgASgLMjsuY29tbXVuaWNhdG9yX29iamVjdHMuVW5p",
"dHlSTE91dHB1dFByb3RvLkxpc3RBZ2VudEluZm9Qcm90bzoCOAFKBAgBEAJC",
"H6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentInfoReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Parser, new[]{ "AgentInfos" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Parser, new[]{ "AgentInfos", "SideChannel" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto), global::MLAgents.CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null),
null, })
}));
}
Expand Down Expand Up @@ -72,6 +72,7 @@ public UnityRLOutputProto() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLOutputProto(UnityRLOutputProto other) : this() {
agentInfos_ = other.agentInfos_.Clone();
sideChannel_ = other.sideChannel_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand All @@ -90,6 +91,17 @@ public UnityRLOutputProto Clone() {
get { return agentInfos_; }
}

/// <summary>Field number for the "side_channel" field.</summary>
public const int SideChannelFieldNumber = 3;
private pb::ByteString sideChannel_ = pb::ByteString.Empty;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pb::ByteString SideChannel {
get { return sideChannel_; }
set {
sideChannel_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLOutputProto);
Expand All @@ -104,13 +116,15 @@ public bool Equals(UnityRLOutputProto other) {
return true;
}
if (!AgentInfos.Equals(other.AgentInfos)) return false;
if (SideChannel != other.SideChannel) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
hash ^= AgentInfos.GetHashCode();
if (SideChannel.Length != 0) hash ^= SideChannel.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand All @@ -125,6 +139,10 @@ public override string ToString() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
agentInfos_.WriteTo(output, _map_agentInfos_codec);
if (SideChannel.Length != 0) {
output.WriteRawTag(26);
output.WriteBytes(SideChannel);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -134,6 +152,9 @@ public void WriteTo(pb::CodedOutputStream output) {
public int CalculateSize() {
int size = 0;
size += agentInfos_.CalculateSize(_map_agentInfos_codec);
if (SideChannel.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeBytesSize(SideChannel);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -146,6 +167,9 @@ public void MergeFrom(UnityRLOutputProto other) {
return;
}
agentInfos_.Add(other.agentInfos_);
if (other.SideChannel.Length != 0) {
SideChannel = other.SideChannel;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand All @@ -161,6 +185,10 @@ public void MergeFrom(pb::CodedInputStream input) {
agentInfos_.AddEntriesFrom(input, _map_agentInfos_codec);
break;
}
case 26: {
SideChannel = input.ReadBytes();
break;
}
}
}
}
Expand Down
Loading