diff --git a/ml-agents/mlagents/tf_utils/__init__.py b/ml-agents/mlagents/tf_utils/__init__.py index 239f11b423..2acce8bb8c 100644 --- a/ml-agents/mlagents/tf_utils/__init__.py +++ b/ml-agents/mlagents/tf_utils/__init__.py @@ -1,2 +1,3 @@ from mlagents.tf_utils.tf import tf as tf # noqa from mlagents.tf_utils.tf import set_warnings_enabled # noqa +from mlagents.tf_utils.tf import generate_session_config # noqa diff --git a/ml-agents/mlagents/tf_utils/tf.py b/ml-agents/mlagents/tf_utils/tf.py index 6a2917da6d..0cbd2d4145 100644 --- a/ml-agents/mlagents/tf_utils/tf.py +++ b/ml-agents/mlagents/tf_utils/tf.py @@ -23,8 +23,23 @@ def set_warnings_enabled(is_enabled: bool) -> None: """ - Enable or disable tensorflow warnings (notabley, this disables deprecation warnings. + Enable or disable tensorflow warnings (notably, this disables deprecation warnings. :param is_enabled: """ level = tf_logging.WARN if is_enabled else tf_logging.ERROR tf_logging.set_verbosity(level) + + +def generate_session_config() -> tf.ConfigProto: + """ + Generate a ConfigProto to use for ML-Agents that doesn't consume all of the GPU memory + and allows for soft placement in the case of multi-GPU. + """ + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + # For multi-GPU training, set allow_soft_placement to True to allow + # placing the operation into an alternative device automatically + # to prevent from exceptions if the device doesn't suppport the operation + # or the device does not exist + config.allow_soft_placement = True + return config diff --git a/ml-agents/mlagents/trainers/tf_policy.py b/ml-agents/mlagents/trainers/tf_policy.py index edf5076ef5..659f1b0353 100644 --- a/ml-agents/mlagents/trainers/tf_policy.py +++ b/ml-agents/mlagents/trainers/tf_policy.py @@ -3,6 +3,7 @@ import numpy as np from mlagents.tf_utils import tf +from mlagents import tf_utils from mlagents_envs.exception import UnityException from mlagents.trainers.policy import Policy @@ -69,14 +70,9 @@ def __init__(self, seed, brain, trainer_parameters): self.model_path = trainer_parameters["model_path"] self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5) self.graph = tf.Graph() - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - # For multi-GPU training, set allow_soft_placement to True to allow - # placing the operation into an alternative device automatically - # to prevent from exceptions if the device doesn't suppport the operation - # or the device does not exist - config.allow_soft_placement = True - self.sess = tf.Session(config=config, graph=self.graph) + self.sess = tf.Session( + config=tf_utils.generate_session_config(), graph=self.graph + ) self.saver = None if self.use_recurrent: self.m_size = trainer_parameters["memory_size"] diff --git a/ml-agents/mlagents/trainers/trainer.py b/ml-agents/mlagents/trainers/trainer.py index c34825fffb..3319e05ace 100644 --- a/ml-agents/mlagents/trainers/trainer.py +++ b/ml-agents/mlagents/trainers/trainer.py @@ -3,6 +3,7 @@ from typing import Dict, List, Deque, Any from mlagents.tf_utils import tf +from mlagents import tf_utils from collections import deque @@ -70,7 +71,7 @@ def write_tensorboard_text(self, key: str, input_dict: Dict[str, Any]) -> None: :param input_dict: A dictionary that will be displayed in a table on Tensorboard. """ try: - with tf.Session() as sess: + with tf.Session(config=tf_utils.generate_session_config()) as sess: s_op = tf.summary.text( key, tf.convert_to_tensor(