@@ -56,9 +56,9 @@ def __init__(self, seed, brain, trainer_parameters):
5656 self .seed = seed
5757 self .brain = brain
5858 self .use_recurrent = trainer_parameters ["use_recurrent" ]
59- self .memory_dict : Dict [int , np .ndarray ] = {}
59+ self .memory_dict : Dict [str , np .ndarray ] = {}
6060 self .num_branches = len (self .brain .vector_action_space_size )
61- self .previous_action_dict : Dict [int , np .array ] = {}
61+ self .previous_action_dict : Dict [str , np .array ] = {}
6262 self .normalize = trainer_parameters .get ("normalize" , False )
6363 self .use_continuous_act = brain .vector_action_space_type == "continuous"
6464 if self .use_continuous_act :
@@ -181,14 +181,14 @@ def make_empty_memory(self, num_agents):
181181 return np .zeros ((num_agents , self .m_size ), dtype = np .float )
182182
183183 def save_memories (
184- self , agent_ids : List [int ], memory_matrix : Optional [np .ndarray ]
184+ self , agent_ids : List [str ], memory_matrix : Optional [np .ndarray ]
185185 ) -> None :
186186 if memory_matrix is None :
187187 return
188188 for index , agent_id in enumerate (agent_ids ):
189189 self .memory_dict [agent_id ] = memory_matrix [index , :]
190190
191- def retrieve_memories (self , agent_ids : List [int ]) -> np .ndarray :
191+ def retrieve_memories (self , agent_ids : List [str ]) -> np .ndarray :
192192 memory_matrix = np .zeros ((len (agent_ids ), self .m_size ), dtype = np .float )
193193 for index , agent_id in enumerate (agent_ids ):
194194 if agent_id in self .memory_dict :
@@ -209,14 +209,14 @@ def make_empty_previous_action(self, num_agents):
209209 return np .zeros ((num_agents , self .num_branches ), dtype = np .int )
210210
211211 def save_previous_action (
212- self , agent_ids : List [int ], action_matrix : Optional [np .ndarray ]
212+ self , agent_ids : List [str ], action_matrix : Optional [np .ndarray ]
213213 ) -> None :
214214 if action_matrix is None :
215215 return
216216 for index , agent_id in enumerate (agent_ids ):
217217 self .previous_action_dict [agent_id ] = action_matrix [index , :]
218218
219- def retrieve_previous_action (self , agent_ids : List [int ]) -> np .ndarray :
219+ def retrieve_previous_action (self , agent_ids : List [str ]) -> np .ndarray :
220220 action_matrix = np .zeros ((len (agent_ids ), self .num_branches ), dtype = np .int )
221221 for index , agent_id in enumerate (agent_ids ):
222222 if agent_id in self .previous_action_dict :
0 commit comments