11import logging
22import itertools
33import numpy as np
4- from typing import Any , Dict , List , Optional , Tuple , Union , Set
4+ from typing import Any , Dict , List , Optional , Tuple , Union
55
66import gym
77from gym import error , spaces
@@ -74,7 +74,9 @@ def __init__(
7474
7575 self .visual_obs = None
7676 self ._n_agents = - 1
77- self ._done_agents : Set [int ] = set ()
77+
78+ self .agent_mapper = AgentIdIndexMapper ()
79+
7880 # Save the step result from the last time all Agents requested decisions.
7981 self ._previous_step_result : BatchedStepResult = None
8082 self ._multiagent = multiagent
@@ -121,6 +123,7 @@ def __init__(
121123 step_result = self ._env .get_step_result (self .brain_name )
122124 self ._check_agents (step_result .n_agents ())
123125 self ._previous_step_result = step_result
126+ self .agent_mapper .set_initial_agents (list (self ._previous_step_result .agent_id ))
124127
125128 # Set observation and action spaces
126129 if self .group_spec .is_action_discrete ():
@@ -368,52 +371,58 @@ def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:
368371 "The number of agents in the scene does not match the expected number."
369372 )
370373
371- # remove the done Agents
372- indices_to_keep : List [int ] = []
373- for index , is_done in enumerate (step_result .done ):
374- if not is_done :
375- indices_to_keep .append (index )
374+ if step_result .n_agents () - sum (step_result .done ) != self ._n_agents :
375+ raise UnityGymException (
376+ "The number of agents in the scene does not match the expected number."
377+ )
378+
379+ for index , agent_id in enumerate (step_result .agent_id ):
380+ if step_result .done [index ]:
381+ self .agent_mapper .mark_agent_done (agent_id , step_result .reward [index ])
376382
377383 # Set the new AgentDone flags to True
378384 # Note that the corresponding agent_id that gets marked done will be different
379385 # than the original agent that was done, but this is OK since the gym interface
380386 # only cares about the ordering.
381387 for index , agent_id in enumerate (step_result .agent_id ):
382388 if not self ._previous_step_result .contains_agent (agent_id ):
389+ # Register this agent, and get the reward of the previous agent that
390+ # was in its index, so that we can return it to the gym.
391+ last_reward = self .agent_mapper .register_new_agent_id (agent_id )
383392 step_result .done [index ] = True
384- if agent_id in self ._done_agents :
385- step_result .done [index ] = True
386- self ._done_agents = set ()
393+ step_result .reward [index ] = last_reward
394+
387395 self ._previous_step_result = step_result # store the new original
388396
397+ # Get a permutation of the agent IDs so that a given ID stays in the same
398+ # index as where it was first seen.
399+ new_id_order = self .agent_mapper .get_id_permutation (list (step_result .agent_id ))
400+
389401 _mask : Optional [List [np .array ]] = None
390402 if step_result .action_mask is not None :
391403 _mask = []
392404 for mask_index in range (len (step_result .action_mask )):
393- _mask .append (step_result .action_mask [mask_index ][indices_to_keep ])
405+ _mask .append (step_result .action_mask [mask_index ][new_id_order ])
394406 new_obs : List [np .array ] = []
395407 for obs_index in range (len (step_result .obs )):
396- new_obs .append (step_result .obs [obs_index ][indices_to_keep ])
408+ new_obs .append (step_result .obs [obs_index ][new_id_order ])
397409 return BatchedStepResult (
398410 obs = new_obs ,
399- reward = step_result .reward [indices_to_keep ],
400- done = step_result .done [indices_to_keep ],
401- max_step = step_result .max_step [indices_to_keep ],
402- agent_id = step_result .agent_id [indices_to_keep ],
411+ reward = step_result .reward [new_id_order ],
412+ done = step_result .done [new_id_order ],
413+ max_step = step_result .max_step [new_id_order ],
414+ agent_id = step_result .agent_id [new_id_order ],
403415 action_mask = _mask ,
404416 )
405417
406418 def _sanitize_action (self , action : np .array ) -> np .array :
407- if self ._previous_step_result .n_agents () == self ._n_agents :
408- return action
409419 sanitized_action = np .zeros (
410420 (self ._previous_step_result .n_agents (), self .group_spec .action_size )
411421 )
412- input_index = 0
413- for index in range (self ._previous_step_result .n_agents ()):
422+ for index , agent_id in enumerate (self ._previous_step_result .agent_id ):
414423 if not self ._previous_step_result .done [index ]:
415- sanitized_action [ index , :] = action [ input_index , :]
416- input_index = input_index + 1
424+ array_index = self . agent_mapper . get_gym_index ( agent_id )
425+ sanitized_action [ index , :] = action [ array_index , :]
417426 return sanitized_action
418427
419428 def _step (self , needs_reset : bool = False ) -> BatchedStepResult :
@@ -432,7 +441,9 @@ def _step(self, needs_reset: bool = False) -> BatchedStepResult:
432441 "The environment does not have the expected amount of agents."
433442 + "Some agents did not request decisions at the same time."
434443 )
435- self ._done_agents .update (list (info .agent_id ))
444+ for agent_id , reward in zip (info .agent_id , info .reward ):
445+ self .agent_mapper .mark_agent_done (agent_id , reward )
446+
436447 self ._env .step ()
437448 info = self ._env .get_step_result (self .brain_name )
438449 return self ._sanitize_info (info )
@@ -499,3 +510,91 @@ def lookup_action(self, action):
499510 :return: The List containing the branched actions.
500511 """
501512 return self .action_lookup [action ]
513+
514+
515+ class AgentIdIndexMapper :
516+ def __init__ (self ) -> None :
517+ self ._agent_id_to_gym_index : Dict [int , int ] = {}
518+ self ._done_agents_index_to_last_reward : Dict [int , float ] = {}
519+
520+ def set_initial_agents (self , agent_ids : List [int ]) -> None :
521+ """
522+ Provide the initial list of agent ids for the mapper
523+ """
524+ for idx , agent_id in enumerate (agent_ids ):
525+ self ._agent_id_to_gym_index [agent_id ] = idx
526+
527+ def mark_agent_done (self , agent_id : int , reward : float ) -> None :
528+ """
529+ Declare the agent done with the corresponding final reward.
530+ """
531+ gym_index = self ._agent_id_to_gym_index .pop (agent_id )
532+ self ._done_agents_index_to_last_reward [gym_index ] = reward
533+
534+ def register_new_agent_id (self , agent_id : int ) -> float :
535+ """
536+ Adds the new agent ID and returns the reward to use for the previous agent in this index
537+ """
538+ # Any free index is OK here.
539+ free_index , last_reward = self ._done_agents_index_to_last_reward .popitem ()
540+ self ._agent_id_to_gym_index [agent_id ] = free_index
541+ return last_reward
542+
543+ def get_id_permutation (self , agent_ids : List [int ]) -> List [int ]:
544+ """
545+ Get the permutation from new agent ids to the order that preserves the positions of previous agents.
546+ The result is a list with each integer from 0 to len(agent_ids)-1 appearing exactly once.
547+ """
548+ # Map the new agent ids to the their index
549+ new_agent_ids_to_index = {
550+ agent_id : idx for idx , agent_id in enumerate (agent_ids )
551+ }
552+
553+ # Make the output list. We don't write to it sequentially, so start with dummy values.
554+ new_permutation = [- 1 ] * len (agent_ids )
555+
556+ # For each agent ID, find the new index of the agent, and write it in the original index.
557+ for agent_id , original_index in self ._agent_id_to_gym_index .items ():
558+ new_permutation [original_index ] = new_agent_ids_to_index [agent_id ]
559+ return new_permutation
560+
561+ def get_gym_index (self , agent_id : int ) -> int :
562+ """
563+ Get the gym index for the current agent.
564+ """
565+ return self ._agent_id_to_gym_index [agent_id ]
566+
567+
568+ class AgentIdIndexMapperSlow :
569+ """
570+ Reference implementation of AgentIdIndexMapper.
571+ The operations are O(N^2) so it shouldn't be used for large numbers of agents.
572+ See AgentIdIndexMapper for method descriptions
573+ """
574+
575+ def __init__ (self ) -> None :
576+ self ._gym_id_order : List [int ] = []
577+ self ._done_agents_index_to_last_reward : Dict [int , float ] = {}
578+
579+ def set_initial_agents (self , agent_ids : List [int ]) -> None :
580+ self ._gym_id_order = list (agent_ids )
581+
582+ def mark_agent_done (self , agent_id : int , reward : float ) -> None :
583+ gym_index = self ._gym_id_order .index (agent_id )
584+ self ._done_agents_index_to_last_reward [gym_index ] = reward
585+ self ._gym_id_order [gym_index ] = - 1
586+
587+ def register_new_agent_id (self , agent_id : int ) -> float :
588+ original_index = self ._gym_id_order .index (- 1 )
589+ self ._gym_id_order [original_index ] = agent_id
590+ reward = self ._done_agents_index_to_last_reward .pop (original_index )
591+ return reward
592+
593+ def get_id_permutation (self , agent_ids ):
594+ new_id_order = []
595+ for agent_id in self ._gym_id_order :
596+ new_id_order .append (agent_ids .index (agent_id ))
597+ return new_id_order
598+
599+ def get_gym_index (self , agent_id : int ) -> int :
600+ return self ._gym_id_order .index (agent_id )
0 commit comments