Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ vector observations to be used simultaneously. (#3981) Thank you @shakenes !
- Unity Player logs are now written out to the results directory. (#3877)
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
### Bug Fixes
- Fixed an issue where SAC would perform a large number of model updates when resuming from a
checkpoint (#4038)
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)

Expand Down
20 changes: 13 additions & 7 deletions ml-agents/mlagents/trainers/sac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(
)
self.step = 0

# Don't count buffer_init_steps in steps_per_update ratio, but also don't divide-by-0
self.update_steps = max(1, self.hyperparameters.buffer_init_steps)
self.reward_signal_update_steps = max(1, self.hyperparameters.buffer_init_steps)
# Don't divide by zero
self.update_steps = 1
self.reward_signal_update_steps = 1

self.steps_per_update = self.hyperparameters.steps_per_update
self.reward_signal_steps_per_update = (
Expand Down Expand Up @@ -229,7 +229,9 @@ def _update_sac_policy(self) -> bool:
)

batch_update_stats: Dict[str, list] = defaultdict(list)
while self.step / self.update_steps > self.steps_per_update:
while (
self.step - self.hyperparameters.buffer_init_steps
) / self.update_steps > self.steps_per_update:
logger.debug("Updating SAC policy at step {}".format(self.step))
buffer = self.update_buffer
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size:
Expand Down Expand Up @@ -282,9 +284,8 @@ def _update_reward_signals(self) -> None:
)
batch_update_stats: Dict[str, list] = defaultdict(list)
while (
self.step / self.reward_signal_update_steps
> self.reward_signal_steps_per_update
):
self.step - self.hyperparameters.buffer_init_steps
) / self.reward_signal_update_steps > self.reward_signal_steps_per_update:
# Get minibatches for reward signal update if needed
reward_signal_minibatches = {}
for name, signal in self.optimizer.reward_signals.items():
Expand Down Expand Up @@ -327,6 +328,11 @@ def add_policy(
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
self.step = policy.get_current_step()
# Assume steps were updated at the correct ratio before
self.update_steps = int(max(1, self.step / self.steps_per_update))
self.reward_signal_update_steps = int(
max(1, self.step / self.reward_signal_steps_per_update)
)
self.next_summary_step = self._get_next_summary_step()

def get_policy(self, name_behavior_id: str) -> TFPolicy:
Expand Down
16 changes: 16 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def test_advance(dummy_config):
discrete_action=False, visual_inputs=0, vec_obs_size=6
)
dummy_config.hyperparameters.steps_per_update = 20
dummy_config.hyperparameters.reward_signal_steps_per_update = 20
dummy_config.hyperparameters.buffer_init_steps = 0
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
policy = trainer.create_policy(brain_params.brain_name, brain_params)
Expand Down Expand Up @@ -220,6 +221,21 @@ def test_advance(dummy_config):
with pytest.raises(AgentManagerQueue.Empty):
policy_queue.get_nowait()

# Call add_policy and check that we update the correct number of times.
# This is to emulate a load from checkpoint.
policy = trainer.create_policy(brain_params.brain_name, brain_params)
policy.get_current_step = lambda: 200
trainer.add_policy(brain_params.brain_name, policy)
trainer.optimizer.update = mock.Mock()
trainer.optimizer.update_reward_signals = mock.Mock()
trainer.optimizer.update_reward_signals.return_value = {}
trainer.optimizer.update.return_value = {}
trajectory_queue.put(trajectory)
trainer.advance()
# Make sure we did exactly 1 update
assert trainer.optimizer.update.call_count == 1
assert trainer.optimizer.update_reward_signals.call_count == 1


if __name__ == "__main__":
pytest.main()