Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import json

from typing import Callable, Optional, List
from typing import Callable, Optional, List, Dict

import mlagents.trainers
import mlagents_envs
Expand All @@ -22,6 +22,8 @@
from mlagents.trainers.training_status import GlobalTrainingStatus
from mlagents_envs.base_env import BaseEnv
from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.model_saver.torch_model_saver import DEFAULT_CHECKPOINT_NAME
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.timers import (
hierarchical_timer,
Expand Down Expand Up @@ -49,6 +51,33 @@ def parse_command_line(argv: Optional[List[str]] = None) -> RunOptions:
return RunOptions.from_argparse(args)


def _get_checkpoint_name(
behavior_name: str, checkpoint_dict: Optional[Dict[str, str]]
) -> str:
"""
Retrieve the checkpoint file mapped to this behavior, use most recent by default
:param behavior_name: Name of the behavior to load checkpoint for
:param checkpoint_dict: mapping from behavior_name to checkpoint_file_name.pt
:return:
"""
if checkpoint_dict and checkpoint_dict.get(behavior_name):
return checkpoint_dict[behavior_name]
else:
return DEFAULT_CHECKPOINT_NAME


def _validate_init_full_path(init_file: str) -> None:
"""
Validate initialization path
:param init_file: full path to initialization checkpoint file
:return:
"""
if not (os.path.isfile(init_file) and init_file.endswith(".pt")):
raise UnityTrainerException(
f"Could not initialize from {init_file}. file does not exists or is not a `.pt` file"
)


def run_training(run_seed: int, options: RunOptions) -> None:
"""
Launches training session.
Expand All @@ -72,11 +101,23 @@ def run_training(run_seed: int, options: RunOptions) -> None:
)
# Make run logs directory
os.makedirs(run_logs_dir, exist_ok=True)
# Load any needed states
# Load any needed states in case of resume
if checkpoint_settings.resume:
GlobalTrainingStatus.load_state(
os.path.join(run_logs_dir, "training_status.json")
)
# In case of initialization, set init_path for all behaviors
elif checkpoint_settings.maybe_init_path is not None:
for behavior_name, ts in options.behaviors.items():
if ts.init_path is None:
ts.init_path = os.path.join(
checkpoint_settings.maybe_init_path,
behavior_name,
_get_checkpoint_name(
behavior_name, checkpoint_settings.init_checkpoints_list
),
)
_validate_init_full_path(ts.init_path)

# Configure Tensorboard Writers and StatsReporter
stats_writers = register_stats_writer_plugins(options)
Expand Down
12 changes: 8 additions & 4 deletions ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


logger = get_logger(__name__)
DEFAULT_CHECKPOINT_NAME = "checkpoint.pt"


class TorchModelSaver(BaseModelSaver):
Expand Down Expand Up @@ -55,7 +56,7 @@ def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]
pytorch_ckpt_path = f"{checkpoint_path}.pt"
export_ckpt_path = f"{checkpoint_path}.onnx"
torch.save(state_dict, f"{checkpoint_path}.pt")
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
torch.save(state_dict, os.path.join(self.model_path, DEFAULT_CHECKPOINT_NAME))
self.export(checkpoint_path, behavior_name)
return export_ckpt_path, [pytorch_ckpt_path]

Expand All @@ -75,16 +76,19 @@ def initialize_or_load(self, policy: Optional[TorchPolicy] = None) -> None:
)
elif self.load:
logger.info(f"Resuming from {self.model_path}.")
self._load_model(self.model_path, policy, reset_global_steps=reset_steps)
self._load_model(
os.path.join(self.model_path, DEFAULT_CHECKPOINT_NAME),
policy,
reset_global_steps=reset_steps,
)

def _load_model(
self,
load_path: str,
policy: Optional[TorchPolicy] = None,
reset_global_steps: bool = False,
) -> None:
model_path = os.path.join(load_path, "checkpoint.pt")
saved_state_dict = torch.load(model_path)
saved_state_dict = torch.load(load_path)
if policy is None:
modules = self.modules
policy = self.policy
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ class CheckpointSettings:
train_model: bool = parser.get_default("train_model")
inference: bool = parser.get_default("inference")
results_dir: str = parser.get_default("results_dir")
init_checkpoints_list: Optional[Dict[str, str]] = None

@property
def write_path(self) -> str:
Expand Down
7 changes: 5 additions & 2 deletions ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.model_saver.torch_model_saver import (
TorchModelSaver,
DEFAULT_CHECKPOINT_NAME,
)
from mlagents.trainers.settings import (
TrainerSettings,
NetworkSettings,
Expand Down Expand Up @@ -62,7 +65,7 @@ def test_load_save_policy(tmp_path):
assert policy2.get_current_step() == 2000

# Try initialize from path 1
trainer_params.init_path = path1
trainer_params.init_path = os.path.join(path1, DEFAULT_CHECKPOINT_NAME)
model_saver3 = TorchModelSaver(trainer_params, path2)
policy3 = create_policy_mock(trainer_params)
model_saver3.register(policy3)
Expand Down
2 changes: 0 additions & 2 deletions ml-agents/mlagents/trainers/trainer/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def _initialize_trainer(
:return:
"""
trainer_artifact_path = os.path.join(output_path, brain_name)
if init_path is not None:
trainer_settings.init_path = os.path.join(init_path, brain_name)
Comment on lines -103 to -104
Copy link
Contributor Author

@maryamhonari maryamhonari Sep 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why init_path is passed up to here, was there a specific reason to set this in trainer factory?
I moved this logic to learn.py and if it's harmless will remove the init_path attribute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tried using init_path and see if it does what we expect out of it?
Reading a bit of the code in torch_model_saver.py it looks like setting init_path in the trainer settings initializes the policy from a checkpoint. Is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Previously trainer_settings.init_path="result_dir/run_id/brain_name and we added the checkpoint name at the end in torch_model_saver.py
This PR sets the full path trainer_settings.init_path="result_dir/run_id/brain_name/checkpoint_name.pt in learn.py:102


min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name)

Expand Down