-
Notifications
You must be signed in to change notification settings - Fork 4.4k
enforce onnx conversion in CI #3628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,8 @@ | ||
| from distutils.util import strtobool | ||
| import os | ||
| import logging | ||
| from typing import Any, List, Set, NamedTuple | ||
| from distutils.version import LooseVersion | ||
|
|
||
| try: | ||
| import onnx | ||
|
|
@@ -18,6 +21,11 @@ | |
| from tensorflow.python.framework import graph_util | ||
| from mlagents.trainers import tensorflow_to_barracuda as tf2bc | ||
|
|
||
| if LooseVersion(tf.__version__) < LooseVersion("1.12.0"): | ||
| # ONNX is only tested on 1.12.0 and later | ||
| ONNX_EXPORT_ENABLED = False | ||
|
|
||
|
|
||
| logger = logging.getLogger("mlagents.trainers") | ||
|
|
||
| POSSIBLE_INPUT_NODES = frozenset( | ||
|
|
@@ -67,18 +75,28 @@ def export_policy_model( | |
| logger.info(f"Exported {settings.model_path}.nn file") | ||
|
|
||
| # Save to onnx too (if we were able to import it) | ||
| if ONNX_EXPORT_ENABLED and settings.convert_to_onnx: | ||
| try: | ||
| onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def) | ||
| onnx_output_path = settings.model_path + ".onnx" | ||
| with open(onnx_output_path, "wb") as f: | ||
| f.write(onnx_graph.SerializeToString()) | ||
| logger.info(f"Converting to {onnx_output_path}") | ||
| except Exception: | ||
| logger.exception( | ||
| "Exception trying to save ONNX graph. Please report this error on " | ||
| "https://github.com/Unity-Technologies/ml-agents/issues and " | ||
| "attach a copy of frozen_graph_def.pb" | ||
| if ONNX_EXPORT_ENABLED: | ||
| if settings.convert_to_onnx: | ||
| try: | ||
| onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def) | ||
| onnx_output_path = settings.model_path + ".onnx" | ||
| with open(onnx_output_path, "wb") as f: | ||
| f.write(onnx_graph.SerializeToString()) | ||
| logger.info(f"Converting to {onnx_output_path}") | ||
| except Exception: | ||
| # Make conversion errors fatal depending on environment variables (only done during CI) | ||
| if _enforce_onnx_conversion(): | ||
| raise | ||
| logger.exception( | ||
| "Exception trying to save ONNX graph. Please report this error on " | ||
| "https://github.com/Unity-Technologies/ml-agents/issues and " | ||
| "attach a copy of frozen_graph_def.pb" | ||
| ) | ||
|
|
||
| else: | ||
| if _enforce_onnx_conversion(): | ||
| raise RuntimeError( | ||
| "ONNX conversion enforced, but couldn't import dependencies." | ||
|
||
| ) | ||
|
|
||
|
|
||
|
|
@@ -203,3 +221,16 @@ def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str | |
| for n in nodes: | ||
| logger.info("\t" + n) | ||
| return nodes | ||
|
|
||
|
|
||
| def _enforce_onnx_conversion() -> bool: | ||
| env_var_name = "TEST_ENFORCE_ONNX_CONVERSION" | ||
| if env_var_name not in os.environ: | ||
| return False | ||
|
|
||
| val = os.environ[env_var_name] | ||
| try: | ||
| # This handles e.g. "false" converting reasonably to False | ||
| return strtobool(val) | ||
| except Exception: | ||
| return False | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,5 +3,5 @@ grpcio==1.11.0 | |
| numpy==1.14.1 | ||
| Pillow==4.2.1 | ||
| protobuf==3.6 | ||
| tensorflow==1.7 | ||
| tensorflow==1.7.0 | ||
| h5py==2.9.0 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also filter out tf 2.x ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to explicitly filter it out. A future version of tf2onnx should support f2 2.x, in which case the
importshould succeed.