Skip to content

Commit d2475ec

Browse files
authored
Add mnasnet to torch models and minor fixes. (huggingface#577)
* Minor fixes to benchmark runner * Add Mnasnet to tank.
1 parent b3bcf4b commit d2475ec

File tree

6 files changed

+24
-24
lines changed

6 files changed

+24
-24
lines changed

shark/shark_benchmark_runner.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -260,19 +260,12 @@ def get_metadata(self, modelname):
260260
return [param_count, model_tags, model_notes]
261261

262262
def compare_bench_results(self, baseline: str, result: str):
263-
# Takes two numbers represented as strings and returns "<n>x slower/faster", as in "result is <n>x slower than baseline".
263+
# Takes a baseline and a result string and calculates a comparison, e.g. "1.04x baseline".
264264
a = float(baseline)
265265
b = float(result)
266-
if a < b:
267-
# result slower than baseline
268-
comparison = (b - a) / a
269-
comp_str = f"{round(comparison, 2)}x slower"
270-
elif a > b:
271-
# result faster than baseline
272-
comparison = a / b
273-
comp_str = f"{round(comparison, 2)}x faster"
274-
else:
275-
comp_str = "equal"
266+
# result faster than baseline
267+
comparison = a / b
268+
comp_str = f"{round(comparison, 2)}x baseline"
276269
return comp_str
277270

278271
def benchmark_all_csv(
@@ -327,7 +320,7 @@ def benchmark_all_csv(
327320
bench_result["ms/iter"],
328321
) = self.benchmark_frontend(modelname)
329322
self.frontend_result = bench_result["ms/iter"]
330-
bench_result["vs. PyTorch/TF"] = "="
323+
bench_result["vs. PyTorch/TF"] = "baseline"
331324
(
332325
bench_result["param_count"],
333326
bench_result["tags"],

tank/all_models.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ resnet50,linalg,torch,1e-2,1e-3,default
3232
squeezenet1_0,linalg,torch,1e-2,1e-3,default
3333
wide_resnet50_2,linalg,torch,1e-2,1e-3,default
3434
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default
35+
mnasnet1_0,linalg,torch,1e-2,1e-3,default

tank/model_metadata.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ roberta-base,False,False,-,-,-
2828
xlm-roberta-base,False,False,-,-,-
2929
facebook/convnext-tiny-224,False,False,-,-,-
3030
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
31-
31+
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"

tank/model_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"squeezenet1_0",
1616
"wide_resnet50_2",
1717
"mobilenet_v3_small",
18+
"mnasnet1_0",
1819
]
1920
hf_img_cls_models = [
2021
"google/vit-base-patch16-224",
@@ -149,6 +150,7 @@ def get_vision_model(torch_model):
149150
"squeezenet1_0": models.squeezenet1_0(pretrained=True),
150151
"wide_resnet50_2": models.wide_resnet50_2(pretrained=True),
151152
"mobilenet_v3_small": models.mobilenet_v3_small(pretrained=True),
153+
"mnasnet1_0": models.mnasnet1_0(pretrained=True),
152154
}
153155
if isinstance(torch_model, str):
154156
torch_model = vision_models_dict[torch_model]
@@ -160,6 +162,8 @@ def get_vision_model(torch_model):
160162

161163
################################################################################
162164

165+
####################### Other PyTorch HF Models ###############################
166+
163167
# Utility function for comparing two tensors (torch).
164168
def compare_tensors(torch_tensor, numpy_tensor, rtol=1e-02, atol=1e-03):
165169
# torch_to_numpy = torch_tensor.detach().numpy()

tank/test_models.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,21 @@ def postprocess_outputs(self, golden_out, result):
227227

228228

229229
def run_test(module_tester, dynamic, device):
230+
import multiprocessing
231+
230232
tempdir = tempfile.TemporaryDirectory(
231233
prefix=module_tester.tmp_prefix, dir="./shark_tmp/"
232234
)
233235
module_tester.temp_dir = tempdir.name
234236

235237
with ireec.tools.TempFileSaver(tempdir.name):
236-
module_tester.create_and_check_module(dynamic, device)
238+
p = multiprocessing.Process(
239+
target=module_tester.create_and_check_module,
240+
args=(dynamic, device),
241+
)
242+
p.start()
243+
p.join()
244+
return p
237245

238246

239247
class SharkModuleTest(unittest.TestCase):
@@ -339,10 +347,7 @@ def test_module(self, dynamic, device, config):
339347
pytest.xfail(
340348
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
341349
)
342-
if config["model_name"] == "mobilenet_v3_small" and device in [
343-
"metal",
344-
"vulkan",
345-
]:
350+
if config["model_name"] == "mobilenet_v3_small":
346351
pytest.xfail(
347352
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
348353
)
@@ -417,9 +422,5 @@ def test_module(self, dynamic, device, config):
417422
# We must create a new process each time we benchmark a model to allow
418423
# for Tensorflow to release GPU resources. Using the same process to
419424
# benchmark multiple models leads to OOM.
420-
p = multiprocessing.Process(
421-
target=run_test, args=(self.module_tester, dynamic, device)
422-
)
423-
p.start()
424-
p.join()
425-
assert not p.exitcode
425+
426+
run_test(self.module_tester, dynamic, device)

tank/torch_model_list.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ microsoft/resnet-50,True,hf_img_cls,False,23M,"image-classification,cnn,residual
1616
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-classification,vision-transformer,cnn",N/A
1717
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
1818
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
19+
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"

0 commit comments

Comments
 (0)