diff --git a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayCollabAgent.cs b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayCollabAgent.cs index da02be7e3c..be771392a5 100644 --- a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayCollabAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayCollabAgent.cs @@ -15,6 +15,17 @@ public class HallwayCollabAgent : HallwayAgent [HideInInspector] public int selection = 0; + + public override void Initialize() + { + base.Initialize(); + if (isSpotter) + { + var teamManager = new HallwayTeamManager(); + SetTeamManager(teamManager); + teammate.SetTeamManager(teamManager); + } + } public override void OnEpisodeBegin() { m_Message = -1; diff --git a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayTeamManager.cs b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayTeamManager.cs new file mode 100644 index 0000000000..83c719e5c0 --- /dev/null +++ b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayTeamManager.cs @@ -0,0 +1,25 @@ +using System.Collections.Generic; +using Unity.MLAgents; +using Unity.MLAgents.Extensions.Teams; +using Unity.MLAgents.Sensors; + +public class HallwayTeamManager : BaseTeamManager +{ + List m_AgentList = new List { }; + + + public override void RegisterAgent(Agent agent) + { + m_AgentList.Add(agent); + } + + public override void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List sensors) + { + agent.SendDoneToTrainer(); + } + + public override void AddTeamReward(float reward) + { + + } +} diff --git a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayTeamManager.cs.meta b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayTeamManager.cs.meta new file mode 100644 index 0000000000..43150bc2b0 --- /dev/null +++ b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayTeamManager.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 8b67166b7adef46febf8b570f92c400d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Project/Assets/ML-Agents/Examples/Hallway/TFModels/HallwayCollab.onnx.meta b/Project/Assets/ML-Agents/Examples/Hallway/TFModels/HallwayCollab.onnx.meta index ceecd608ba..b5ffac6e9c 100644 --- a/Project/Assets/ML-Agents/Examples/Hallway/TFModels/HallwayCollab.onnx.meta +++ b/Project/Assets/ML-Agents/Examples/Hallway/TFModels/HallwayCollab.onnx.meta @@ -4,6 +4,7 @@ ScriptedImporter: fileIDToRecycleName: 11400000: main obj 11400002: model data + 2186277476908879412: ImportLogs externalObjects: {} userData: assetBundleName: diff --git a/com.unity.ml-agents.extensions/Runtime/Teams.meta b/com.unity.ml-agents.extensions/Runtime/Teams.meta new file mode 100644 index 0000000000..7d905a7387 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Teams.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 77124df6c18c4f669052016b3116147e +timeCreated: 1610064454 \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs new file mode 100644 index 0000000000..5d50e64613 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs @@ -0,0 +1,35 @@ +using System.Collections.Generic; +using Unity.MLAgents; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Teams +{ + public class BaseTeamManager : ITeamManager + { + readonly string m_Id = System.Guid.NewGuid().ToString(); + + public virtual void RegisterAgent(Agent agent) + { + throw new System.NotImplementedException(); + } + + public virtual void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List sensors) + { + // Possible implementation - save reference to Agent's IPolicy so that we can repeatedly + // call IPolicy.RequestDecision on behalf of the Agent after it's dead + // If so, we'll need dummy sensor impls with the same shape as the originals. + throw new System.NotImplementedException(); + } + + public virtual void AddTeamReward(float reward) + { + + } + + public string GetId() + { + return m_Id; + } + + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta new file mode 100644 index 0000000000..4c421e6761 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: b2967f9c3bd4449a98ad309085094769 +timeCreated: 1610064493 \ No newline at end of file diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index f9720e5c7e..c47df58ba5 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -26,7 +26,6 @@ internal class BehaviorParametersEditor : UnityEditor.Editor const string k_InferenceDeviceName = "m_InferenceDevice"; const string k_BehaviorTypeName = "m_BehaviorType"; const string k_TeamIdName = "TeamId"; - const string k_GroupIdName = "GroupId"; const string k_UseChildSensorsName = "m_UseChildSensors"; const string k_ObservableAttributeHandlingName = "m_ObservableAttributeHandling"; @@ -68,7 +67,6 @@ public override void OnInspectorGUI() } needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck(); - EditorGUILayout.PropertyField(so.FindProperty(k_GroupIdName)); EditorGUILayout.PropertyField(so.FindProperty(k_TeamIdName)); EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); { diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 347dc16ec8..28446a966a 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -50,6 +50,11 @@ internal struct AgentInfo /// public int episodeId; + /// + /// Team Manager identifier. + /// + public string teamManagerId; + public void ClearActions() { storedActions.Clear(); @@ -312,6 +317,8 @@ internal struct AgentParameters /// float[] m_LegacyActionCache; + private ITeamManager m_TeamManager; + /// /// Called when the attached [GameObject] becomes enabled and active. /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html @@ -443,6 +450,11 @@ public void LazyInitialize() new int[m_ActuatorManager.NumDiscreteActions] ); + if (m_TeamManager != null) + { + m_Info.teamManagerId = m_TeamManager.GetId(); + } + // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. // To avoid the Agent resetting twice, the Agents will not begin their @@ -459,7 +471,7 @@ public void LazyInitialize() /// /// The reason that the Agent has been set to "done". /// - enum DoneReason + public enum DoneReason { /// /// The episode was ended manually by calling . @@ -535,9 +547,17 @@ void NotifyAgentDone(DoneReason doneReason) } } // Request the last decision with no callbacks - // We request a decision so Python knows the Agent is done immediately - m_Brain?.RequestDecision(m_Info, sensors); - ResetSensors(); + if (m_TeamManager != null) + { + // Send final observations to TeamManager if it exists. + // The TeamManager is responsible to keeping track of the Agent after it's + // done, including propagating any "posthumous" rewards. + m_TeamManager.OnAgentDone(this, doneReason, sensors); + } + else + { + SendDoneToTrainer(); + } // We also have to write any to any DemonstationStores so that they get the "done" flag. foreach (var demoWriter in DemonstrationWriters) @@ -560,6 +580,13 @@ void NotifyAgentDone(DoneReason doneReason) m_Info.storedActions.Clear(); } + public void SendDoneToTrainer() + { + // We request a decision so Python knows the Agent is done immediately + m_Brain?.RequestDecision(m_Info, sensors); + ResetSensors(); + } + /// /// Updates the Model assigned to this Agent instance. /// @@ -1344,5 +1371,12 @@ void DecideAction() m_Info.CopyActions(actions); m_ActuatorManager.UpdateActions(actions); } + + public void SetTeamManager(ITeamManager teamManager) + { + m_TeamManager = teamManager; + m_Info.teamManagerId = teamManager?.GetId(); + teamManager?.RegisterAgent(this); + } } } diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 8e62032941..139432744c 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -67,6 +67,11 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) agentInfoProto.ActionMask.AddRange(ai.discreteActionMasks); } + if (ai.teamManagerId != null) + { + agentInfoProto.TeamManagerId = ai.teamManagerId; + } + return agentInfoProto; } diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs index 5e7232b47b..3de2bb3ede 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs @@ -26,17 +26,18 @@ static AgentInfoReflection() { string.Concat( "CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu", "Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz", - "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B", + "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B", "Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY", "ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv", "bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj", - "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD", - "EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CJaoCIlVuaXR5Lk1MQWdlbnRz", - "LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy", + "X2lkGA4gASgJSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH", + "SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz", + "YgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId" }, null, null, null) })); } #endregion @@ -74,6 +75,7 @@ public AgentInfoProto(AgentInfoProto other) : this() { id_ = other.id_; actionMask_ = other.actionMask_.Clone(); observations_ = other.observations_.Clone(); + teamManagerId_ = other.teamManagerId_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -146,6 +148,17 @@ public int Id { get { return observations_; } } + /// Field number for the "team_manager_id" field. + public const int TeamManagerIdFieldNumber = 14; + private string teamManagerId_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TeamManagerId { + get { return teamManagerId_; } + set { + teamManagerId_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as AgentInfoProto); @@ -165,6 +178,7 @@ public bool Equals(AgentInfoProto other) { if (Id != other.Id) return false; if(!actionMask_.Equals(other.actionMask_)) return false; if(!observations_.Equals(other.observations_)) return false; + if (TeamManagerId != other.TeamManagerId) return false; return Equals(_unknownFields, other._unknownFields); } @@ -177,6 +191,7 @@ public override int GetHashCode() { if (Id != 0) hash ^= Id.GetHashCode(); hash ^= actionMask_.GetHashCode(); hash ^= observations_.GetHashCode(); + if (TeamManagerId.Length != 0) hash ^= TeamManagerId.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -208,6 +223,10 @@ public void WriteTo(pb::CodedOutputStream output) { } actionMask_.WriteTo(output, _repeated_actionMask_codec); observations_.WriteTo(output, _repeated_observations_codec); + if (TeamManagerId.Length != 0) { + output.WriteRawTag(114); + output.WriteString(TeamManagerId); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -230,6 +249,9 @@ public int CalculateSize() { } size += actionMask_.CalculateSize(_repeated_actionMask_codec); size += observations_.CalculateSize(_repeated_observations_codec); + if (TeamManagerId.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TeamManagerId); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -255,6 +277,9 @@ public void MergeFrom(AgentInfoProto other) { } actionMask_.Add(other.actionMask_); observations_.Add(other.observations_); + if (other.TeamManagerId.Length != 0) { + TeamManagerId = other.TeamManagerId; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -291,6 +316,10 @@ public void MergeFrom(pb::CodedInputStream input) { observations_.AddEntriesFrom(input, _repeated_observations_codec); break; } + case 114: { + TeamManagerId = input.ReadString(); + break; + } } } } diff --git a/com.unity.ml-agents/Runtime/ITeamManager.cs b/com.unity.ml-agents/Runtime/ITeamManager.cs new file mode 100644 index 0000000000..c33bafaafa --- /dev/null +++ b/com.unity.ml-agents/Runtime/ITeamManager.cs @@ -0,0 +1,14 @@ +using System.Collections.Generic; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents +{ + public interface ITeamManager + { + string GetId(); + + void RegisterAgent(Agent agent); + // TODO not sure this is all the info we need, maybe pass a class/struct instead. + void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List sensors); + } +} diff --git a/com.unity.ml-agents/Runtime/ITeamManager.cs.meta b/com.unity.ml-agents/Runtime/ITeamManager.cs.meta new file mode 100644 index 0000000000..0889f5a46b --- /dev/null +++ b/com.unity.ml-agents/Runtime/ITeamManager.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 75810d91665e4477977eb78c9b15aeb3 +timeCreated: 1610057818 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs index 33be92c988..4a74437fb9 100644 --- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs +++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs @@ -144,12 +144,6 @@ public string BehaviorName [HideInInspector, SerializeField, FormerlySerializedAs("m_TeamID")] public int TeamId; - /// - /// The group ID for this behavior. - /// - [HideInInspector, SerializeField] - [Tooltip("Assign the same Group ID to all Agents in the same Area.")] - public int GroupId; // TODO properties here instead of Agent [FormerlySerializedAs("m_useChildSensors")] @@ -200,7 +194,7 @@ public ObservableAttributeOptions ObservableAttributeHandling /// public string FullyQualifiedBehaviorName { - get { return m_BehaviorName + "?team=" + TeamId + "&group=" + GroupId; } + get { return m_BehaviorName + "?team=" + TeamId; } } internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGenerator heuristic) diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index 1f2915b59f..f282ad95dc 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -56,6 +56,7 @@ class DecisionStep(NamedTuple): reward: float agent_id: AgentId action_mask: Optional[List[np.ndarray]] + team_manager_id: Optional[str] class DecisionSteps(Mapping): @@ -81,10 +82,11 @@ class DecisionSteps(Mapping): this simulation step. """ - def __init__(self, obs, reward, agent_id, action_mask): + def __init__(self, obs, reward, agent_id, action_mask, team_manager_id=None): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward self.agent_id: np.ndarray = agent_id + self.team_manager_id: Optional[List[str]] = team_manager_id self.action_mask: Optional[List[np.ndarray]] = action_mask self._agent_id_to_index: Optional[Dict[AgentId, int]] = None @@ -120,11 +122,15 @@ def __getitem__(self, agent_id: AgentId) -> DecisionStep: agent_mask = [] for mask in self.action_mask: agent_mask.append(mask[agent_index]) + team_manager_id = None + if self.team_manager_id is not None and self.team_manager_id != "": + team_manager_id = self.team_manager_id[agent_index] return DecisionStep( obs=agent_obs, reward=self.reward[agent_index], agent_id=agent_id, action_mask=agent_mask, + team_manager_id=team_manager_id, ) def __iter__(self) -> Iterator[Any]: @@ -144,6 +150,7 @@ def empty(spec: "BehaviorSpec") -> "DecisionSteps": reward=np.zeros(0, dtype=np.float32), agent_id=np.zeros(0, dtype=np.int32), action_mask=None, + team_manager_id=None, ) @@ -163,6 +170,7 @@ class TerminalStep(NamedTuple): reward: float interrupted: bool agent_id: AgentId + team_manager_id: Optional[str] class TerminalSteps(Mapping): @@ -183,12 +191,13 @@ class TerminalSteps(Mapping): across simulation steps. """ - def __init__(self, obs, reward, interrupted, agent_id): + def __init__(self, obs, reward, interrupted, agent_id, team_manager_id=None): self.obs: List[np.ndarray] = obs self.reward: np.ndarray = reward self.interrupted: np.ndarray = interrupted self.agent_id: np.ndarray = agent_id self._agent_id_to_index: Optional[Dict[AgentId, int]] = None + self.team_manager_id: Optional[List[str]] = team_manager_id @property def agent_id_to_index(self) -> Dict[AgentId, int]: @@ -218,11 +227,15 @@ def __getitem__(self, agent_id: AgentId) -> TerminalStep: agent_obs = [] for batched_obs in self.obs: agent_obs.append(batched_obs[agent_index]) + team_manager_id = None + if self.team_manager_id is not None and self.team_manager_id != "": + team_manager_id = self.team_manager_id[agent_index] return TerminalStep( obs=agent_obs, reward=self.reward[agent_index], interrupted=self.interrupted[agent_index], agent_id=agent_id, + team_manager_id=team_manager_id, ) def __iter__(self) -> Iterator[Any]: @@ -242,6 +255,7 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps": reward=np.zeros(0, dtype=np.float32), interrupted=np.zeros(0, dtype=np.bool), agent_id=np.zeros(0, dtype=np.int32), + team_manager_id=None, ) diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py index e128cc76d8..2318a04c47 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py @@ -20,7 +20,7 @@ name='mlagents_envs/communicator_objects/agent_info.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xd1\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProtoJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\tJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') , dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,]) @@ -76,6 +76,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='team_manager_id', full_name='communicator_objects.AgentInfoProto.team_manager_id', index=6, + number=14, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -89,7 +96,7 @@ oneofs=[ ], serialized_start=132, - serialized_end=341, + serialized_end=366, ) _AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents__envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi index fcf93b7c7f..33b4ce779d 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi @@ -20,6 +20,7 @@ from mlagents_envs.communicator_objects.observation_pb2 import ( from typing import ( Iterable as typing___Iterable, Optional as typing___Optional, + Text as typing___Text, ) from typing_extensions import ( @@ -40,6 +41,7 @@ class AgentInfoProto(google___protobuf___message___Message): max_step_reached = ... # type: builtin___bool id = ... # type: builtin___int action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool] + team_manager_id = ... # type: typing___Text @property def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ... @@ -52,12 +54,13 @@ class AgentInfoProto(google___protobuf___message___Message): id : typing___Optional[builtin___int] = None, action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None, observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None, + team_manager_id : typing___Optional[typing___Text] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward",u"team_manager_id"]) -> None: ... else: - def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward",u"team_manager_id",b"team_manager_id"]) -> None: ... diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py index 2f21615d1f..35a1121fe4 100644 --- a/ml-agents-envs/mlagents_envs/rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/rpc_utils.py @@ -309,9 +309,24 @@ def steps_from_proto( decision_rewards = np.array( [agent_info.reward for agent_info in decision_agent_info_list], dtype=np.float32 ) + decision_team_manager = [ + agent_info.team_manager_id + for agent_info in decision_agent_info_list + if agent_info.team_manager_id is not None + ] + if len(decision_team_manager) == 0: + decision_team_manager = None + terminal_rewards = np.array( [agent_info.reward for agent_info in terminal_agent_info_list], dtype=np.float32 ) + terminal_team_manager = [ + agent_info.team_manager_id + for agent_info in terminal_agent_info_list + if agent_info.team_manager_id is not None + ] + if len(terminal_team_manager) == 0: + terminal_team_manager = None _raise_on_nan_and_inf(decision_rewards, "rewards") _raise_on_nan_and_inf(terminal_rewards, "rewards") @@ -349,9 +364,19 @@ def steps_from_proto( action_mask = np.split(action_mask, indices, axis=1) return ( DecisionSteps( - decision_obs_list, decision_rewards, decision_agent_id, action_mask + decision_obs_list, + decision_rewards, + decision_agent_id, + action_mask, + decision_team_manager, + ), + TerminalSteps( + terminal_obs_list, + terminal_rewards, + max_step, + terminal_agent_id, + terminal_team_manager, ), - TerminalSteps(terminal_obs_list, terminal_rewards, max_step, terminal_agent_id), ) diff --git a/ml-agents/mlagents/trainers/agent_processor.py b/ml-agents/mlagents/trainers/agent_processor.py index 1d5a85116c..e56af89bef 100644 --- a/ml-agents/mlagents/trainers/agent_processor.py +++ b/ml-agents/mlagents/trainers/agent_processor.py @@ -51,8 +51,11 @@ def __init__( self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) self.last_experience: Dict[str, AgentExperience] = {} self.last_step_result: Dict[str, Tuple[DecisionStep, int]] = {} - # current_obs is used to collect the last seen obs of all the agents, and assemble the next_collab_obs. - self.current_obs: Dict[str, List[np.ndarray]] = {} + # current_group_obs is used to collect the last seen obs of all the agents in the same group, + # and assemble the next_collab_obs. + self.current_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict( + lambda: defaultdict(list) + ) # last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while # grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1). self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} @@ -103,7 +106,7 @@ def add_experiences( local_id = terminal_step.agent_id global_id = get_global_agent_id(worker_id, local_id) self._assemble_trajectory(terminal_step, global_id) - self.current_obs.clear() + self.current_group_obs.clear() # Clean the last experience dictionary for terminal steps for terminal_step in terminal_steps.values(): @@ -122,7 +125,7 @@ def add_experiences( local_id = ongoing_step.agent_id global_id = get_global_agent_id(worker_id, local_id) self._assemble_trajectory(ongoing_step, global_id) - self.current_obs.clear() + self.current_group_obs.clear() for _gid in action_global_agent_ids: # If the ID doesn't have a last step result, the agent just reset, @@ -177,7 +180,8 @@ def _process_step( interrupted=interrupted, memory=memory, ) - self.current_obs[global_id] = step.obs + if step.team_manager_id is not None: + self.current_group_obs[step.team_manager_id][global_id] += step.obs self.last_experience[global_id] = experience def _assemble_trajectory( @@ -206,10 +210,8 @@ def _assemble_trajectory( ): next_obs = step.obs next_collab_obs = [] - for _id, _exp in self.current_obs.items(): - if _id == global_id: - continue - else: + for _id, _exp in self.current_group_obs[step.team_manager_id].items(): + if _id != global_id: next_collab_obs.append(_exp) trajectory = Trajectory( diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto index 403540a6c5..ae82b0d6fd 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto @@ -19,4 +19,5 @@ message AgentInfoProto { repeated bool action_mask = 11; reserved 12; // deprecated CustomObservationProto custom_observation = 12; repeated ObservationProto observations = 13; + string team_manager_id = 14; }