Skip to content

Commit 17dba60

Browse files
authored
Add huggingface top5 image classification automodel (huggingface#268)
1 parent 064aa3b commit 17dba60

File tree

10 files changed

+431
-24
lines changed

10 files changed

+431
-24
lines changed

generate_sharktank.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def create_hash(file_name):
4343
def save_torch_model(torch_model_list):
4444
from tank.model_utils import get_hf_model
4545
from tank.model_utils import get_vision_model
46+
from tank.model_utils import get_hf_img_cls_model
4647

4748
with open(torch_model_list) as csvfile:
4849
torch_reader = csv.reader(csvfile, delimiter=",")
@@ -51,15 +52,19 @@ def save_torch_model(torch_model_list):
5152
torch_model_name = row[0]
5253
tracing_required = row[1]
5354
model_type = row[2]
55+
is_dynamic = row[3]
5456

5557
tracing_required = False if tracing_required == "False" else True
58+
is_dynamic = False if is_dynamic == "False" else True
5659

5760
model = None
5861
input = None
5962
if model_type == "vision":
6063
model, input, _ = get_vision_model(torch_model_name)
6164
elif model_type == "hf":
6265
model, input, _ = get_hf_model(torch_model_name)
66+
elif model_type == "hf_img_cls":
67+
model, input, _ = get_hf_img_cls_model(torch_model_name)
6368

6469
torch_model_name = torch_model_name.replace("/", "_")
6570
torch_model_dir = os.path.join(
@@ -85,12 +90,13 @@ def save_torch_model(torch_model_list):
8590
)
8691
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
8792
# Generate torch dynamic models.
88-
mlir_importer.import_debug(
89-
is_dynamic=True,
90-
tracing_required=tracing_required,
91-
dir=torch_model_dir,
92-
model_name=torch_model_name + "_dynamic",
93-
)
93+
if is_dynamic:
94+
mlir_importer.import_debug(
95+
is_dynamic=True,
96+
tracing_required=tracing_required,
97+
dir=torch_model_dir,
98+
model_name=torch_model_name + "_dynamic",
99+
)
94100

95101

96102
def save_tf_model(tf_model_list):

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ tqdm
77
# SHARK Downloader
88
gsutil
99

10+
# generate_sharktank
11+
transformers==4.18.0
12+
1013
# Testing
1114
pytest
1215
pytest-xdist
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from shark.iree_utils._common import check_device_drivers, device_driver_info
2+
from shark.shark_inference import SharkInference
3+
from shark.shark_downloader import download_torch_model
4+
5+
import unittest
6+
import pytest
7+
import numpy as np
8+
9+
10+
class DeitModuleTester:
11+
def __init__(
12+
self,
13+
benchmark=False,
14+
):
15+
self.benchmark = benchmark
16+
17+
def create_and_check_module(self, dynamic, device):
18+
model, func_name, inputs, golden_out = download_torch_model(
19+
"facebook/deit-small-distilled-patch16-224", dynamic
20+
)
21+
22+
shark_module = SharkInference(
23+
model,
24+
func_name,
25+
device=device,
26+
mlir_dialect="linalg",
27+
is_benchmark=self.benchmark,
28+
)
29+
shark_module.compile()
30+
result = shark_module.forward(inputs)
31+
32+
print(np.allclose(golden_out[0], result[0], rtol=1e-02, atol=1e-03))
33+
34+
35+
class DeitModuleTest(unittest.TestCase):
36+
@pytest.fixture(autouse=True)
37+
def configure(self, pytestconfig):
38+
self.module_tester = DeitModuleTester(self)
39+
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
40+
41+
def test_module_static_cpu(self):
42+
dynamic = False
43+
device = "cpu"
44+
self.module_tester.create_and_check_module(dynamic, device)
45+
46+
@pytest.mark.skipif(
47+
check_device_drivers("gpu"), reason=device_driver_info("gpu")
48+
)
49+
def test_module_static_gpu(self):
50+
dynamic = False
51+
device = "gpu"
52+
self.module_tester.create_and_check_module(dynamic, device)
53+
54+
@pytest.mark.skipif(
55+
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
56+
)
57+
def test_module_static_vulkan(self):
58+
dynamic = False
59+
device = "vulkan"
60+
self.module_tester.create_and_check_module(dynamic, device)
61+
62+
63+
if __name__ == "__main__":
64+
# dynamic = False
65+
# device = "cpu"
66+
# module_tester = DeiteModuleTester()
67+
# module_tester.create_and_check_module(dynamic, device)
68+
unittest.main()

tank/google_vit-base-patch16-224_tf/google_vit-base-patch16-224_tf_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def test_module_static_vulkan(self):
6363

6464

6565
if __name__ == "__main__":
66-
dynamic = False
67-
device = "cpu"
68-
module_tester = VitBaseModuleTester()
69-
module_tester.create_and_check_module(dynamic, device)
70-
# unittest.main()
66+
# dynamic = False
67+
# device = "cpu"
68+
# module_tester = VitBaseModuleTester()
69+
# module_tester.create_and_check_module(dynamic, device)
70+
unittest.main()
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from shark.iree_utils._common import check_device_drivers, device_driver_info
2+
from shark.shark_inference import SharkInference
3+
from shark.shark_downloader import download_torch_model
4+
5+
import unittest
6+
import pytest
7+
import numpy as np
8+
9+
10+
class VitBaseModuleTester:
11+
def __init__(
12+
self,
13+
benchmark=False,
14+
):
15+
self.benchmark = benchmark
16+
17+
def create_and_check_module(self, dynamic, device):
18+
model, func_name, inputs, golden_out = download_torch_model(
19+
"google/vit-base-patch16-224", dynamic
20+
)
21+
22+
shark_module = SharkInference(
23+
model,
24+
func_name,
25+
device=device,
26+
mlir_dialect="linalg",
27+
is_benchmark=self.benchmark,
28+
)
29+
shark_module.compile()
30+
result = shark_module.forward(inputs)
31+
32+
print(np.allclose(golden_out[0], result[0], rtol=1e-02, atol=1e-03))
33+
34+
35+
class VitBaseModuleTest(unittest.TestCase):
36+
@pytest.fixture(autouse=True)
37+
def configure(self, pytestconfig):
38+
self.module_tester = VitBaseModuleTester(self)
39+
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
40+
41+
def test_module_static_cpu(self):
42+
dynamic = False
43+
device = "cpu"
44+
self.module_tester.create_and_check_module(dynamic, device)
45+
46+
@pytest.mark.skipif(
47+
check_device_drivers("gpu"), reason=device_driver_info("gpu")
48+
)
49+
def test_module_static_gpu(self):
50+
dynamic = False
51+
device = "gpu"
52+
self.module_tester.create_and_check_module(dynamic, device)
53+
54+
@pytest.mark.skipif(
55+
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
56+
)
57+
def test_module_static_vulkan(self):
58+
dynamic = False
59+
device = "vulkan"
60+
self.module_tester.create_and_check_module(dynamic, device)
61+
62+
63+
if __name__ == "__main__":
64+
# dynamic = False
65+
# device = "cpu"
66+
# module_tester = VitBaseModuleTester()
67+
# module_tester.create_and_check_module(dynamic, device)
68+
unittest.main()
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from shark.iree_utils._common import check_device_drivers, device_driver_info
2+
from shark.shark_inference import SharkInference
3+
from shark.shark_downloader import download_torch_model
4+
5+
import unittest
6+
import pytest
7+
import numpy as np
8+
9+
10+
class BeitModuleTester:
11+
def __init__(
12+
self,
13+
benchmark=False,
14+
):
15+
self.benchmark = benchmark
16+
17+
def create_and_check_module(self, dynamic, device):
18+
model, func_name, inputs, golden_out = download_torch_model(
19+
"microsoft/beit-base-patch16-224-pt22k-ft22k", dynamic
20+
)
21+
22+
shark_module = SharkInference(
23+
model,
24+
func_name,
25+
device=device,
26+
mlir_dialect="linalg",
27+
is_benchmark=self.benchmark,
28+
)
29+
shark_module.compile()
30+
result = shark_module.forward(inputs)
31+
32+
print(np.allclose(golden_out[0], result[0], rtol=1e-02, atol=1e-03))
33+
34+
35+
class BeitModuleTest(unittest.TestCase):
36+
@pytest.fixture(autouse=True)
37+
def configure(self, pytestconfig):
38+
self.module_tester = BeitModuleTester(self)
39+
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
40+
41+
def test_module_static_cpu(self):
42+
dynamic = False
43+
device = "cpu"
44+
self.module_tester.create_and_check_module(dynamic, device)
45+
46+
@pytest.mark.skipif(
47+
check_device_drivers("gpu"), reason=device_driver_info("gpu")
48+
)
49+
def test_module_static_gpu(self):
50+
dynamic = False
51+
device = "gpu"
52+
self.module_tester.create_and_check_module(dynamic, device)
53+
54+
@pytest.mark.skipif(
55+
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
56+
)
57+
def test_module_static_vulkan(self):
58+
dynamic = False
59+
device = "vulkan"
60+
self.module_tester.create_and_check_module(dynamic, device)
61+
62+
63+
if __name__ == "__main__":
64+
# dynamic = False
65+
# device = "cpu"
66+
# module_tester = BeitModuleTester()
67+
# module_tester.create_and_check_module(dynamic, device)
68+
unittest.main()
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from shark.iree_utils._common import check_device_drivers, device_driver_info
2+
from shark.shark_inference import SharkInference
3+
from shark.shark_downloader import download_torch_model
4+
5+
import unittest
6+
import pytest
7+
import numpy as np
8+
9+
10+
class ResnetModuleTester:
11+
def __init__(
12+
self,
13+
benchmark=False,
14+
):
15+
self.benchmark = benchmark
16+
17+
def create_and_check_module(self, dynamic, device):
18+
model, func_name, inputs, golden_out = download_torch_model(
19+
"microsoft/resnet-50", dynamic
20+
)
21+
22+
shark_module = SharkInference(
23+
model,
24+
func_name,
25+
device=device,
26+
mlir_dialect="linalg",
27+
is_benchmark=self.benchmark,
28+
)
29+
shark_module.compile()
30+
result = shark_module.forward(inputs)
31+
32+
print(np.allclose(golden_out[0], result[0], rtol=1e-01, atol=1e-03))
33+
34+
35+
class ResnetModuleTest(unittest.TestCase):
36+
@pytest.fixture(autouse=True)
37+
def configure(self, pytestconfig):
38+
self.module_tester = ResnetModuleTester(self)
39+
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
40+
41+
def test_module_static_cpu(self):
42+
dynamic = False
43+
device = "cpu"
44+
self.module_tester.create_and_check_module(dynamic, device)
45+
46+
@pytest.mark.skipif(
47+
check_device_drivers("gpu"), reason=device_driver_info("gpu")
48+
)
49+
def test_module_static_gpu(self):
50+
dynamic = False
51+
device = "gpu"
52+
self.module_tester.create_and_check_module(dynamic, device)
53+
54+
@pytest.mark.skipif(
55+
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
56+
)
57+
def test_module_static_vulkan(self):
58+
dynamic = False
59+
device = "vulkan"
60+
self.module_tester.create_and_check_module(dynamic, device)
61+
62+
63+
if __name__ == "__main__":
64+
# dynamic = False
65+
# device = "cpu"
66+
# module_tester = ResnetModuleTester()
67+
# module_tester.create_and_check_module(dynamic, device)
68+
unittest.main()

0 commit comments

Comments
 (0)