Skip to content

Commit 7b800d1

Browse files
author
Ervin T
committed
[bug-fix] When agent isn't training, don't clear update buffer (#5205)
* Don't clear update buffer, but don't append to it either * Update changelog * Address comments * Make experience replay buffer saving more verbose (cherry picked from commit 63e7ad4)
1 parent 2aaf326 commit 7b800d1

File tree

6 files changed

+63
-19
lines changed

6 files changed

+63
-19
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
66
and this project adheres to
77
[Semantic Versioning](http://semver.org/spec/v2.0.0.html).
88

9+
## [1.9.1-preview]
10+
### Bug Fixes
11+
#### ml-agents / ml-agents-envs / gym-unity (Python)
12+
- Fixed a bug where the SAC replay buffer would not be saved out at the end of a run, even if `save_replay_buffer` was enabled. (#5205)
13+
914
## [1.9.0-preview] - 2021-03-17
1015
### Major Changes
1116
#### com.unity.ml-agents (C#)

ml-agents/mlagents/trainers/poca/trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None:
166166
)
167167
agent_buffer_trajectory[BufferKey.ADVANTAGES].set(global_advantages)
168168

169-
# Append to update buffer
170-
agent_buffer_trajectory.resequence_and_append(
171-
self.update_buffer, training_length=self.policy.sequence_length
172-
)
169+
self._append_to_update_buffer(agent_buffer_trajectory)
173170

174171
# If this was a terminal trajectory, append stats and reset reward collection
175172
if trajectory.done_reached:

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,8 @@ def _process_trajectory(self, trajectory: Trajectory) -> None:
149149
global_returns = list(np.mean(np.array(tmp_returns, dtype=np.float32), axis=0))
150150
agent_buffer_trajectory[BufferKey.ADVANTAGES].set(global_advantages)
151151
agent_buffer_trajectory[BufferKey.DISCOUNTED_RETURNS].set(global_returns)
152-
# Append to update buffer
153-
agent_buffer_trajectory.resequence_and_append(
154-
self.update_buffer, training_length=self.policy.sequence_length
155-
)
152+
153+
self._append_to_update_buffer(agent_buffer_trajectory)
156154

157155
# If this was a terminal trajectory, append stats and reset reward collection
158156
if trajectory.done_reached:

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ def save_replay_buffer(self) -> None:
104104
Save the training buffer's update buffer to a pickle file.
105105
"""
106106
filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5")
107-
logger.info(f"Saving Experience Replay Buffer to {filename}")
107+
logger.info(f"Saving Experience Replay Buffer to {filename}...")
108108
with open(filename, "wb") as file_object:
109109
self.update_buffer.save_to_file(file_object)
110+
logger.info(
111+
f"Saved Experience Replay Buffer ({os.path.getsize(filename)} bytes)."
112+
)
110113

111114
def load_replay_buffer(self) -> None:
112115
"""
@@ -175,10 +178,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None:
175178
agent_buffer_trajectory[ObsUtil.get_name_at_next(i)][-1] = obs
176179
agent_buffer_trajectory[BufferKey.DONE][-1] = False
177180

178-
# Append to update buffer
179-
agent_buffer_trajectory.resequence_and_append(
180-
self.update_buffer, training_length=self.policy.sequence_length
181-
)
181+
self._append_to_update_buffer(agent_buffer_trajectory)
182182

183183
if trajectory.done_reached:
184184
self._update_end_episode_stats(agent_id, self.optimizer)

ml-agents/mlagents/trainers/tests/test_rl_trainer.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ def test_clear_update_buffer():
7777

7878

7979
@mock.patch("mlagents.trainers.trainer.trainer.Trainer.save_model")
80-
@mock.patch("mlagents.trainers.trainer.rl_trainer.RLTrainer._clear_update_buffer")
81-
def test_advance(mocked_clear_update_buffer, mocked_save_model):
80+
def test_advance(mocked_save_model):
8281
trainer = create_rl_trainer()
8382
mock_policy = mock.Mock()
8483
trainer.add_policy("TestBrain", mock_policy)
@@ -115,9 +114,8 @@ def test_advance(mocked_clear_update_buffer, mocked_save_model):
115114
with pytest.raises(AgentManagerQueue.Empty):
116115
policy_queue.get_nowait()
117116

118-
# Check that the buffer has been cleared
117+
# Check that no model has been saved
119118
assert not trainer.should_still_train
120-
assert mocked_clear_update_buffer.call_count > 0
121119
assert mocked_save_model.call_count == 0
122120

123121

@@ -181,6 +179,39 @@ def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
181179
mock_add_checkpoint.assert_has_calls(add_checkpoint_calls)
182180

183181

182+
def test_update_buffer_append():
183+
trainer = create_rl_trainer()
184+
mock_policy = mock.Mock()
185+
trainer.add_policy("TestBrain", mock_policy)
186+
trajectory_queue = AgentManagerQueue("testbrain")
187+
policy_queue = AgentManagerQueue("testbrain")
188+
trainer.subscribe_trajectory_queue(trajectory_queue)
189+
trainer.publish_policy_queue(policy_queue)
190+
time_horizon = 10
191+
trajectory = mb.make_fake_trajectory(
192+
length=time_horizon,
193+
observation_specs=create_observation_specs_with_shapes([(1,)]),
194+
max_step_complete=True,
195+
action_spec=ActionSpec.create_discrete((2,)),
196+
)
197+
agentbuffer_trajectory = trajectory.to_agentbuffer()
198+
assert trainer.update_buffer.num_experiences == 0
199+
200+
# Check that if we append, our update buffer gets longer.
201+
# max_steps = 100
202+
for i in range(10):
203+
trainer._process_trajectory(trajectory)
204+
trainer._append_to_update_buffer(agentbuffer_trajectory)
205+
assert trainer.update_buffer.num_experiences == (i + 1) * time_horizon
206+
207+
# Check that if we append after stopping training, nothing happens.
208+
# We process enough trajectories to hit max steps
209+
trainer.set_is_policy_updating(False)
210+
trainer._process_trajectory(trajectory)
211+
trainer._append_to_update_buffer(agentbuffer_trajectory)
212+
assert trainer.update_buffer.num_experiences == (i + 1) * time_horizon
213+
214+
184215
class RLTrainerWarningTest(unittest.TestCase):
185216
def test_warning_group_reward(self):
186217
with self.assertLogs("mlagents.trainers", level="WARN") as cm:

ml-agents/mlagents/trainers/trainer/rl_trainer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,21 @@ def _maybe_write_summary(self, step_after_process: int) -> None:
245245
if step_after_process >= self._next_summary_step and self.get_step != 0:
246246
self._write_summary(self._next_summary_step)
247247

248+
def _append_to_update_buffer(self, agentbuffer_trajectory: AgentBuffer) -> None:
249+
"""
250+
Append an AgentBuffer to the update buffer. If the trainer isn't training,
251+
don't update to avoid a memory leak.
252+
"""
253+
if self.should_still_train:
254+
seq_len = (
255+
self.trainer_settings.network_settings.memory.sequence_length
256+
if self.trainer_settings.network_settings.memory is not None
257+
else 1
258+
)
259+
agentbuffer_trajectory.resequence_and_append(
260+
self.update_buffer, training_length=seq_len
261+
)
262+
248263
def _maybe_save_model(self, step_after_process: int) -> None:
249264
"""
250265
If processing the trajectory will make the step exceed the next model write,
@@ -298,5 +313,3 @@ def advance(self) -> None:
298313
for q in self.policy_queues:
299314
# Get policies that correspond to the policy queue in question
300315
q.put(self.get_policy(q.behavior_id))
301-
else:
302-
self._clear_update_buffer()

0 commit comments

Comments
 (0)