Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ and this project adheres to
- Added the capacity to initialize behaviors from any checkpoint and not just the latest one (#5525)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Set gym version in gym-unity to gym release 0.20.0
- Changed default behavior to restart crashed Unity environments rather than exiting.
- Rate & lifetime limits on this are configurable via 3 new yaml options
1. env_params.max_lifetime_restarts (--max-lifetime-restarts) [default=10]
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]
### Bug Fixes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
Expand Down
3 changes: 3 additions & 0 deletions docs/Training-ML-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ env_settings:
base_port: 5005
num_envs: 1
seed: -1
max_lifetime_restarts: 10
restarts_rate_limit_n: 1
restarts_rate_limit_period_s: 60
```

#### Engine settings
Expand Down
5 changes: 4 additions & 1 deletion ml-agents-envs/mlagents_envs/rpc_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def create_server(self):

try:
# Establish communication grpc
self.server = grpc.server(ThreadPoolExecutor(max_workers=10))
self.server = grpc.server(
thread_pool=ThreadPoolExecutor(max_workers=10),
options=(("grpc.so_reuseport", 1),),
)
self.unity_to_external = UnityToExternalServicerImplementation()
add_UnityToExternalProtoServicer_to_server(
self.unity_to_external, self.server
Expand Down
20 changes: 20 additions & 0 deletions ml-agents/mlagents/trainers/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,26 @@ def _create_parser() -> argparse.ArgumentParser:
"passed to the executable.",
action=DetectDefault,
)
argparser.add_argument(
"--max-lifetime-restarts",
default=10,
help="The max number of times a single Unity executable can crash over its lifetime before ml-agents exits. "
"Can be set to -1 if no limit is desired.",
action=DetectDefault,
)
argparser.add_argument(
"--restarts-rate-limit-n",
default=1,
help="The maximum number of times a single Unity executable can crash over a period of time (period set in "
"restarts-rate-limit-period-s). Can be set to -1 to not use rate limiting with restarts.",
action=DetectDefault,
)
argparser.add_argument(
"--restarts-rate-limit-period-s",
default=60,
help="The period of time --restarts-rate-limit-n applies to.",
action=DetectDefault,
)
argparser.add_argument(
"--torch",
default=False,
Expand Down
5 changes: 5 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,11 @@ class EnvironmentSettings:
base_port: int = parser.get_default("base_port")
num_envs: int = attr.ib(default=parser.get_default("num_envs"))
seed: int = parser.get_default("seed")
max_lifetime_restarts: int = parser.get_default("max_lifetime_restarts")
restarts_rate_limit_n: int = parser.get_default("restarts_rate_limit_n")
restarts_rate_limit_period_s: int = parser.get_default(
"restarts_rate_limit_period_s"
)

@num_envs.validator
def validate_num_envs(self, attribute, value):
Expand Down
122 changes: 117 additions & 5 deletions ml-agents/mlagents/trainers/subprocess_env_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
import cloudpickle
import enum
Expand Down Expand Up @@ -251,6 +252,14 @@ def __init__(
self.env_workers: List[UnityEnvWorker] = []
self.step_queue: Queue = Queue()
self.workers_alive = 0
self.env_factory = env_factory
self.run_options = run_options
self.env_parameters: Optional[Dict] = None
# Each worker is correlated with a list of times they restarted within the last time period.
self.recent_restart_timestamps: List[List[datetime.datetime]] = [
[] for _ in range(n_env)
]
self.restart_counts: List[int] = [0] * n_env
for worker_idx in range(n_env):
self.env_workers.append(
self.create_worker(
Expand Down Expand Up @@ -293,6 +302,105 @@ def _queue_steps(self) -> None:
env_worker.send(EnvironmentCommand.STEP, env_action_info)
env_worker.waiting = True

def _restart_failed_workers(self, first_failure: EnvironmentResponse) -> None:
if first_failure.cmd != EnvironmentCommand.ENV_EXITED:
return
# Drain the step queue to make sure all workers are paused and we have found all concurrent errors.
# Pausing all training is needed since we need to reset all pending training steps as they could be corrupted.
other_failures: Dict[int, Exception] = self._drain_step_queue()
# TODO: Once we use python 3.9 switch to using the | operator to combine dicts.
failures: Dict[int, Exception] = {
**{first_failure.worker_id: first_failure.payload},
**other_failures,
}
for worker_id, ex in failures.items():
self._assert_worker_can_restart(worker_id, ex)
logger.warning(f"Restarting worker[{worker_id}] after '{ex}'")
self.recent_restart_timestamps[worker_id].append(datetime.datetime.now())
self.restart_counts[worker_id] += 1
self.env_workers[worker_id] = self.create_worker(
worker_id, self.step_queue, self.env_factory, self.run_options
)
# The restarts were successful, clear all the existing training trajectories so we don't use corrupted or
# outdated data.
self.reset(self.env_parameters)

def _drain_step_queue(self) -> Dict[int, Exception]:
"""
Drains all steps out of the step queue and returns all exceptions from crashed workers.
This will effectively pause all workers so that they won't do anything until _queue_steps is called.
"""
all_failures = {}
workers_still_pending = {w.worker_id for w in self.env_workers if w.waiting}
deadline = datetime.datetime.now() + datetime.timedelta(minutes=1)
while workers_still_pending and deadline > datetime.datetime.now():
try:
while True:
step: EnvironmentResponse = self.step_queue.get_nowait()
if step.cmd == EnvironmentCommand.ENV_EXITED:
workers_still_pending.add(step.worker_id)
all_failures[step.worker_id] = step.payload
else:
workers_still_pending.remove(step.worker_id)
self.env_workers[step.worker_id].waiting = False
except EmptyQueueException:
pass
if deadline < datetime.datetime.now():
still_waiting = {w.worker_id for w in self.env_workers if w.waiting}
raise TimeoutError(f"Workers {still_waiting} stuck in waiting state")
return all_failures

def _assert_worker_can_restart(self, worker_id: int, exception: Exception) -> None:
"""
Checks if we can recover from an exception from a worker.
If the restart limit is exceeded it will raise a UnityCommunicationException.
If the exception is not recoverable it re-raises the exception.
"""
if (
isinstance(exception, UnityCommunicationException)
or isinstance(exception, UnityTimeOutException)
or isinstance(exception, UnityEnvironmentException)
or isinstance(exception, UnityCommunicatorStoppedException)
):
if self._worker_has_restart_quota(worker_id):
return
else:
logger.error(
f"Worker {worker_id} exceeded the allowed number of restarts."
)
raise exception
raise exception

def _worker_has_restart_quota(self, worker_id: int) -> bool:
self._drop_old_restart_timestamps(worker_id)
max_lifetime_restarts = self.run_options.env_settings.max_lifetime_restarts
max_limit_check = (
max_lifetime_restarts == -1
or self.restart_counts[worker_id] < max_lifetime_restarts
)

rate_limit_n = self.run_options.env_settings.restarts_rate_limit_n
rate_limit_check = (
rate_limit_n == -1
or len(self.recent_restart_timestamps[worker_id]) < rate_limit_n
)

return rate_limit_check and max_limit_check

def _drop_old_restart_timestamps(self, worker_id: int) -> None:
"""
Drops environment restart timestamps that are outside of the current window.
"""

def _filter(t: datetime.datetime) -> bool:
return t > datetime.datetime.now() - datetime.timedelta(
seconds=self.run_options.env_settings.restarts_rate_limit_period_s
)

self.recent_restart_timestamps[worker_id] = list(
filter(_filter, self.recent_restart_timestamps[worker_id])
)

def _step(self) -> List[EnvironmentStep]:
# Queue steps for any workers which aren't in the "waiting" state.
self._queue_steps()
Expand All @@ -306,15 +414,18 @@ def _step(self) -> List[EnvironmentStep]:
while True:
step: EnvironmentResponse = self.step_queue.get_nowait()
if step.cmd == EnvironmentCommand.ENV_EXITED:
env_exception: Exception = step.payload
raise env_exception
self.env_workers[step.worker_id].waiting = False
if step.worker_id not in step_workers:
# If even one env exits try to restart all envs that failed.
self._restart_failed_workers(step)
# Clear state and restart this function.
worker_steps.clear()
step_workers.clear()
self._queue_steps()
elif step.worker_id not in step_workers:
self.env_workers[step.worker_id].waiting = False
worker_steps.append(step)
step_workers.add(step.worker_id)
except EmptyQueueException:
pass

step_infos = self._postprocess_steps(worker_steps)
return step_infos

Expand All @@ -339,6 +450,7 @@ def set_env_parameters(self, config: Dict = None) -> None:
EnvironmentParametersSidehannel for each worker.
:param config: Dict of environment parameter keys and values
"""
self.env_parameters = config
for ew in self.env_workers:
ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config)

Expand Down
65 changes: 63 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import mock
from unittest.mock import Mock, MagicMock
from unittest.mock import Mock, MagicMock, call, ANY
import unittest
import pytest
from queue import Empty as EmptyQueue
Expand All @@ -14,7 +14,10 @@
from mlagents.trainers.env_manager import EnvironmentStep
from mlagents_envs.base_env import BaseEnv
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.exception import (
UnityEnvironmentException,
UnityCommunicationException,
)
from mlagents.trainers.tests.simple_test_envs import (
SimpleEnvironment,
UnexpectedExceptionEnvironment,
Expand Down Expand Up @@ -153,6 +156,64 @@ def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker):
manager.env_workers[1].previous_step,
]

@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
)
def test_crashed_env_restarts(self, mock_create_worker):
crashing_worker = MockEnvWorker(
0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
)
restarting_worker = MockEnvWorker(
0, EnvironmentResponse(EnvironmentCommand.RESET, 0, 0)
)
healthy_worker = MockEnvWorker(
1, EnvironmentResponse(EnvironmentCommand.RESET, 1, 1)
)
mock_create_worker.side_effect = [
crashing_worker,
healthy_worker,
restarting_worker,
]
manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 2)
manager.step_queue = Mock()
manager.step_queue.get_nowait.side_effect = [
EnvironmentResponse(
EnvironmentCommand.ENV_EXITED,
0,
UnityCommunicationException("Test msg"),
),
EnvironmentResponse(EnvironmentCommand.CLOSED, 0, None),
EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(0, None, {})),
EmptyQueue(),
EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(1, None, {})),
EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(2, None, {})),
EmptyQueue(),
]
step_mock = Mock()
last_steps = [Mock(), Mock(), Mock()]
assert crashing_worker is manager.env_workers[0]
assert healthy_worker is manager.env_workers[1]
crashing_worker.previous_step = last_steps[0]
crashing_worker.waiting = True
healthy_worker.previous_step = last_steps[1]
healthy_worker.waiting = True
manager._take_step = Mock(return_value=step_mock)
manager._step()
healthy_worker.send.assert_has_calls(
[
call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
call(EnvironmentCommand.RESET, ANY),
call(EnvironmentCommand.STEP, ANY),
]
)
restarting_worker.send.assert_has_calls(
[
call(EnvironmentCommand.ENVIRONMENT_PARAMETERS, ANY),
call(EnvironmentCommand.RESET, ANY),
call(EnvironmentCommand.STEP, ANY),
]
)

@mock.patch("mlagents.trainers.subprocess_env_manager.SubprocessEnvManager._step")
@mock.patch(
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.training_behaviors",
Expand Down