diff --git a/ml-agents/mlagents/trainers/ppo/trainer.py b/ml-agents/mlagents/trainers/ppo/trainer.py index 445a7f4c0c..95d6cd399f 100644 --- a/ml-agents/mlagents/trainers/ppo/trainer.py +++ b/ml-agents/mlagents/trainers/ppo/trainer.py @@ -258,6 +258,7 @@ def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None: if not isinstance(policy, PPOPolicy): raise RuntimeError("Non-PPOPolicy passed to PPOTrainer.add_policy()") self.policy = policy + self.step = policy.get_current_step() def get_policy(self, name_behavior_id: str) -> TFPolicy: """ diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py index b64a919ef1..bd6ff4d559 100644 --- a/ml-agents/mlagents/trainers/sac/trainer.py +++ b/ml-agents/mlagents/trainers/sac/trainer.py @@ -340,6 +340,7 @@ def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None: if not isinstance(policy, SACPolicy): raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()") self.policy = policy + self.step = policy.get_current_step() def get_policy(self, name_behavior_id: str) -> TFPolicy: """