Skip to content

Commit 9956099

Browse files
authored
Add pytest option for updating tank and fix save_mlir function. (huggingface#413)
* Use IREE tf tools to save .mlir modules when generating shark_tank. * Add option to pytest for enabling auto-updates to local shark tank. * xfail mobilenet torch on cpu, cuda and fix CI macos setup * Update test-models.yml to disable macos vulkan CI.
1 parent f97b8ff commit 9956099

File tree

9 files changed

+67
-40
lines changed

9 files changed

+67
-40
lines changed

.github/workflows/test-models.yml

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ jobs:
3636
suite: cuda
3737
- os: ubuntu-latest
3838
suite: cpu
39+
- os: MacStudio
40+
suite: vulkan
3941
- os: MacStudio
4042
suite: cuda
4143
- os: MacStudio
@@ -96,7 +98,7 @@ jobs:
9698
cd $GITHUB_WORKSPACE
9799
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
98100
source shark.venv/bin/activate
99-
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cpu
101+
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cpu --update_tank
100102
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
101103
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
102104
@@ -106,15 +108,15 @@ jobs:
106108
cd $GITHUB_WORKSPACE
107109
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
108110
source shark.venv/bin/activate
109-
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cuda
111+
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cuda --update_tank
110112
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
111113
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
112114
113115
- name: Validate Vulkan Models (MacOS)
114116
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
115117
run: |
116118
cd $GITHUB_WORKSPACE
117-
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
119+
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
118120
source shark.venv/bin/activate
119121
echo "VULKAN SDK PATH wo setup: $VULKAN_SDK"
120122
cd /Users/anush/VulkanSDK/1.3.224.1/
@@ -123,18 +125,12 @@ jobs:
123125
echo "VULKAN SDK PATH with setup: $VULKAN_SDK"
124126
echo $PATH
125127
pip list | grep -E "torch|iree"
126-
pip uninstall -y torch iree-compiler iree-runtime
127-
pip install https://download.pytorch.org/whl/nightly/cpu/torch-1.14.0.dev20221010-cp310-none-macosx_11_0_arm64.whl
128-
pip install https://github.com/llvm/torch-mlir/releases/download/oneshot-20221011.55/torch_mlir-20221011.55-cp310-cp310-macosx_11_0_universal2.whl
129-
pip install https://github.com/nod-ai/SHARK-Runtime/releases/download/candidate-20221011.179/iree_compiler-20221011.179-cp310-cp310-macosx_11_0_universal2.whl
130-
pip install https://github.com/nod-ai/SHARK-Runtime/releases/download/candidate-20221011.179/iree_runtime-20221011.179-cp310-cp310-macosx_11_0_universal2.whl
131-
pip list | grep -E "torch|iree"
132-
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush" tank/test_models.py -k vulkan
128+
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
133129
134130
- name: Validate Vulkan Models (a100)
135131
if: matrix.suite == 'vulkan' && matrix.os != 'MacStudio'
136132
run: |
137133
cd $GITHUB_WORKSPACE
138134
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
139135
source shark.venv/bin/activate
140-
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k vulkan
136+
pytest --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k vulkan --update_tank

conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def pytest_addoption(parser):
3636
default="False",
3737
help="Enables uploading of reproduction artifacts upon test case failure during iree-compile or validation. Must be passed with --ci_sha option ",
3838
)
39+
parser.addoption(
40+
"--update_tank",
41+
action="store_true",
42+
default="False",
43+
help="Update local shark tank with latest artifacts.",
44+
)
3945
parser.addoption(
4046
"--ci_sha",
4147
action="store",

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ requires = [
44
"wheel",
55
"packaging",
66

7-
"numpy==1.22.4",
8-
"torch-mlir>=20220428.420",
9-
"iree-compiler>=20220427.13",
10-
"iree-runtime>=20220427.13",
7+
"numpy>=1.22.4",
8+
"torch-mlir>=20221021.633",
9+
"iree-compiler>=20221022.190",
10+
"iree-runtime>=20221022.190",
1111
]
1212
build-backend = "setuptools.build_meta"

requirements-importer-macos.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
1+
-f https://download.pytorch.org/whl/nightly/cpu/
22
--pre
33

44
numpy
5-
torch
5+
torch==1.14.0.dev20221021
66
torchvision
77

88
tqdm

requirements-importer.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ lit
3232
pyyaml
3333
python-dateutil
3434
sacremoses
35-
chardet
3635

3736
# web dependecies.
3837
gradio

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
backend_deps = []
1111
if "NO_BACKEND" in os.environ.keys():
1212
backend_deps = [
13-
"iree-compiler>=20220427.13",
14-
"iree-runtime>=20220427.13",
13+
"iree-compiler>=20221022.190",
14+
"iree-runtime>=20221022.190",
1515
]
1616

1717
setup(
@@ -33,11 +33,11 @@
3333
"Operating System :: OS Independent",
3434
],
3535
packages=find_packages(exclude=("examples")),
36-
python_requires=">=3.7",
36+
python_requires=">=3.9",
3737
install_requires=[
3838
"numpy",
3939
"PyYAML",
40-
"torch-mlir>=20220428.420",
40+
"torch-mlir>=20221021.633",
4141
]
4242
+ backend_deps,
4343
)

