|
| 1 | +from distutils.util import strtobool |
| 2 | +import os |
1 | 3 | import logging |
2 | 4 | from typing import Any, List, Set, NamedTuple |
| 5 | +from distutils.version import LooseVersion |
3 | 6 |
|
4 | 7 | try: |
5 | 8 | import onnx |
|
18 | 21 | from tensorflow.python.framework import graph_util |
19 | 22 | from mlagents.trainers import tensorflow_to_barracuda as tf2bc |
20 | 23 |
|
| 24 | +if LooseVersion(tf.__version__) < LooseVersion("1.12.0"): |
| 25 | + # ONNX is only tested on 1.12.0 and later |
| 26 | + ONNX_EXPORT_ENABLED = False |
| 27 | + |
| 28 | + |
21 | 29 | logger = logging.getLogger("mlagents.trainers") |
22 | 30 |
|
23 | 31 | POSSIBLE_INPUT_NODES = frozenset( |
@@ -67,18 +75,28 @@ def export_policy_model( |
67 | 75 | logger.info(f"Exported {settings.model_path}.nn file") |
68 | 76 |
|
69 | 77 | # Save to onnx too (if we were able to import it) |
70 | | - if ONNX_EXPORT_ENABLED and settings.convert_to_onnx: |
71 | | - try: |
72 | | - onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def) |
73 | | - onnx_output_path = settings.model_path + ".onnx" |
74 | | - with open(onnx_output_path, "wb") as f: |
75 | | - f.write(onnx_graph.SerializeToString()) |
76 | | - logger.info(f"Converting to {onnx_output_path}") |
77 | | - except Exception: |
78 | | - logger.exception( |
79 | | - "Exception trying to save ONNX graph. Please report this error on " |
80 | | - "https://github.com/Unity-Technologies/ml-agents/issues and " |
81 | | - "attach a copy of frozen_graph_def.pb" |
| 78 | + if ONNX_EXPORT_ENABLED: |
| 79 | + if settings.convert_to_onnx: |
| 80 | + try: |
| 81 | + onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def) |
| 82 | + onnx_output_path = settings.model_path + ".onnx" |
| 83 | + with open(onnx_output_path, "wb") as f: |
| 84 | + f.write(onnx_graph.SerializeToString()) |
| 85 | + logger.info(f"Converting to {onnx_output_path}") |
| 86 | + except Exception: |
| 87 | + # Make conversion errors fatal depending on environment variables (only done during CI) |
| 88 | + if _enforce_onnx_conversion(): |
| 89 | + raise |
| 90 | + logger.exception( |
| 91 | + "Exception trying to save ONNX graph. Please report this error on " |
| 92 | + "https://github.com/Unity-Technologies/ml-agents/issues and " |
| 93 | + "attach a copy of frozen_graph_def.pb" |
| 94 | + ) |
| 95 | + |
| 96 | + else: |
| 97 | + if _enforce_onnx_conversion(): |
| 98 | + raise RuntimeError( |
| 99 | + "ONNX conversion enforced, but couldn't import dependencies." |
82 | 100 | ) |
83 | 101 |
|
84 | 102 |
|
@@ -203,3 +221,16 @@ def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str |
203 | 221 | for n in nodes: |
204 | 222 | logger.info("\t" + n) |
205 | 223 | return nodes |
| 224 | + |
| 225 | + |
| 226 | +def _enforce_onnx_conversion() -> bool: |
| 227 | + env_var_name = "TEST_ENFORCE_ONNX_CONVERSION" |
| 228 | + if env_var_name not in os.environ: |
| 229 | + return False |
| 230 | + |
| 231 | + val = os.environ[env_var_name] |
| 232 | + try: |
| 233 | + # This handles e.g. "false" converting reasonably to False |
| 234 | + return strtobool(val) |
| 235 | + except Exception: |
| 236 | + return False |
0 commit comments