@@ -67,7 +67,7 @@ def __init__(
6767 self .hyperparameters : SACSettings = cast (
6868 SACSettings , trainer_settings .hyperparameters
6969 )
70- self .step = 0
70+ self ._step = 0
7171
7272 # Don't divide by zero
7373 self .update_steps = 1
@@ -188,7 +188,7 @@ def _is_ready_update(self) -> bool:
188188 """
189189 return (
190190 self .update_buffer .num_experiences >= self .hyperparameters .batch_size
191- and self .step >= self .hyperparameters .buffer_init_steps
191+ and self ._step >= self .hyperparameters .buffer_init_steps
192192 )
193193
194194 @timed
@@ -251,9 +251,9 @@ def _update_sac_policy(self) -> bool:
251251
252252 batch_update_stats : Dict [str , list ] = defaultdict (list )
253253 while (
254- self .step - self .hyperparameters .buffer_init_steps
254+ self ._step - self .hyperparameters .buffer_init_steps
255255 ) / self .update_steps > self .steps_per_update :
256- logger .debug (f"Updating SAC policy at step { self .step } " )
256+ logger .debug (f"Updating SAC policy at step { self ._step } " )
257257 buffer = self .update_buffer
258258 if self .update_buffer .num_experiences >= self .hyperparameters .batch_size :
259259 sampled_minibatch = buffer .sample_mini_batch (
@@ -305,12 +305,12 @@ def _update_reward_signals(self) -> None:
305305 )
306306 batch_update_stats : Dict [str , list ] = defaultdict (list )
307307 while (
308- self .step - self .hyperparameters .buffer_init_steps
308+ self ._step - self .hyperparameters .buffer_init_steps
309309 ) / self .reward_signal_update_steps > self .reward_signal_steps_per_update :
310310 # Get minibatches for reward signal update if needed
311311 reward_signal_minibatches = {}
312312 for name in self .optimizer .reward_signals .keys ():
313- logger .debug (f"Updating { name } at step { self .step } " )
313+ logger .debug (f"Updating { name } at step { self ._step } " )
314314 if name != "extrinsic" :
315315 reward_signal_minibatches [name ] = buffer .sample_mini_batch (
316316 self .hyperparameters .batch_size ,
@@ -355,11 +355,11 @@ def add_policy(
355355 self .model_saver .initialize_or_load ()
356356
357357 # Needed to resume loads properly
358- self .step = policy .get_current_step ()
358+ self ._step = policy .get_current_step ()
359359 # Assume steps were updated at the correct ratio before
360- self .update_steps = int (max (1 , self .step / self .steps_per_update ))
360+ self .update_steps = int (max (1 , self ._step / self .steps_per_update ))
361361 self .reward_signal_update_steps = int (
362- max (1 , self .step / self .reward_signal_steps_per_update )
362+ max (1 , self ._step / self .reward_signal_steps_per_update )
363363 )
364364
365365 def get_policy (self , name_behavior_id : str ) -> Policy :
0 commit comments