setup_venv.sh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,15 @@ fi
7676
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
7777
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
7878
if [ "$torch_mlir_bin" = true ]; then
79-
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
80-
if [ $? -eq 0 ];then
81-
echo "Successfully Installed torch-mlir"
79+
if [[ $(uname -s) = 'Darwin' ]]; then
80+
echo "MacOS detected. Please install torch-mlir from source or .whl, as dependency problems may occur otherwise."
8281
else
83-
echo "Could not install torch-mlir" >&2
82+
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
83+
if [ $? -eq 0 ];then
84+
echo "Successfully Installed torch-mlir"
85+
else
86+
echo "Could not install torch-mlir" >&2
87+
fi
8488
fi
8589
else
8690
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
@@ -109,6 +113,7 @@ if [[ ! -z "${IMPORTER}" ]]; then
109113
echo "${Yellow}macOS detected.. installing macOS importer tools"
110114
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
111115
$PYTHON -m pip install --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
116+
$PYTHON -m pip install https://github.com/llvm/torch-mlir/releases/download/snapshot-20221024.636/torch_mlir-20221024.636-cp310-cp310-macosx_11_0_universal2.whl
112117
fi
113118
fi
114119

shark/shark_importer.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,25 @@ def _torch_mlir(self, is_dynamic, tracing_required):
7575
self.module, self.inputs, is_dynamic, tracing_required
7676
)
7777

78-
def _tf_mlir(self, func_name):
78+
def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
7979
from iree.compiler import tf as tfc
8080

8181
return tfc.compile_module(
82-
self.module, exported_names=[func_name], import_only=True
82+
self.module,
83+
exported_names=[func_name],
84+
import_only=True,
85+
output_file=save_dir,
8386
)
8487

85-
def _tflite_mlir(self, func_name):
88+
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
8689
from iree.compiler import tflite as tflitec
8790
from shark.iree_utils._common import IREE_TARGET_MAP
8891

8992
self.mlir_model = tflitec.compile_file(
9093
self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter
9194
input_type="tosa",
9295
import_only=True,
96+
output_file=save_dir,
9397
)
9498
return self.mlir_model
9599

