Skip to content

Commit 35ed23e

Browse files
author
Ervin T
authored
Make step in trainer private (#5099)
1 parent 63683b7 commit 35ed23e

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

ml-agents/mlagents/trainers/poca/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def add_policy(
287287
self.model_saver.initialize_or_load()
288288

289289
# Needed to resume loads properly
290-
self.step = policy.get_current_step()
290+
self._step = policy.get_current_step()
291291

292292
def get_policy(self, name_behavior_id: str) -> Policy:
293293
"""

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def add_policy(
263263
self.model_saver.initialize_or_load()
264264

265265
# Needed to resume loads properly
266-
self.step = policy.get_current_step()
266+
self._step = policy.get_current_step()
267267

268268
def get_policy(self, name_behavior_id: str) -> Policy:
269269
"""

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
self.hyperparameters: SACSettings = cast(
6868
SACSettings, trainer_settings.hyperparameters
6969
)
70-
self.step = 0
70+
self._step = 0
7171

7272
# Don't divide by zero
7373
self.update_steps = 1
@@ -188,7 +188,7 @@ def _is_ready_update(self) -> bool:
188188
"""
189189
return (
190190
self.update_buffer.num_experiences >= self.hyperparameters.batch_size
191-
and self.step >= self.hyperparameters.buffer_init_steps
191+
and self._step >= self.hyperparameters.buffer_init_steps
192192
)
193193

194194
@timed
@@ -251,9 +251,9 @@ def _update_sac_policy(self) -> bool:
251251

252252
batch_update_stats: Dict[str, list] = defaultdict(list)
253253
while (
254-
self.step - self.hyperparameters.buffer_init_steps
254+
self._step - self.hyperparameters.buffer_init_steps
255255
) / self.update_steps > self.steps_per_update:
256-
logger.debug(f"Updating SAC policy at step {self.step}")
256+
logger.debug(f"Updating SAC policy at step {self._step}")
257257
buffer = self.update_buffer
258258
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size:
259259
sampled_minibatch = buffer.sample_mini_batch(
@@ -305,12 +305,12 @@ def _update_reward_signals(self) -> None:
305305
)
306306
batch_update_stats: Dict[str, list] = defaultdict(list)
307307
while (
308-
self.step - self.hyperparameters.buffer_init_steps
308+
self._step - self.hyperparameters.buffer_init_steps
309309
) / self.reward_signal_update_steps > self.reward_signal_steps_per_update:
310310
# Get minibatches for reward signal update if needed
311311
reward_signal_minibatches = {}
312312
for name in self.optimizer.reward_signals.keys():
313-
logger.debug(f"Updating {name} at step {self.step}")
313+
logger.debug(f"Updating {name} at step {self._step}")
314314
if name != "extrinsic":
315315
reward_signal_minibatches[name] = buffer.sample_mini_batch(
316316
self.hyperparameters.batch_size,
@@ -355,11 +355,11 @@ def add_policy(
355355
self.model_saver.initialize_or_load()
356356

357357
# Needed to resume loads properly
358-
self.step = policy.get_current_step()
358+
self._step = policy.get_current_step()
359359
# Assume steps were updated at the correct ratio before
360-
self.update_steps = int(max(1, self.step / self.steps_per_update))
360+
self.update_steps = int(max(1, self._step / self.steps_per_update))
361361
self.reward_signal_update_steps = int(
362-
max(1, self.step / self.reward_signal_steps_per_update)
362+
max(1, self._step / self.reward_signal_steps_per_update)
363363
)
364364

365365
def get_policy(self, name_behavior_id: str) -> Policy:

ml-agents/mlagents/trainers/trainer/rl_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,10 @@ def _checkpoint(self) -> ModelCheckpoint:
152152
logger.warning(
153153
"Trainer has multiple policies, but default behavior only saves the first."
154154
)
155-
checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self.step)
155+
checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self._step)
156156
export_ext = "onnx"
157157
new_checkpoint = ModelCheckpoint(
158-
int(self.step),
158+
int(self._step),
159159
f"{checkpoint_path}.{export_ext}",
160160
self._policy_mean_reward(),
161161
time.time(),
@@ -199,7 +199,7 @@ def _increment_step(self, n_steps: int, name_behavior_id: str) -> None:
199199
Increment the step count of the trainer
200200
:param n_steps: number of steps to increment the step count by
201201
"""
202-
self.step += n_steps
202+
self._step += n_steps
203203
self._next_summary_step = self._get_next_interval_step(self.summary_freq)
204204
self._next_save_step = self._get_next_interval_step(
205205
self.trainer_settings.checkpoint_interval
@@ -213,7 +213,7 @@ def _get_next_interval_step(self, interval: int) -> int:
213213
Get the next step count that should result in an action.
214214
:param interval: The interval between actions.
215215
"""
216-
return self.step + (interval - self.step % interval)
216+
return self._step + (interval - self._step % interval)
217217

218218
def _write_summary(self, step: int) -> None:
219219
"""

ml-agents/mlagents/trainers/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
self._reward_buffer: Deque[float] = deque(maxlen=reward_buff_cap)
4646
self.policy_queues: List[AgentManagerQueue[Policy]] = []
4747
self.trajectory_queues: List[AgentManagerQueue[Trajectory]] = []
48-
self.step: int = 0
48+
self._step: int = 0
4949
self.artifact_path = artifact_path
5050
self.summary_freq = self.trainer_settings.summary_freq
5151
self.policies: Dict[str, Policy] = {}
@@ -78,7 +78,7 @@ def get_step(self) -> int:
7878
Returns the number of steps the trainer has performed
7979
:return: the step count of the trainer
8080
"""
81-
return self.step
81+
return self._step
8282

8383
@property
8484
def threaded(self) -> bool:

0 commit comments

Comments
 (0)