Skip to content

Commit c359490

Browse files
author
Ervin T
authored
[bug-fix] Fix issue with SAC updating too much on resume (#4038)
1 parent 575b240 commit c359490

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ vector observations to be used simultaneously. (#3981) Thank you @shakenes !
3434
- When trying to load/resume from a checkpoint created with an earlier verison of ML-Agents,
3535
a warning will be thrown. (#4035)
3636
### Bug Fixes
37+
- Fixed an issue where SAC would perform too many model updates when resuming from a
38+
checkpoint, and too few when using `buffer_init_steps`. (#4038)
3739
#### com.unity.ml-agents (C#)
3840
#### ml-agents / ml-agents-envs / gym-unity (Python)
3941

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def __init__(
6565
)
6666
self.step = 0
6767

68-
# Don't count buffer_init_steps in steps_per_update ratio, but also don't divide-by-0
69-
self.update_steps = max(1, self.hyperparameters.buffer_init_steps)
70-
self.reward_signal_update_steps = max(1, self.hyperparameters.buffer_init_steps)
68+
# Don't divide by zero
69+
self.update_steps = 1
70+
self.reward_signal_update_steps = 1
7171

7272
self.steps_per_update = self.hyperparameters.steps_per_update
7373
self.reward_signal_steps_per_update = (
@@ -229,7 +229,9 @@ def _update_sac_policy(self) -> bool:
229229
)
230230

231231
batch_update_stats: Dict[str, list] = defaultdict(list)
232-
while self.step / self.update_steps > self.steps_per_update:
232+
while (
233+
self.step - self.hyperparameters.buffer_init_steps
234+
) / self.update_steps > self.steps_per_update:
233235
logger.debug("Updating SAC policy at step {}".format(self.step))
234236
buffer = self.update_buffer
235237
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size:
@@ -282,9 +284,8 @@ def _update_reward_signals(self) -> None:
282284
)
283285
batch_update_stats: Dict[str, list] = defaultdict(list)
284286
while (
285-
self.step / self.reward_signal_update_steps
286-
> self.reward_signal_steps_per_update
287-
):
287+
self.step - self.hyperparameters.buffer_init_steps
288+
) / self.reward_signal_update_steps > self.reward_signal_steps_per_update:
288289
# Get minibatches for reward signal update if needed
289290
reward_signal_minibatches = {}
290291
for name, signal in self.optimizer.reward_signals.items():
@@ -327,6 +328,11 @@ def add_policy(
327328
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
328329
# Needed to resume loads properly
329330
self.step = policy.get_current_step()
331+
# Assume steps were updated at the correct ratio before
332+
self.update_steps = int(max(1, self.step / self.steps_per_update))
333+
self.reward_signal_update_steps = int(
334+
max(1, self.step / self.reward_signal_steps_per_update)
335+
)
330336
self.next_summary_step = self._get_next_summary_step()
331337

332338
def get_policy(self, name_behavior_id: str) -> TFPolicy:

ml-agents/mlagents/trainers/tests/test_sac.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def test_advance(dummy_config):
151151
discrete_action=False, visual_inputs=0, vec_obs_size=6
152152
)
153153
dummy_config.hyperparameters.steps_per_update = 20
154+
dummy_config.hyperparameters.reward_signal_steps_per_update = 20
154155
dummy_config.hyperparameters.buffer_init_steps = 0
155156
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
156157
policy = trainer.create_policy(brain_params.brain_name, brain_params)
@@ -220,6 +221,21 @@ def test_advance(dummy_config):
220221
with pytest.raises(AgentManagerQueue.Empty):
221222
policy_queue.get_nowait()
222223

224+
# Call add_policy and check that we update the correct number of times.
225+
# This is to emulate a load from checkpoint.
226+
policy = trainer.create_policy(brain_params.brain_name, brain_params)
227+
policy.get_current_step = lambda: 200
228+
trainer.add_policy(brain_params.brain_name, policy)
229+
trainer.optimizer.update = mock.Mock()
230+
trainer.optimizer.update_reward_signals = mock.Mock()
231+
trainer.optimizer.update_reward_signals.return_value = {}
232+
trainer.optimizer.update.return_value = {}
233+
trajectory_queue.put(trajectory)
234+
trainer.advance()
235+
# Make sure we did exactly 1 update
236+
assert trainer.optimizer.update.call_count == 1
237+
assert trainer.optimizer.update_reward_signals.call_count == 1
238+
223239

224240
if __name__ == "__main__":
225241
pytest.main()

0 commit comments

Comments
 (0)