@@ -65,9 +65,9 @@ def __init__(
6565 )
6666 self .step = 0
6767
68- # Don't count buffer_init_steps in steps_per_update ratio, but also don't divide-by-0
69- self .update_steps = max ( 1 , self . hyperparameters . buffer_init_steps )
70- self .reward_signal_update_steps = max ( 1 , self . hyperparameters . buffer_init_steps )
68+ # Don't divide by zero
69+ self .update_steps = 1
70+ self .reward_signal_update_steps = 1
7171
7272 self .steps_per_update = self .hyperparameters .steps_per_update
7373 self .reward_signal_steps_per_update = (
@@ -229,7 +229,9 @@ def _update_sac_policy(self) -> bool:
229229 )
230230
231231 batch_update_stats : Dict [str , list ] = defaultdict (list )
232- while self .step / self .update_steps > self .steps_per_update :
232+ while (
233+ self .step - self .hyperparameters .buffer_init_steps
234+ ) / self .update_steps > self .steps_per_update :
233235 logger .debug ("Updating SAC policy at step {}" .format (self .step ))
234236 buffer = self .update_buffer
235237 if self .update_buffer .num_experiences >= self .hyperparameters .batch_size :
@@ -282,9 +284,8 @@ def _update_reward_signals(self) -> None:
282284 )
283285 batch_update_stats : Dict [str , list ] = defaultdict (list )
284286 while (
285- self .step / self .reward_signal_update_steps
286- > self .reward_signal_steps_per_update
287- ):
287+ self .step - self .hyperparameters .buffer_init_steps
288+ ) / self .reward_signal_update_steps > self .reward_signal_steps_per_update :
288289 # Get minibatches for reward signal update if needed
289290 reward_signal_minibatches = {}
290291 for name , signal in self .optimizer .reward_signals .items ():
@@ -327,6 +328,11 @@ def add_policy(
327328 self .collected_rewards [_reward_signal ] = defaultdict (lambda : 0 )
328329 # Needed to resume loads properly
329330 self .step = policy .get_current_step ()
331+ # Assume steps were updated at the correct ratio before
332+ self .update_steps = int (max (1 , self .step / self .steps_per_update ))
333+ self .reward_signal_update_steps = int (
334+ max (1 , self .step / self .reward_signal_steps_per_update )
335+ )
330336 self .next_summary_step = self ._get_next_summary_step ()
331337
332338 def get_policy (self , name_behavior_id : str ) -> TFPolicy :
0 commit comments