Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ml-agents/mlagents/trainers/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ class TrainerError(Exception):
pass


class TrainerConfigError(Exception):
"""
Any error related to the configuration of trainers in the ML-Agents Toolkit.
"""

pass


class CurriculumError(TrainerError):
"""
Any error related to training with a curriculum.
Expand Down
8 changes: 4 additions & 4 deletions ml-agents/mlagents/trainers/tests/test_simple_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def close(self):
pass


PPO_CONFIG = """
default:
PPO_CONFIG = f"""
{BRAIN_NAME}:
trainer: ppo
batch_size: 16
beta: 5.0e-3
Expand All @@ -153,8 +153,8 @@ def close(self):
gamma: 0.99
"""

SAC_CONFIG = """
default:
SAC_CONFIG = f"""
{BRAIN_NAME}:
trainer: sac
batch_size: 8
buffer_size: 500
Expand Down
73 changes: 69 additions & 4 deletions ml-agents/mlagents/trainers/tests/test_trainer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from mlagents.trainers.trainer_util import load_config, _load_config
from mlagents.trainers.trainer_metrics import TrainerMetrics
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.brain import BrainParameters


@pytest.fixture
Expand Down Expand Up @@ -36,6 +37,10 @@ def dummy_config():
use_curiosity: false
curiosity_strength: 0.0
curiosity_enc_size: 1
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)

Expand Down Expand Up @@ -212,7 +217,7 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
BrainParametersMock.return_value.brain_name = "testbrain"
external_brains = {"testbrain": BrainParametersMock()}

with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
trainer_factory = trainer_util.TrainerFactory(
trainer_config=bad_config,
summaries_dir=summaries_dir,
Expand All @@ -228,8 +233,68 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
trainers[brain_name] = trainer_factory.generate(brain_parameters)


def test_handles_no_default_section():
"""
Make sure the trainer setup handles a missing "default" in the config.
"""
brain_name = "testbrain"
config = dummy_config()
no_default_config = {brain_name: config["default"]}
brain_parameters = BrainParameters(
brain_name=brain_name,
vector_observation_space_size=1,
camera_resolutions=[],
vector_action_space_size=[2],
vector_action_descriptions=[],
vector_action_space_type=0,
)

trainer_factory = trainer_util.TrainerFactory(
trainer_config=no_default_config,
summaries_dir="test_dir",
run_id="testrun",
model_path="model_dir",
keep_checkpoints=1,
train_model=True,
load_model=False,
seed=42,
)
trainer_factory.generate(brain_parameters)


def test_raise_if_no_config_for_brain():
"""
Make sure the trainer setup raises a friendlier exception if both "default" and the brain name
are missing from the config.
"""
brain_name = "testbrain"
config = dummy_config()
bad_config = {"some_other_brain": config["default"]}
brain_parameters = BrainParameters(
brain_name=brain_name,
vector_observation_space_size=1,
camera_resolutions=[],
vector_action_space_size=[2],
vector_action_descriptions=[],
vector_action_space_type=0,
)

trainer_factory = trainer_util.TrainerFactory(
trainer_config=bad_config,
summaries_dir="test_dir",
run_id="testrun",
model_path="model_dir",
keep_checkpoints=1,
train_model=True,
load_model=False,
seed=42,
)
with pytest.raises(TrainerConfigError):
trainer_factory.generate(brain_parameters)


def test_load_config_missing_file():
with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
load_config("thisFileDefinitelyDoesNotExist.yaml")


Expand All @@ -250,6 +315,6 @@ def test_load_config_invalid_yaml():
- not
- parse
"""
with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
fp = io.StringIO(file_contents)
_load_config(fp)
36 changes: 22 additions & 14 deletions ml-agents/mlagents/trainers/trainer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, TextIO

from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.ppo.trainer import PPOTrainer
Expand Down Expand Up @@ -80,8 +80,14 @@ def initialize_trainer(
:param multi_gpu: Whether to use multi-GPU training
:return:
"""
trainer_parameters = trainer_config["default"].copy()
brain_name = brain_parameters.brain_name
if "default" not in trainer_config and brain_name not in trainer_config:
raise TrainerConfigError(
f'Trainer config must have either a "default" section, or a section for the brain name ({brain_name}). '
"See config/trainer_config.yaml for an example."
)

trainer_parameters = trainer_config.get("default", {}).copy()
trainer_parameters["summary_path"] = "{basedir}/{name}".format(
basedir=summaries_dir, name=str(run_id) + "_" + brain_name
)
Expand All @@ -96,13 +102,19 @@ def initialize_trainer(
trainer_parameters.update(trainer_config[_brain_key])

trainer: Trainer = None # type: ignore # will be set to one of these, or raise
if trainer_parameters["trainer"] == "offline_bc":
if "trainer" not in trainer_parameters:
raise TrainerConfigError(
f'The "trainer" key must be set in your trainer config for brain {brain_name} (or the default brain).'
)
trainer_type = trainer_parameters["trainer"]

if trainer_type == "offline_bc":
raise UnityTrainerException(
"The offline_bc trainer has been removed. To train with demonstrations, "
"please use a PPO or SAC trainer with the GAIL Reward Signal and/or the "
"Behavioral Cloning feature enabled."
)
elif trainer_parameters["trainer"] == "ppo":
elif trainer_type == "ppo":
trainer = PPOTrainer(
brain_parameters,
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length
Expand All @@ -115,7 +127,7 @@ def initialize_trainer(
run_id,
multi_gpu,
)
elif trainer_parameters["trainer"] == "sac":
elif trainer_type == "sac":
trainer = SACTrainer(
brain_parameters,
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length
Expand All @@ -128,10 +140,8 @@ def initialize_trainer(
run_id,
)
else:
raise UnityEnvironmentException(
"The trainer config contains "
"an unknown trainer type for "
"brain {}".format(brain_name)
raise TrainerConfigError(
f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}'
)
return trainer

Expand All @@ -141,11 +151,9 @@ def load_config(config_path: str) -> Dict[str, Any]:
with open(config_path) as data_file:
return _load_config(data_file)
except IOError:
raise UnityEnvironmentException(
f"Config file could not be found at {config_path}."
)
raise TrainerConfigError(f"Config file could not be found at {config_path}.")
except UnicodeDecodeError:
raise UnityEnvironmentException(
raise TrainerConfigError(
f"There was an error decoding Config file from {config_path}. "
f"Make sure your file is save using UTF-8"
)
Expand All @@ -158,7 +166,7 @@ def _load_config(fp: TextIO) -> Dict[str, Any]:
try:
return yaml.safe_load(fp)
except yaml.parser.ParserError as e:
raise UnityEnvironmentException(
raise TrainerConfigError(
"Error parsing yaml file. Please check for formatting errors. "
"A tool such as http://www.yamllint.com/ can be helpful with this."
) from e