@@ -77,8 +77,7 @@ def test_clear_update_buffer():
7777
7878
7979@mock .patch ("mlagents.trainers.trainer.trainer.Trainer.save_model" )
80- @mock .patch ("mlagents.trainers.trainer.rl_trainer.RLTrainer._clear_update_buffer" )
81- def test_advance (mocked_clear_update_buffer , mocked_save_model ):
80+ def test_advance (mocked_save_model ):
8281 trainer = create_rl_trainer ()
8382 mock_policy = mock .Mock ()
8483 trainer .add_policy ("TestBrain" , mock_policy )
@@ -115,9 +114,8 @@ def test_advance(mocked_clear_update_buffer, mocked_save_model):
115114 with pytest .raises (AgentManagerQueue .Empty ):
116115 policy_queue .get_nowait ()
117116
118- # Check that the buffer has been cleared
117+ # Check that no model has been saved
119118 assert not trainer .should_still_train
120- assert mocked_clear_update_buffer .call_count > 0
121119 assert mocked_save_model .call_count == 0
122120
123121
@@ -181,6 +179,39 @@ def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
181179 mock_add_checkpoint .assert_has_calls (add_checkpoint_calls )
182180
183181
182+ def test_update_buffer_append ():
183+ trainer = create_rl_trainer ()
184+ mock_policy = mock .Mock ()
185+ trainer .add_policy ("TestBrain" , mock_policy )
186+ trajectory_queue = AgentManagerQueue ("testbrain" )
187+ policy_queue = AgentManagerQueue ("testbrain" )
188+ trainer .subscribe_trajectory_queue (trajectory_queue )
189+ trainer .publish_policy_queue (policy_queue )
190+ time_horizon = 10
191+ trajectory = mb .make_fake_trajectory (
192+ length = time_horizon ,
193+ observation_specs = create_observation_specs_with_shapes ([(1 ,)]),
194+ max_step_complete = True ,
195+ action_spec = ActionSpec .create_discrete ((2 ,)),
196+ )
197+ agentbuffer_trajectory = trajectory .to_agentbuffer ()
198+ assert trainer .update_buffer .num_experiences == 0
199+
200+ # Check that if we append, our update buffer gets longer.
201+ # max_steps = 100
202+ for i in range (10 ):
203+ trainer ._process_trajectory (trajectory )
204+ trainer ._append_to_update_buffer (agentbuffer_trajectory )
205+ assert trainer .update_buffer .num_experiences == (i + 1 ) * time_horizon
206+
207+ # Check that if we append after stopping training, nothing happens.
208+ # We process enough trajectories to hit max steps
209+ trainer .set_is_policy_updating (False )
210+ trainer ._process_trajectory (trajectory )
211+ trainer ._append_to_update_buffer (agentbuffer_trajectory )
212+ assert trainer .update_buffer .num_experiences == (i + 1 ) * time_horizon
213+
214+
184215class RLTrainerWarningTest (unittest .TestCase ):
185216 def test_warning_group_reward (self ):
186217 with self .assertLogs ("mlagents.trainers" , level = "WARN" ) as cm :
0 commit comments