@@ -247,25 +247,19 @@ def advance(self) -> None:
247247
248248 next_learning_team = self .controller .get_learning_team
249249
250- # CASE 1: Current learning team is managed by this GhostTrainer.
251- # If the learning team changes, the following loop over queues will push the
252- # new policy into the policy queue for the new learning agent if
253- # that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
254- # CASE 2: Current learning team is managed by a different GhostTrainer.
255- # If the learning team changes to a team managed by this GhostTrainer, this loop
256- # will push the current_snapshot into the correct queue. Otherwise,
257- # it will continue skipping and swap_snapshot will continue to handle
258- # pushing fixed snapshots
259- # Case 3: No team change. The if statement just continues to push the policy
250+ # Case 1: No team change. The if statement just continues to push the policy
260251 # into the correct queue (or not if not learning team).
261252 for brain_name in self ._internal_policy_queues :
262253 internal_policy_queue = self ._internal_policy_queues [brain_name ]
263254 try :
264255 policy = internal_policy_queue .get_nowait ()
265256 self .current_policy_snapshot [brain_name ] = policy .get_weights ()
266257 except AgentManagerQueue .Empty :
267- pass
268- if next_learning_team in self ._team_to_name_to_policy_queue :
258+ continue
259+ if (
260+ self ._learning_team == next_learning_team
261+ and next_learning_team in self ._team_to_name_to_policy_queue
262+ ):
269263 name_to_policy_queue = self ._team_to_name_to_policy_queue [
270264 next_learning_team
271265 ]
@@ -277,6 +271,28 @@ def advance(self) -> None:
277271 policy .load_weights (self .current_policy_snapshot [brain_name ])
278272 name_to_policy_queue [brain_name ].put (policy )
279273
274+ # CASE 2: Current learning team is managed by this GhostTrainer.
275+ # If the learning team changes, the following loop over queues will push the
276+ # new policy into the policy queue for the new learning agent if
277+ # that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
278+ # CASE 3: Current learning team is managed by a different GhostTrainer.
279+ # If the learning team changes to a team managed by this GhostTrainer, this loop
280+ # will push the current_snapshot into the correct queue. Otherwise,
281+ # it will continue skipping and swap_snapshot will continue to handle
282+ # pushing fixed snapshots
283+ if (
284+ self ._learning_team != next_learning_team
285+ and next_learning_team in self ._team_to_name_to_policy_queue
286+ ):
287+ name_to_policy_queue = self ._team_to_name_to_policy_queue [
288+ next_learning_team
289+ ]
290+ for brain_name in name_to_policy_queue :
291+ behavior_id = create_name_behavior_id (brain_name , next_learning_team )
292+ policy = self .get_policy (behavior_id )
293+ policy .load_weights (self .current_policy_snapshot [brain_name ])
294+ name_to_policy_queue [brain_name ].put (policy )
295+
280296 # Note save and swap should be on different step counters.
281297 # We don't want to save unless the policy is learning.
282298 if self .get_step - self .last_save > self .steps_between_save :
0 commit comments