Skip to content

Commit 40eea21

Browse files
authored
Enable conv nchw-to-nhwc flag by default for most models + minor fixes (huggingface#584)
1 parent d2475ec commit 40eea21

File tree

4 files changed

+48
-10
lines changed

4 files changed

+48
-10
lines changed

shark/iree_utils/compile_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import iree.compiler as ireec
1616
from shark.iree_utils._common import iree_device_map, iree_target_map
1717
from shark.iree_utils.benchmark_utils import *
18+
from shark.parser import shark_args
1819
import numpy as np
1920
import os
2021
import re
@@ -66,6 +67,16 @@ def get_iree_common_args():
6667
]
6768

6869

70+
# Args that are suitable only for certain models or groups of models.
71+
# shark_args are passed down from pytests to control which models compile with these flags,
72+
# but they can also be set in shark/parser.py
73+
def get_model_specific_args():
74+
ms_args = []
75+
if shark_args.enable_conv_transform == True:
76+
ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
77+
return ms_args
78+
79+
6980
def create_dispatch_dirs(bench_dir, device):
7081
protected_files = ["ordered-dispatches.txt"]
7182
bench_dir_path = bench_dir.split("/")
@@ -213,14 +224,22 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
213224

214225

215226
def compile_module_to_flatbuffer(
216-
module, device, frontend, func_name, model_config_path, extra_args
227+
module,
228+
device,
229+
frontend,
230+
func_name,
231+
model_config_path,
232+
extra_args,
233+
model_name="None",
217234
):
218235
# Setup Compile arguments wrt to frontends.
219236
input_type = ""
220237
args = get_iree_frontend_args(frontend)
221238
args += get_iree_device_args(device, extra_args)
222239
args += get_iree_common_args()
240+
args += get_model_specific_args()
223241
args += extra_args
242+
print(args)
224243

225244
if frontend in ["tensorflow", "tf"]:
226245
input_type = "mhlo"

shark/parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,11 @@ def dir_file(path):
105105
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
106106
)
107107

108+
parser.add_argument(
109+
"--enable_conv_transform",
110+
default=True,
111+
action="store",
112+
help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.",
113+
)
114+
108115
shark_args, unknown = parser.parse_known_args()

tank/model_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ def get_vision_model(torch_model):
143143
import torchvision.models as models
144144

145145
vision_models_dict = {
146-
"alexnet": models.alexnet(pretrained=True),
147-
"resnet18": models.resnet18(pretrained=True),
148-
"resnet50": models.resnet50(pretrained=True),
149-
"resnet101": models.resnet101(pretrained=True),
150-
"squeezenet1_0": models.squeezenet1_0(pretrained=True),
151-
"wide_resnet50_2": models.wide_resnet50_2(pretrained=True),
152-
"mobilenet_v3_small": models.mobilenet_v3_small(pretrained=True),
153-
"mnasnet1_0": models.mnasnet1_0(pretrained=True),
146+
"alexnet": models.alexnet(weights="DEFAULT"),
147+
"resnet18": models.resnet18(weights="DEFAULT"),
148+
"resnet50": models.resnet50(weights="DEFAULT"),
149+
"resnet101": models.resnet101(weights="DEFAULT"),
150+
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
151+
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
152+
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
153+
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
154154
}
155155
if isinstance(torch_model, str):
156156
torch_model = vision_models_dict[torch_model]

tank/test_models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,11 @@ def __init__(self, config):
127127
self.config = config
128128

129129
def create_and_check_module(self, dynamic, device):
130+
130131
shark_args.local_tank_cache = self.local_tank_cache
131132
shark_args.update_tank = self.update_tank
133+
if self.config["model_name"] in ["alexnet", "resnet18"]:
134+
shark_args.enable_conv_transform = False
132135
model, func_name, inputs, golden_out = download_model(
133136
self.config["model_name"],
134137
tank_url=self.tank_url,
@@ -347,7 +350,16 @@ def test_module(self, dynamic, device, config):
347350
pytest.xfail(
348351
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
349352
)
350-
if config["model_name"] == "mobilenet_v3_small":
353+
if config["model_name"] == "mobilenet_v3_small" and device not in [
354+
"cpu"
355+
]:
356+
pytest.xfail(
357+
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
358+
)
359+
if config["model_name"] == "mnasnet1_0" and device not in [
360+
"cpu",
361+
"cuda",
362+
]:
351363
pytest.xfail(
352364
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
353365
)

0 commit comments

Comments
 (0)