1+ #!/usr/bin/env python
2+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3+ # All rights reserved.
4+ #
5+ # This source code is licensed under the BSD-style license found in the
6+ # LICENSE file in the root directory of this source tree.
7+
8+ from typing import List
9+
10+ import sys
11+ import executorch
12+
13+ import torch
14+ from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
15+ from executorch .examples .models import Backend , Model , MODEL_NAME_TO_MODEL
16+ from executorch .examples .models .model_factory import EagerModelFactory
17+ from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS
18+ from executorch .examples .xnnpack .quantization .utils import quantize as quantize_xnn
19+ from executorch .exir import EdgeCompileConfig , to_edge_transform_and_lower
20+ from executorch .extension .pybindings .portable_lib import (
21+ _load_for_executorch_from_buffer ,
22+ )
23+ from test_base import ModelTest
24+
25+
26+ def test_model_xnnpack (model : Model , quantize : bool ) -> None :
27+ model_instance , example_inputs , _ , _ = EagerModelFactory .create_model (
28+ * MODEL_NAME_TO_MODEL [str (model )]
29+ )
30+
31+ model_instance .eval ()
32+ ref_outputs = model_instance (* example_inputs )
33+
34+ if quantize :
35+ quant_type = MODEL_NAME_TO_OPTIONS [str (model )].quantization
36+ model_instance = torch .export .export_for_training (
37+ model_instance , example_inputs
38+ )
39+ model_instance = quantize_xnn (
40+ model_instance .module (), example_inputs , quant_type
41+ )
42+
43+ lowered = to_edge_transform_and_lower (
44+ torch .export .export (model_instance , example_inputs ),
45+ partitioner = [XnnpackPartitioner ()],
46+ compile_config = EdgeCompileConfig (
47+ _check_ir_validity = False ,
48+ ),
49+ ).to_executorch ()
50+
51+ loaded_model = _load_for_executorch_from_buffer (lowered .buffer )
52+ et_outputs = loaded_model ([* example_inputs ])
53+
54+ if isinstance (ref_outputs , torch .Tensor ):
55+ ref_outputs = (ref_outputs ,)
56+
57+ assert len (ref_outputs ) == len (et_outputs )
58+ for i in range (len (ref_outputs )):
59+ assert torch .allclose (ref_outputs [i ], et_outputs [i ], atol = 1e-5 )
60+
61+
62+ def run_tests (model_tests : List [ModelTest ]) -> None :
63+ for model_test in model_tests :
64+ if model_test .backend == Backend .Xnnpack :
65+ test_model_xnnpack (model_test .model , quantize = False )
66+ else :
67+ raise RuntimeError (f"Unsupported backend { model_test .backend } ." )
68+
69+
70+ if __name__ == "__main__" :
71+ run_tests (
72+ model_tests = [
73+ ModelTest (
74+ model = Model .Mv3 ,
75+ backend = Backend .Xnnpack ,
76+ ),
77+ ]
78+ )
0 commit comments