@@ -99,6 +103,7 @@ def import_mlir(
99103
is_dynamic=False,
100104
tracing_required=False,
101105
func_name="forward",
106+
save_dir="./shark_tmp/",
102107
):
103108
if self.frontend in ["torch", "pytorch"]:
104109
if self.inputs == None:
@@ -108,10 +113,10 @@ def import_mlir(
108113
sys.exit(1)
109114
return self._torch_mlir(is_dynamic, tracing_required), func_name
110115
if self.frontend in ["tf", "tensorflow"]:
111-
return self._tf_mlir(func_name), func_name
116+
return self._tf_mlir(func_name, save_dir), func_name
112117
if self.frontend in ["tflite", "tf-lite"]:
113118
func_name = "main"
114-
return self._tflite_mlir(func_name), func_name
119+
return self._tflite_mlir(func_name, save_dir), func_name
115120

116121
# Converts the frontend specific tensors into np array.
117122
def convert_to_numpy(self, array_tuple: tuple):
@@ -130,20 +135,22 @@ def save_data(
130135
outputs_name = "golden_out.npz"
131136
func_file_name = "function_name"
132137
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
133-
inputs = [x.cpu().detach() for x in inputs]
138+
try:
139+
inputs = [x.cpu().detach() for x in inputs]
140+
except AttributeError:
141+
try:
142+
inputs = [x.numpy() for x in inputs]
143+
except AttributeError:
144+
inputs = [x for x in inputs]
134145
np.savez(os.path.join(dir, inputs_name), *inputs)
135146
np.savez(os.path.join(dir, outputs_name), *outputs)
136147
np.save(os.path.join(dir, func_file_name), np.array(func_name))
137148

138149
mlir_str = mlir_data
139150
if self.frontend == "torch":
140151
mlir_str = mlir_data.operation.get_asm()
141-
elif self.frontend == "tf":
142-
mlir_str = mlir_data.decode("latin-1")
143-
elif self.frontend == "tflite":
144-
mlir_str = mlir_data.decode("latin-1")
145-
with open(os.path.join(dir, model_name_mlir), "w") as mlir_file:
146-
mlir_file.write(mlir_str)
152+
with open(os.path.join(dir, model_name_mlir), "w") as mlir_file:
153+
mlir_file.write(mlir_str)
147154

148155
return
149156

@@ -160,9 +167,13 @@ def import_debug(
160167
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
161168
)
162169
sys.exit(1)
163-
170+
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
171+
artifact_path = os.path.join(dir, model_name_mlir)
164172
imported_mlir = self.import_mlir(
165-
is_dynamic, tracing_required, func_name
173+
is_dynamic,
174+
tracing_required,
175+
func_name,
176+
save_dir=artifact_path,
166177
)
167178
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
168179
# TODO: Check for multiple outputs.

tank/test_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(self, config):
131131

132132
def create_and_check_module(self, dynamic, device):
133133
shark_args.local_tank_cache = self.local_tank_cache
134+
shark_args.update_tank = self.update_tank
134135
if self.config["framework"] == "tf":
135136
model, func_name, inputs, golden_out = download_tf_model(
136137
self.config["model_name"],
@@ -266,6 +267,9 @@ def test_module(self, dynamic, device, config):
266267
self.module_tester.local_tank_cache = self.pytestconfig.getoption(
267268
"local_tank_cache"
268269
)
270+
self.module_tester.update_tank = self.pytestconfig.getoption(
271+
"update_tank"
272+
)
269273
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
270274
if (
271275
config["model_name"] == "distilbert-base-uncased"
@@ -350,6 +354,7 @@ def test_module(self, dynamic, device, config):
350354
):
351355
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/390")
352356
if config["model_name"] == "squeezenet1_0" and device in [
357+
"cpu",
353358
"metal",
354359
"vulkan",
355360
]:
@@ -392,6 +397,11 @@ def test_module(self, dynamic, device, config):
392397
"microsoft/resnet-50",
393398
] and device in ["metal", "vulkan"]:
394399
pytest.xfail(reason="Vulkan Numerical Error (mostly conv)")
400+
if config["model_name"] == "mobilenet_v3_small" and device in [
401+
"cuda",
402+
"cpu",
403+
]:
404+
pytest.xfail(reason="https://github.com/nod-ai/SHARK/issues/424")
395405
if config["framework"] == "tf" and dynamic == True:
396406
pytest.skip(
397407
reason="Dynamic shapes not supported for this framework."

0 commit comments

Comments
 (0)