From e9076fe20f6db2900b4c3e6ceebdea5baebcdc80 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 01:47:11 -0700 Subject: [PATCH 01/12] [PT2E][X86] Migrate fusion passes in Inductor to torchao --- .../pt2e/test_x86inductor_fusion.py | 2578 +++++++++++++++ .../inductor/fx_passes/quantization.py | 2854 +++++++++++++++++ torchao/quantization/pt2e/pt2e/lowering.py | 31 +- 3 files changed, 5461 insertions(+), 2 deletions(-) create mode 100644 test/quantization/pt2e/test_x86inductor_fusion.py create mode 100644 torchao/prototype/inductor/fx_passes/quantization.py diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py new file mode 100644 index 0000000000..9073b1facb --- /dev/null +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -0,0 +1,2578 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# Owner(s): ["oncall: quantization"] +import contextlib +import copy +import itertools +import unittest + +import torch +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +from torch._dynamo import config as dynamo_config +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_code +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.testing._internal.common_quantization import ( + _generate_qdq_quantized_model, + skipIfNoDynamoSupport, + skipIfNoONEDNN, + skipIfNoONEDNNBF16, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + IS_FBCODE, + IS_LINUX, + IS_X86, + MI300_ARCH, + parametrize, + skipIfRocm, + skipIfRocmArch, + TEST_ACL, + xfailIfACL, +) +from torch.testing._internal.inductor_utils import ( + _check_has_dynamic_shape, + clone_preserve_strides_offset, + HAS_CPU, +) +from torchao.quantization.pt2e.pt2e.lowering import lower_pt2e_quantized_to_x86 + + +# The dict value is match_nodes(computation_op+unary_op) +unary_list = { + torch.nn.ReLU(): 2, + torch.nn.Sigmoid(): 2, + torch.nn.Tanh(): 2, + torch.nn.Hardswish(): 6, + torch.nn.LeakyReLU(0.1, inplace=False): 4, + # Use floats for min/max, otherwise they can get converted to symints + torch.nn.Hardtanh(min_val=-0.5, max_val=4.0, inplace=False): 3, + torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3, + torch.nn.GELU(approximate="none"): 6, + torch.nn.GELU(approximate="tanh"): 10, + torch.nn.ReLU6(): 3, + torch.nn.SiLU(): 3, + torch.nn.Hardsigmoid(): 5, +} + +non_decomposed_unary_list = [ + torch.nn.ReLU, + torch.nn.Sigmoid, + torch.nn.Tanh, +] + +# The dict value is (match_count, match_nodes, inplace) +binary_list = { + lambda x, y: torch.add(x, y): (1, 2, False), # call_function + lambda x, y: torch.add(y, x): (1, 2, False), # call_function + lambda x, y: x.add(y): (1, 2, False), # call_method + lambda x, y: x.add_(y): (1, 2, True), # call_method + lambda x, y: torch.sub(x, y): (1, 2, False), # call_function + lambda x, y: x.sub(y): (1, 2, False), # call_method + lambda x, y: x.sub_(y): (1, 2, True), # call_method +} + +quantization_add_fn_list = [ + lambda x, y: torch.add(x, y), + lambda x, y: x.add(y), +] + +quantization_inplace_add_fn_list = [ + lambda x, y: x.add_(y), +] + + +def get_default_quantizer(is_qat, is_dynamic): + quantizer = X86InductorQuantizer() + quantizer.set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, is_dynamic=is_dynamic + ) + ) + return quantizer + + +def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"): + # this function is to decide how many kernels are generated + # while testing conv2d/3d/deconv2d + # the assumption is: + # (1) There will be a to_dtype kernel for input for lp + # (2) inductor always use channel_last format, there will + # be a to_channel_last format for input + # (3) to_dtype and to_channel_last for input can be fused + # (4) inductor always get channel last format from mkldnn_conv_pointwise(binary), + # and force the output to have same stride with eager. + # So there will be a to_contiguous for output if eager output is contiguouse + mod = copy.deepcopy(mod) + mod = mod.to(device=device) + input = input.clone() + input = input.to(device) + + if dtype == torch.float32: + maybe_autocast = contextlib.nullcontext() + else: + maybe_autocast = torch.amp.autocast(device_type=device, dtype=dtype) + with torch.no_grad(), maybe_autocast: + output = mod(input) + input_kernel, output_kernel = 0, 0 + if ( + input.is_contiguous(memory_format=torch.contiguous_format) + or dtype != torch.float32 + or (TEST_ACL and dim == 4) + ): + input_kernel = 1 + if output.is_contiguous(memory_format=torch.contiguous_format) or ( + TEST_ACL and dtype == torch.bfloat16 + ): + output_kernel = 1 + return input_kernel + output_kernel + + +@config.patch({"freezing": True}) +class TestPatternMatcherBase(TestCase): + def _check_unary_is_decomposed(self, unary_fn): + return not any( + isinstance(unary_fn, fn) + for fn in [torch.nn.ReLU, torch.nn.Sigmoid, torch.nn.Tanh] + ) + + def _clone_inputs(self, inputs): + def clone(x): + if not isinstance(x, torch.Tensor): + return x + return x.clone() + + return tuple(clone(x) for x in inputs) + + def _test_common( + self, + mod, + inputs, + matcher_check_fn, + atol=1e-5, + rtol=1.3e-6, + check_autocast=torch.float32, + check_quantization=False, + is_qat=False, + dtype=None, + is_dynamic=False, + quantizer=None, + compile_options={}, # noqa: B006 + ): + if not hasattr(self, "device"): + has_xpu = any( + isinstance(input, torch.Tensor) and input.device.type == "xpu" + for input in inputs + ) + device = "xpu" if has_xpu else "cpu" + else: + device = self.device + + mod = mod.to(device=device) + if device != "cpu": + inputs = tuple( + clone_preserve_strides_offset(x, device=device) for x in inputs + ) + counters.clear() + torch._dynamo.reset() + if check_autocast == torch.bfloat16 and ( + torch.ops.mkldnn._is_mkldnn_bf16_supported() or device == "xpu" + ): + maybe_autocast = torch.amp.autocast( + device_type=device, dtype=torch.bfloat16 + ) + atol, rtol = 1e-2, 1e-2 + elif check_autocast == torch.float16 and ( + torch.ops.mkldnn._is_mkldnn_fp16_supported() or device == "xpu" + ): + maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16) + atol, rtol = 1e-2, 1e-2 + else: + assert check_autocast == torch.float32 + maybe_autocast = contextlib.nullcontext() + if check_quantization: + convert_model = _generate_qdq_quantized_model( + mod, inputs, is_qat, is_dynamic, quantizer + ) + with torch.no_grad(), maybe_autocast: + _ = lower_pt2e_quantized_to_x86(convert_model)(*inputs) + matcher_check_fn() + else: + with torch.no_grad(), maybe_autocast: + clone_inputs = self._clone_inputs(inputs) + expected = mod(*inputs) + actual = lower_pt2e_quantized_to_x86(mod, **compile_options)( + *clone_inputs + ) + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + matcher_check_fn() + + def _test_code_common( + self, + mod, + inputs, + include_ops, + exclude_ops, + atol=1e-5, + rtol=1.3e-6, + check_quantization=False, + check_dynamic=None, + num_include_ops=None, + quantizer=None, + ): + with torch.no_grad(): + clone_inputs = self._clone_inputs(inputs) + if check_quantization: + mod = _generate_qdq_quantized_model(mod, inputs, quantizer=quantizer) + expected = mod(*inputs) + actual, (source_code,) = run_and_get_code( + lower_pt2e_quantized_to_x86(mod, fullgraph=True, dynamic=check_dynamic), + *clone_inputs, + ) + for op in include_ops: + self.assertIn(op, source_code) + if num_include_ops is not None: + assert len(include_ops) == len(num_include_ops) + for i in range(len(include_ops)): + self.assertEqual( + source_code.count(include_ops[i]), num_include_ops[i] + ) + for op in exclude_ops: + self.assertNotIn(op, source_code) + if check_dynamic is not None: + _check_has_dynamic_shape(self, source_code) + if not check_quantization: + # Skip due to reduce range setting for Quantization on preCI system. + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + + +class TestPatternMatcher(TestPatternMatcherBase): + def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) + self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1) + self.conv3 = torch.nn.Conv2d( + 128, 128, kernel_size=3, stride=1, groups=4 + ) + + def forward(self, x): + return self.conv3(self.conv2(self.conv(x))) + + mod = M().eval().to(device=device) + v = ( + torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False) + .add(1) + .to(device=device) + ) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1 + # int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution] + # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), + # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 + ) + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], + 18 if int8_mixed_bf16 else 12, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qconv2d_cpu(self): + r""" + This testcase will quantize a single Conv2d module. + """ + self._qconv2d_test_helper("cpu") + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfRocmArch(MI300_ARCH) + def test_qconv2d_int8_mixed_bf16(self): + r""" + This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. + """ + self._qconv2d_test_helper(int8_mixed_bf16=True) + + def _qconv2d_unary_test_helper( + self, + device="cpu", + int8_mixed_bf16=False, + unary_op=torch.nn.ReLU(), + qconv_unary_matcher_nodes=None, + ): + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) + self.unary_fn = copy.deepcopy(unary_op) + self.conv2 = torch.nn.Conv2d( + 128, 128, kernel_size=3, stride=1, bias=False + ) + self.unary_fn2 = copy.deepcopy(unary_op) + + def forward(self, x): + tmp = self.unary_fn(self.conv(x)) + return self.unary_fn2(self.conv2(tmp)) + + mod = M().eval().to(device=device) + v = ( + torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False) + .add(1) + .to(device=device) + ) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + # 2. QConv2D Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 + ) + if qconv_unary_matcher_nodes: + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_nodes"], + 0 if TEST_ACL else qconv_unary_matcher_nodes, + ) + + self._test_common( + mod, + (v,), + check_quantization=True, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + matcher_check_fn=matcher_check_fn, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_relu_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern. + """ + self._qconv2d_unary_test_helper(device="cpu") + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qconv2d_relu_int8_mixed_bf16_xpu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization. + """ + self._qconv2d_unary_test_helper(int8_mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_relu6_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_hardtanh_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardtanh(), + int8_mixed_bf16=True, + qconv_unary_matcher_nodes=11, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_hardswish_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, + clamp_max, mul, div, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardswish(), + int8_mixed_bf16=True, + qconv_unary_matcher_nodes=17, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_silu_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qconv2d_silu_int8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, + convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.SiLU(), + int8_mixed_bf16=True, + qconv_unary_matcher_nodes=11, + ) + + def _qconv2d_add_test_helper( + self, device="cpu", use_relu=False, int8_mixed_bf16=False + ): + r""" + This testcase will quantize a Conv2d->Add pattern as: + X + / \ + Conv1(X) Conv2(X) + \ / + Add + | + Optional(relu) + | + Y + """ + + class M(torch.nn.Module): + def __init__( + self, + add_fn, + use_relu, + **kwargs, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.add_fn = add_fn + self.relu = torch.nn.ReLU() + self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False) + self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU() + self.use_relu = use_relu + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + tmp1 = self.conv3(tmp) + tmp2 = self.conv4(tmp) + res = self.add_fn2(tmp1, tmp2) + if self.use_relu: + res = self.relu2(res) + return res + + for add_fn in quantization_add_fn_list + quantization_inplace_add_fn_list: + mod = M(add_fn, use_relu).eval().to(device=device) + v = ( + torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False) + .add(1) + .to(device=device) + ) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 4 + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 4 + ) + # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_lower_count"], + 0 if TEST_ACL else 2, + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + ) + + def _qconv2d_add_test_helper2( + self, device="cpu", use_relu=False, int8_mixed_bf16=False + ): + r""" + This testcase will quantize two Conv2d->Add patterns as: + + Conv(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + , and + + extra input Conv(X) + \ / + Add + | + Optional(relu) + | + Y + """ + + class M(torch.nn.Module): + def __init__( + self, + add_fn, + use_relu, + swap_inputs, + **kwargs, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.add_fn = add_fn + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU() + self.use_relu = use_relu + self.swap_inputs = swap_inputs + + def forward(self, x, x2, x3): + x1 = self.conv1(x) + if self.swap_inputs: + tmp = self.add_fn(x2, x1) + else: + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + tmp1 = self.conv2(tmp) + if self.swap_inputs: + res = self.add_fn2(x3, tmp1) + else: + res = self.add_fn2(tmp1, x3) + if self.use_relu: + res = self.relu2(res) + return res + + for add_fn, swap_inputs in itertools.product( + quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True] + ): + mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device) + x = torch.randn( + (1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device + ) + x2 = torch.randn( + (1, 6, 6, 6), dtype=torch.float32, requires_grad=False, device=device + ) + x3 = torch.randn( + (1, 6, 4, 4), dtype=torch.float32, requires_grad=False, device=device + ) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_lower_count"], + 0 if TEST_ACL else 2, + ) + + self._test_common( + mod, + (x, x2, x3), + matcher_check_fn, + check_quantization=True, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_add_cpu(self): + self._qconv2d_add_test_helper() + self._qconv2d_add_test_helper2() + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qconv2d_add_int8_mixed_bf16(self): + self._qconv2d_add_test_helper(int8_mixed_bf16=True) + self._qconv2d_add_test_helper2(int8_mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_add_relu_cpu(self): + self._qconv2d_add_test_helper(use_relu=True) + self._qconv2d_add_test_helper2(use_relu=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qconv2d_add_relu_int8_mixed_bf16(self): + self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True) + self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_add_broadcast_shapes_cpu(self): + r""" + This testcase will quantize Conv2d->add pattern using broadcast shape inputs. + Conv2d->Add fusion will fail for the broadcast shape inputs case. + """ + + class M(torch.nn.Module): + def __init__(self, use_bias): + super().__init__() + self.conv = torch.nn.Conv2d(32, 32, kernel_size=3, stride=1) + + def forward(self, x1, x2): + return torch.add(self.conv(x1), x2) + + bias_list = [True, False] + for bias in bias_list: + mod = M(bias).eval() + x1 = torch.randn((2, 32, 9, 9)) + x2 = torch.randn((2, 32, 1, 1)) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 1 + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 + ) + # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 0 + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], 0 + ) + + self._test_common( + mod, + (x1, x2), + matcher_check_fn, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_with_concat_cpu(self): + channel_1 = 32 + channel_2 = 16 + channel_3 = 8 + channel_4 = int(channel_2 * 2 + channel_3) + + class Model(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d( + channel_1, channel_2, 1, stride=1, dilation=1, padding=0 + ) + self.conv2 = torch.nn.Conv2d( + channel_1, channel_2, 1, stride=1, dilation=1, padding=0 + ) + self.conv3 = torch.nn.Conv2d( + channel_2, channel_3, 3, stride=1, dilation=1, padding=1 + ) + + self.conv = torch.nn.Conv2d( + channel_4, channel_2, 1, stride=1, dilation=1, padding=0 + ) + + def forward(self, x: torch.Tensor): + x1 = self.conv1(x) + x2 = self.conv2(x) + x3 = self.conv3(x2) + res = torch.cat([x1, x2, x3], dim=1) + res = self.conv(res) + return res + + mod = Model().eval() + v = torch.randn( + (8, channel_1, 40, 40), dtype=torch.float32, requires_grad=False + ) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 4 + ) + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 3, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 4 + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_add_2(self): + r""" + This testcase prevents this pattern be matched as a conv_binary fusion by mistake. + Conv(X) 3 + \ / + Add + We see this pattern in Mobilenet v3 large which add is decomposed from torch.nn.Hardswish or torch.nn.Hardsigmoid. + """ + + class M(torch.nn.Module): + def __init__( + self, + post_op, + ): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.post_op = post_op + + def forward(self, x): + return self.post_op(self.conv(x)) + + for post_op in [ + torch.nn.Hardswish(inplace=True), + torch.nn.Hardsigmoid(inplace=True), + ]: + mod = M(post_op).eval() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( + 1 + ) + + def matcher_check_fn(): + # Shouldn't hit conv binary fusion + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], 0 + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv2d_add_3(self): + r""" + This testcase will test below model: + x + / \ + conv1 maxpool + \ / \ + add conv2 + \ / + cat + Based on default recipe of x86InductorQuantizer, we will see this pattern after convert: + qconv1 maxpool + \ | + \ q1 + \ / \ + \ dq1 qconv2 + \ / + add + | + q2 + Since q1 has 2 users and qconv2 is not ancestor node of qconv1, we shouldn't fuse: + int8 + / + qconv1 dq1 + \ / + add + | + q2 + | + int8 + Instead we can match and fuse this pattern into qconv_binary: + qconv1 fp32 + \ / + add + | + fp32 + """ + + class M(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) + self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1) + self.maxpool = torch.nn.MaxPool2d( + kernel_size=3, stride=1, padding=0, dilation=1 + ) + + def forward(self, x): + tmp1 = self.conv1(x) + tmp2 = self.maxpool(x) + add = torch.add(tmp1, tmp2) + tmp3 = self.conv2(tmp2) + return torch.cat((add, tmp3), dim=1) + + mod = M().eval() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], + 0 if TEST_ACL else 1, + ) + # The matched qconv binary pattern should have 2 nodes [qconv, add] + # instead of 11 which has dequant in binary input and output quant + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_nodes"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_lower_count"], + 0 if TEST_ACL else 1, + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d(self): + r""" + This testcase will quantize a single Conv2d module with qat flow. + """ + + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) + self.bn = torch.nn.BatchNorm2d(128) + + def forward(self, x): + return self.bn(self.conv(x)) + + mod = M().train() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) + + def matcher_check_fn(): + # 1. Dequant-conv pattern matched in quantization weight prepack * 1 + # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 + ) + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 4 + ) + # 2. QConv2D Unary fusion in post-grad fusion pass * 1 + # [qconv2d_pointwise_default, quantize_per_tensor] + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 1, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_nodes"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 1 + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + is_qat=True, + ) + + def _qat_qconv2d_unary_cpu_test_helper( + self, + unary_op=torch.nn.ReLU(), + ): + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) + self.unary_fn = copy.deepcopy(unary_op) + self.bn = torch.nn.BatchNorm2d(3) + self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) + self.unary_fn2 = copy.deepcopy(unary_op) + self.bn2 = torch.nn.BatchNorm2d(3) + + def forward(self, x): + tmp = self.unary_fn(self.bn(self.conv(x))) + return self.unary_fn2(self.bn2(self.conv2(tmp))) + + mod = M() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) + + def matcher_check_fn(): + # 1. Dequant-conv pattern matched in quantization weight prepack * 1 + # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + # 2. QConv2D Unary fusion in post-grad fusion pass * 1 + # [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + is_qat=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qat_qconv2d_relu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern with qat flow. + """ + + self._qat_qconv2d_unary_cpu_test_helper() + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qat_qconv2d_relu6(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern with qat flow. + """ + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qat_qconv2d_hardtanh(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern with qat flow. + """ + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qat_qconv2d_silu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern with qat flow. + """ + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qat_qconv2d_hardswish(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern with qat flow. + """ + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d_add(self): + r""" + This testcase will quantize a Conv2d->Add pattern as: + X + / \ + Conv1(X) Conv2(X) + \ / + Add + | + Y + """ + + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.bn1 = torch.nn.BatchNorm2d(6) + self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.bn2 = torch.nn.BatchNorm2d(6) + + def forward(self, x): + x1 = self.bn1(self.conv1(x)) + x2 = self.bn2(self.conv2(x)) + return x1 + x2 + + mod = M().train() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) + + def matcher_check_fn(): + # 1. Dequant-conv pattern matched in quantization weight prepack * 2 + # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 8 + ) + # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 + # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor] + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], + 0 if TEST_ACL else 1, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_nodes"], + 0 if TEST_ACL else 4, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_lower_count"], + 0 if TEST_ACL else 1, + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + is_qat=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d_add_relu(self): + r""" + This testcase will quantize a Conv2d->Add->ReLU pattern as: + X + / \ + Conv1(X) Conv2(X) + \ / + Add + | + ReLU + | + Y + """ + + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.bn1 = torch.nn.BatchNorm2d(6) + self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.bn2 = torch.nn.BatchNorm2d(6) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x1 = self.bn1(self.conv1(x)) + x2 = self.bn2(self.conv2(x)) + return self.relu(x1 + x2) + + mod = M().train() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) + + def matcher_check_fn(): + # 1. Dequant-conv pattern matched in quantization weight prepack * 2 + # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 8 + ) + # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 + # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor] + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], + 0 if TEST_ACL else 1, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_nodes"], + 0 if TEST_ACL else 5, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_lower_count"], + 0 if TEST_ACL else 1, + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + is_qat=True, + ) + + def _test_qconv2d_dequant_promotion_helper(self, device="cpu"): + r""" + This testcase tests if dequant node before conv2d is promoted correctly: + X + | + Conv1(X) + / \ + Conv2(X) Conv3(X) + \ / + Add + | + Y + """ + + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.conv2 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + + def forward(self, x): + temp = self.conv1(x) + temp = self.conv2(temp) + self.conv3(temp) + return temp + + mod = M().eval().to(device=device) + v = ( + torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False) + .add(1) + .to(device=device) + ) + + def matcher_check_fn(): + # 1. Dequant pattern matcher for dequant promotion * 1 + # [dequantize_per_tensor] + self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) + self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1) + # 2. Dequant-conv pattern matched in quantization weight prepack * 3 + # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 + ) + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 12 + ) + # 3. Qconv2d Binary fusion in post-grad fusion pass * 1 + # [qconv2d_pointwise_default_1, add_3] + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], + 0 if TEST_ACL else 1, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_nodes"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv2d_binary_lower_count"], + 0 if TEST_ACL else 1, + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qconv2d_dequant_promotion_cpu(self): + self._test_qconv2d_dequant_promotion_helper() + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv1d_relu_cpu(self): + r""" + This testcase will quantize Conv1d->ReLU pattern. + """ + device = "cpu" + unary_op = torch.nn.ReLU() + + class M(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv1d(3, 128, kernel_size=3, stride=1) + self.unary_fn = copy.deepcopy(unary_op) + self.conv2 = torch.nn.Conv1d( + 128, 128, kernel_size=3, stride=1, bias=False + ) + self.unary_fn2 = copy.deepcopy(unary_op) + + def forward(self, x): + tmp = self.unary_fn(self.conv(x)) + return self.unary_fn2(self.conv2(tmp)) + + mod = M().eval().to(device=device) + v = ( + torch.randn((1, 3, 8), dtype=torch.float32, requires_grad=False) + .add(1) + .to(device=device) + ) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + # 2. QConv2D Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 + ) + + self._test_common( + mod, + (v,), + check_quantization=True, + matcher_check_fn=matcher_check_fn, + ) + + def _qlinear_test_helper( + self, + inputs, + device="cpu", + int8_mixed_bf16=False, + do_permute=False, + matcher_check_fn=None, + bias=True, + is_dynamic=False, + is_qat=False, + ): + class M(torch.nn.Module): + def __init__(self, use_bias, do_permute=False): + super().__init__() + self.linear = torch.nn.Linear(4, 3, use_bias) + self.linear2 = torch.nn.Linear(3, 4, use_bias) + self.do_permute = do_permute + + def forward(self, x): + if self.do_permute: + x = torch.reshape(torch.permute(x, (0, 2, 3, 1)), (2, 12, 4)) + return self.linear2(self.linear(x)) + + mod = M(bias, do_permute=do_permute).eval().to(device=device) + assert isinstance(inputs, tuple) + + def __convert_tensor_to_device(input, device): + return input.to(device=device) if isinstance(input, torch.Tensor) else input + + inputs = tuple(__convert_tensor_to_device(input, device) for input in inputs) + + def _default_matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + + self._test_common( + mod, + inputs, + matcher_check_fn=( + matcher_check_fn + if matcher_check_fn is not None + else _default_matcher_check_fn + ), + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_quantization=True, + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_cpu(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_dynamic_qlinear_cpu(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 4)),), bias=bias, is_dynamic=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_dynamic_qlinear_qat_cpu(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_dynamic_qlinear_input_dim_exceeds_2(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 3, 4)),), bias=bias, is_dynamic=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_int8_mixed_bf16(self): + r""" + This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 4)),), int8_mixed_bf16=True, bias=bias + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_input_dim_exceeds_2(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self): + r""" + This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 3, 4)),), int8_mixed_bf16=True, bias=bias + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self): + r""" + This testcase will quantize a single Linear Module. + * Input dim exceeds 2 + * Input not contiguous + """ + for bias in [True, False]: + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 13 if bias else 12, + ) + + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): + r""" + This testcase will quantize a single Linear Module for int8_bf16. + * Input dim exceeds 2 + * Input not contiguous + """ + for bias in [True, False]: + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 17 if bias else 16, + ) + + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + int8_mixed_bf16=True, + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + ) + + def _qlinear_unary_test_helper( + self, inputs, unary_op=torch.nn.ReLU(), device="cpu", int8_mixed_bf16=False + ): + class M(torch.nn.Module): + def __init__(self, use_bias): + super().__init__() + self.linear = torch.nn.Linear(4, 4, use_bias) + self.unary_fn = copy.deepcopy(unary_op) + self.linear2 = torch.nn.Linear(4, 4, use_bias) + self.unary_fn2 = copy.deepcopy(unary_op) + + def forward(self, x): + tmp = self.unary_fn(self.linear(x)) + return self.unary_fn2(self.linear2(tmp)) + + bias_list = [True, False] + for bias in bias_list: + mod = M(bias).eval().to(device=device) + + def matcher_check_fn(): + # 1. dequant-linear pattern matched in quantization weight prepack + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + # 2. QLinear Unary fusion in post-grad fusion pass + self.assertEqual( + counters["inductor"]["qlinear_unary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qlinear_unary_lower_count"], + 0 if TEST_ACL else 2, + ) + + self._test_common( + mod, + inputs, + matcher_check_fn, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_relu_cpu(self): + r""" + This testcase will quantize a Linear->ReLU pattern. + """ + self._qlinear_unary_test_helper((torch.randn((2, 4)),)) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_relu_int8_mixed_bf16(self): + r""" + This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization. + """ + self._qlinear_unary_test_helper((torch.randn((2, 4)),), int8_mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_relu_input_dim_exceeds_2(self): + r""" + This testcase will quantize a Linear->ReLU pattern. + """ + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),)) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_relu_int8_mixed_bf16_input_dim_exceeds_2(self): + r""" + This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization. + """ + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), int8_mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_gelu_cpu(self): + r""" + This testcase will quantize a Linear->GELU pattern. + """ + for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: + self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_gelu_int8_mixed_bf16(self): + r""" + This testcase will quantize a Linear->GELU pattern with int8_mixed_bf16 quantization. + """ + for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: + self._qlinear_unary_test_helper( + (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True + ) + + def _qlinear_add_test_helper( + self, + device="cpu", + use_relu=False, + int8_mixed_bf16=False, + is_qat=True, + is_dynamic=True, + ): + r""" + This testcase will quantize two consecutive Linear->Add(->relu) patterns as: + X + / \ + linear(X) linear(X) + \ / + Add + | + Optional(relu) + / \ + linear(X) linear(X) + \ / + Add + | + Optional(relu) + | + Y + """ + + def fake_quant(x): + # to produce a float32 result as extra input + qlib = torch.ops.quantized_decomposed + if device == "cpu": + qmin, qmax, dtype = 0, 255, torch.uint8 + else: + qmin, qmax, dtype = -128, 127, torch.int8 + x = qlib.quantize_per_tensor.default(x, 0.0166785, 42, qmin, qmax, dtype) + x = qlib.dequantize_per_tensor.default(x, 0.0166785, 42, qmin, qmax, dtype) + return x + + class M(torch.nn.Module): + def __init__( + self, + add_fn, + use_relu, + fake_quant_before_extra_input, + ): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.add_fn = add_fn + self.relu = torch.nn.ReLU() + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU() + self.use_relu = use_relu + self.fake_quant_before_extra_input = fake_quant_before_extra_input + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x) + if self.fake_quant_before_extra_input: + x2 = fake_quant(x2) + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + tmp1 = self.linear3(tmp) + tmp2 = self.linear4(tmp) + if self.fake_quant_before_extra_input: + tmp2 = fake_quant(tmp2) + res = self.add_fn2(tmp1, tmp2) + if self.use_relu: + res = self.relu2(res) + return res + + add_fn_list = [ + lambda x, y: x + y, + lambda x, y: y + x, + lambda x, y: x.add_(y), + lambda x, y: y.add_(x), + ] + fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False] + shape_list = [(4, 4), (4, 4, 4)] + cases = itertools.product(add_fn_list, fake_quant_x2_list, shape_list) + for add_fn, fq_x2, shape in cases: + mod = M(add_fn, use_relu, fq_x2).eval().to(device=device) + v = torch.randn( + shape, dtype=torch.float32, requires_grad=False, device=device + ).add(1) + + def matcher_check_fn(): + # 1. Dequant-linear pattern matched in quantization weight prepack * 4 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 + ) + # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm] + nodes_per_match = 6 if int8_mixed_bf16 else 4 + if len(shape) == 3: + # pattern = [dequant_per_tensor, (convert_dtype), (view), \ + # dequant_per_channel, (convert_dtype), (view), permute, addmm] + nodes_per_match += 2 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 4 * nodes_per_match, + ) + # 2. Qlinear Binary Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qlinear_binary_matcher_count"], + 0 if TEST_ACL else 2, + ) + # Two linear-binary patterns are matched + # matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor] + # matched patter2 = [qlinear, add, (convert dtype), (relu)] + # If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary + to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2) + expected_matcher_nodes = ( + (4 if is_dynamic else 5) + 2 * use_relu + to_bf16_after_binary + ) + self.assertEqual( + counters["inductor"]["qlinear_binary_matcher_nodes"], + 0 if TEST_ACL else expected_matcher_nodes, + ) + self.assertEqual( + counters["inductor"]["qlinear_binary_lower_count"], + 0 if TEST_ACL else 2, + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + + if TEST_ACL: + continue + + if torch._inductor.config.cpp_wrapper: + # For CPP wrapper + self._test_code_common( + mod, + (v,), + [ + "aoti_torch_cpu__qlinear_pointwise_tensor", + "aoti_torch_cpu__qlinear_pointwise_binary_tensor", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + ) + print("time 2:", time.time() - t0) + t0 = time.time() + else: + # For python wrapper + self._test_code_common( + mod, + (v,), + [ + "torch.ops.onednn.qlinear_pointwise.tensor", + "torch.ops.onednn.qlinear_pointwise.binary", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @parametrize("use_relu", [True, False]) + @parametrize("is_qat", [True, False]) + @parametrize("is_dynamic", [True, False]) + def test_qlinear_add_cpu(self, use_relu, is_qat, is_dynamic): + self._qlinear_add_test_helper( + use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @parametrize("use_relu", [True, False]) + @parametrize("is_qat", [True, False]) + @parametrize("is_dynamic", [True, False]) + def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic): + self._qlinear_add_test_helper( + int8_mixed_bf16=True, + use_relu=use_relu, + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + + def _qlinear_dequant_promotion_test_helper( + self, + inputs, + device="cpu", + int8_mixed_bf16=False, + is_dynamic=False, + matcher_check_fn=None, + ): + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + + def forward(self, x): + temp = self.linear1(x) + temp = self.linear2(temp) + self.linear3(temp) + return temp + + mod = M().eval().to(device=device) + + def default_matcher_check_fn(): + # 1. Dequant pattern matcher for dequant promotion * 1 + self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) + # 2. dequant-linear pattern matched in quantization weight prepack * 3 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 + ) + # 3. QLinear Unary fusion in post-grad fusion pass * 1 + self.assertEqual( + counters["inductor"]["qlinear_unary_matcher_count"], + 0 if TEST_ACL else 1, + ) + + self._test_common( + mod, + inputs, + matcher_check_fn=( + matcher_check_fn + if matcher_check_fn is not None + else default_matcher_check_fn + ), + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_quantization=True, + is_dynamic=is_dynamic, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_dequant_promotion_cpu(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),)) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_dequant_promotion_int8_mixed_bf16(self): + r""" + Test with int8_mixed_bf16 quantization. + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 4)),), int8_mixed_bf16=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper((torch.randn((2, 3, 4)),)) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self): + r""" + Test with int8_mixed_bf16 quantization. + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 3, 4)),), int8_mixed_bf16=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_dequant_promotion_dynamic_cpu(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + + def matcher_check_fn(): + # 1. Dequant pattern matcher for dequant promotion * 1 + self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) + # 2. dequant-linear pattern matched in quantization weight prepack * 3 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 + ) + + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 4)),), + matcher_check_fn=matcher_check_fn, + is_dynamic=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_mul_cpu(self): + r""" + This testcase will quantize a Linear->Mul pattern. + """ + + class M(torch.nn.Module): + def __init__(self, use_bias): + super().__init__() + self.linear = torch.nn.Linear(4, 5, use_bias) + + def forward(self, x1, x2): + return torch.mul(self.linear(x1), x2) + + bias_list = [True, False] + for bias in bias_list: + mod = M(bias).eval() + x1 = torch.randn((2, 4)) + x2 = torch.randn((2, 5)) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 + ) + + self._test_common( + mod, + (x1, x2), + matcher_check_fn, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + def test_qmaxpool2d(self): + r""" + This testcase will quantize Conv2d->ReLU->MaxPool2d pattern. + """ + + class M(torch.nn.Module): + def __init__( + self, + kwargs, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.relu = torch.nn.ReLU() + self.maxpool = torch.nn.MaxPool2d(3, **kwargs) + + def forward(self, x): + return self.maxpool(self.relu(self.conv(x))) + + kwargs_list = [ + {"stride": 2}, + {"stride": 2, "padding": 1}, + {"stride": 2, "padding": 1, "dilation": 1}, + {"stride": 2, "padding": 1, "dilation": 1, "ceil_mode": False}, + ] + for kwargs in kwargs_list: + mod = M(kwargs).eval() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( + 1 + ) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qmaxpool2d_matcher_count"], + 0 if TEST_ACL else 1, + ) + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 + ) + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 1, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], + 0 if TEST_ACL else 1, + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + def test_qflatten(self): + r""" + This testcase will quantize Conv2d->AdaptiveAvgPool2d->flatten->cat pattern. + """ + + class M(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.relu = torch.nn.ReLU() + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + return torch.cat( + [ + torch.flatten( + self.adaptive_avg_pool2d(self.relu(self.conv(x))), 1 + ) + ] + ) + + mod = M().eval() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qreshape_matcher_count"], 0 if TEST_ACL else 1 + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + ) + + @skipIfNoDynamoSupport + def test_qcat(self): + r""" + This testcase will quantize cat based pattern: + X + / \ + Conv1(X) Pow(x) + \ \ + \ Conv2(X) + \ / + Cat + | + Y + """ + + class M(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.conv2 = torch.nn.Conv2d( + 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 + ) + + def forward(self, x): + temp1 = self.conv(x) + temp2 = self.conv2(torch.pow(x, 2)) + return torch.cat((temp1, temp2), 1) + + mod = M().eval() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qcat_matcher_count"], 0 if TEST_ACL else 1 + ) + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + ) + + def _test_linear_dynamic_fp16_helper(self, use_relu: bool): + class M(torch.nn.Module): + def __init__(self, bias: bool, use_relu: bool): + super().__init__() + self.linear = torch.nn.Linear(256, 256, bias=bias) + self.relu = torch.nn.ReLU() + self.use_relu = use_relu + + def forward(self, x): + if self.use_relu: + return self.relu(self.linear(x)) + return self.linear(x) + + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + quantizer.set_module_type_qconfig( + torch.nn.Linear, xiq.get_x86_inductor_linear_dynamic_fp16_config() + ) + bias_list = [True, False] + input_ndim_list = [2, 3] + x_contig_list = [True, False] + cases = itertools.product(bias_list, input_ndim_list, x_contig_list) + for bias, input_ndim, x_contig in cases: + x_shape = (4, 256) if input_ndim == 2 else (4, 1, 256) + x = torch.randn(x_shape) + if not x_contig: + x = x[0::2, ...] + mod = M(bias, use_relu).eval() + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 + ) + # Matched nodes: + # (1) w to fp16, (2) w to fp32, (3) permute w, (4) mm/addmm/bmm + # If x.ndim == 3 and x is contiguous, two view nodes are added. + # If x.ndim == 3 and x is not contiguous, two expand nodes and one add node are added. + nodes_count = 4 + if input_ndim > 2: + if x_contig: + nodes_count += 2 + else: + nodes_count += 3 if bias else 2 + if use_relu: + nodes_count += 1 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + nodes_count, + ) + + self._test_common( + mod, + (x,), + atol=1e-2, + rtol=1e-2, + matcher_check_fn=matcher_check_fn, + check_quantization=True, + quantizer=quantizer, + ) + linear_op_str = ( + "torch.ops.onednn.linear_relu_dynamic_fp16.default" + if use_relu + else "torch.ops.onednn.linear_dynamic_fp16.default" + ) + self._test_code_common( + mod, + (x,), + [linear_op_str], + ["torch.ops.aten.addmm.default", "torch.ops.aten.mm.default"], + check_quantization=True, + quantizer=quantizer, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_linear_dynamic_fp16(self): + self._test_linear_dynamic_fp16_helper(use_relu=False) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_linear_relu_dynamic_fp16(self): + self._test_linear_dynamic_fp16_helper(use_relu=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + # TODO: investigate options of torch.compile in fbcode + @unittest.skipIf(IS_FBCODE, "Failing in fbcode") + @parametrize("has_bias", [True, False]) + @parametrize("dtype", [torch.float, torch.bfloat16]) + @parametrize("per_channel_quant", [True, False]) + @parametrize("dynamic", [True, False]) + def test_smooth_quant_with_int_mm( + self, has_bias, dtype, per_channel_quant, dynamic + ): + r""" + This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao. + The pattern is: + (no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape + or + (with bias) pattern_no_bias -> add -> reshape -> reshape + """ + if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported(): + return + M = 16 + in_feature = 32 + out_feature = 64 + q_min, q_max = -32, 31 + + class Mod(torch.nn.Module): + def __init__( + self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool + ): + super().__init__() + self.dtype = dtype + self.has_bias = has_bias + self.b = torch.randint( + q_min, q_max, [in_feature, out_feature], dtype=torch.int8 + ) + self.per_channel_quant = per_channel_quant + a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01 + a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01 + self.a_scale = ( + a_scale_per_channel + if self.per_channel_quant + else a_scale_per_tensor + ) + self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01 + self.b_scale = self.b_scale.to(dtype) + self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None + + def forward(self, a): + out_shape = a.shape[:-1] + (self.b.size(-1),) + a_reshaped = a.reshape(-1, a.size(-1)) + c = torch._int_mm(a_reshaped, self.b) + c = c.to(self.dtype) + c_shape = c.shape + a_scale = self.a_scale.expand(c.shape) + c = c * a_scale + c = c * self.b_scale + if self.has_bias: + c = c.reshape([1, *list(c_shape)]) + c = c + self.bias + c = c.reshape(c_shape) + c = c.reshape(out_shape) + return c + + mod = Mod(dtype, has_bias, per_channel_quant).eval() + a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 + ) + if dynamic: + nodes_count = 10 if has_bias else 7 + else: + nodes_count = 7 if has_bias else 6 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + nodes_count, + ) + + self._test_common( + mod, + (a,), + matcher_check_fn=matcher_check_fn, + check_autocast=dtype, + compile_options={"dynamic": dynamic}, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + # TODO: investigate options of torch.compile in fbcode + @unittest.skipIf(IS_FBCODE, "Failing in fbcode") + @parametrize("has_bias", [True, False]) + @parametrize("dtype", [torch.float, torch.bfloat16]) + @parametrize("dynamic", [True, False]) + @parametrize("reshape_a", [True, False]) + @parametrize( + "M", + [ + 1, + 32, + ], + ) + @parametrize("inplace_add", [True, False]) + @parametrize("expand_a_scale", [True, False]) + def test_da8w8_sym_act_sym_wgt_with_int_mm( + self, has_bias, dtype, dynamic, reshape_a, M, inplace_add, expand_a_scale + ): + r""" + This testcase check if we can match the int8_dynamic_activation_int8_weight int8 linear pattern from torchao, + when activation is symmetrically quantized dynamically & weights are symmetrically quantized (statically) + The pattern is: + (no bias) _int_mm -> convert_element_type -> ([expand_a] -> mul) -> mul + or + (with bias) pattern_no_bias -> add + Expansion of the scale of activation is optional. + The pattern depiction doesn't mean that convert_element_type output is fed into expand_a as input, + but simply that activation scale may be applied after an expand operation on it. + """ + if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported(): + return + in_feature = 32 + out_feature = 64 + q_min, q_max = -32, 31 + # we only test for qlinear_binary in this case + test_for_pointwise_binary = ( + True + if M == 1 + and inplace_add + and not expand_a_scale + and not dynamic + and not has_bias + else False + ) + if test_for_pointwise_binary and not IS_X86: + self.skipTest("Some UTs are only supported on x86_64 CPUs") + + class Mod(torch.nn.Module): + def __init__(self, dtype: torch.dtype, has_bias: bool): + super().__init__() + self.dtype = dtype + self.has_bias = has_bias + self.b = torch.randint( + q_min, q_max, [in_feature, out_feature], dtype=torch.int8 + ) + self.a_scale = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01 + self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01 + self.b_scale = self.b_scale.to(dtype) + self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None + self.additive = torch.rand([M, out_feature], dtype=dtype) + + def forward(self, a): + if reshape_a: + a_reshaped = a.reshape(-1, a.size(-1)) + else: + a_reshaped = a + c = torch._int_mm(a_reshaped, self.b) + c = c.to(self.dtype) + if expand_a_scale: + a_scale = self.a_scale.expand(c.shape) + else: + a_scale = self.a_scale + c = c * a_scale + c = c * self.b_scale + if self.has_bias: + c = c + self.bias + elif inplace_add and test_for_pointwise_binary: + # When M is 1, dynamic shapes are enabled with torch.compile, has_bias is False, + # expand_a_scale is False and inplace_add is true, + # the output's outermost dim's stride can't be determined due to some Inductor bug. + c.add_(self.additive) + return c + + mod = Mod(dtype, has_bias).eval() + a = torch.randint(q_min, q_max, [M, in_feature], dtype=torch.int8) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 + ) + + self._test_common( + mod, + (a,), + matcher_check_fn, + check_autocast=dtype, + compile_options={"dynamic": dynamic}, + ) + if test_for_pointwise_binary: + self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) + + +@dynamo_config.patch( + { + "dynamic_shapes": True, + "assume_static_by_default": False, + "specialize_float": True, + } +) +class TestDynamicPatternMatcher(TestPatternMatcherBase): + @xfailIfACL + def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): + r""" + This testcase will quantize a single Conv2d->Maxpool2d->Linear module + with dynamic batch size input. + """ + + class M(torch.nn.Module): + def __init__( + self, + **kwargs, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 16, (2, 2), stride=(1, 1), padding=(1, 1) + ) + self.relu = torch.nn.ReLU() + self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = torch.nn.Linear(16, 16) + + def forward(self, x): + temp = self.relu(self.conv(x)) + temp = self.maxpool2d(temp) + temp = self.avgpool(temp) + temp = torch.flatten(temp, 1) + return self.linear(temp) + + mod = M().eval() + v = torch.randn((2, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) + if include_ops is None: + include_ops = [ + "torch.ops.onednn.qconv_pointwise", + "torch.ops.quantized.max_pool2d", + "torch.ops.onednn.qlinear_pointwise", + ] + exclude_ops = [] + self._test_code_common( + mod, + (v,), + include_ops, + exclude_ops, + check_quantization=True, + check_dynamic=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qat_bn_conv2d(self): + r""" + This testcase will quantize a single BN Conv2d module with qat flow. + """ + + class M(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.bn1 = torch.nn.BatchNorm2d(3) + self.bn2 = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv(self.bn1(x)) + return self.bn2(x) + + mod = M().train() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + is_qat=True, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_q_attention_block(self): + class SelfAttnLikeModule(torch.nn.Module): + def __init__( + self, + input_dim, + transpose_for_score=False, + num_attention_heads=None, + attention_head_size=None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.softmax = torch.nn.Softmax(dim=-1) + self.transpose_for_score = transpose_for_score + if self.transpose_for_score: + assert num_attention_heads is not None + assert attention_head_size is not None + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, x): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + if self.transpose_for_score: + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + attention = self.softmax(scores) + weighted = torch.matmul(attention, v) + return weighted + + for annotate_matmul in [False, True]: + mod = SelfAttnLikeModule( + input_dim=64 * 16, + transpose_for_score=True, + num_attention_heads=16, + attention_head_size=64, + ).eval() + v = torch.randn(2, 384, 1024) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 + ) + self.assertEqual( + counters["inductor"]["qlinear_unary_matcher_count"], + 3 if annotate_matmul and not TEST_ACL else 0, + ) + + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + if annotate_matmul: + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + quantizer=quantizer, + ) + + +instantiate_parametrized_tests(TestPatternMatcher) +if __name__ == "__main__": + if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available(): + # set weight_prepack = False to skip fusion passes in pytorch core + import torch._inductor.config + torch._inductor.config.cpp.weight_prepack = False + run_tests() diff --git a/torchao/prototype/inductor/fx_passes/quantization.py b/torchao/prototype/inductor/fx_passes/quantization.py new file mode 100644 index 0000000000..06d4a9489c --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/quantization.py @@ -0,0 +1,2854 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import copy +import functools +import itertools +import math +import operator +from typing import Any + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.fx.node import map_arg + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + KeywordArg, + Match, +) +from torch._inductor.fx_passes.freezing_patterns import register_freezing_graph_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims +quantized_decomposed = torch.ops.quantized_decomposed +quantized = torch.ops.quantized + +# Only for per tensor quant since permute may changes the channel idx +_PER_TENSOR_QUANTIZE_OPS = [ + quantized_decomposed.quantize_per_tensor.default, + quantized_decomposed.quantize_per_tensor.tensor, +] + +_VIEW_OPS = [ + aten.transpose.int, + aten.permute.default, + aten.view.default, +] + +""" +The quantization.py file primarily incorporates passes related to quantization fusion +in inductor, includes: +1. Dequant Promotion; +2. Conv/GEMM weight prepack with oneDNN Library; +3. Conv/GEMM quantization fusion with output quant node (if have); +4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; + +It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference +of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is +1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. +2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. +Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 +quantization. +""" + + +def _get_pattern_output_dtype(match: Match): + """ + Get the pattern's output dtype from node's meta + Assume only 1 output node in this matched pattern. + """ + pattern_output_nodes = match.output_nodes() + assert len(pattern_output_nodes) == 1 + output_node = pattern_output_nodes[0] + assert isinstance(output_node, torch.fx.Node) + output_dtype = output_node.meta["val"].dtype + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + return output_dtype + + +def _may_generate_pattern_with_dtype_convert( + pattern, dtype=Arg(), with_dtype_convert=True, users=1 +): + if with_dtype_convert: + return CallFunction( + prims.convert_element_type.default, + pattern, + dtype, + _users=users, + ) + else: + return pattern + + +def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): + if with_reshape: + return CallFunction( + torch.ops.aten.reshape.default, + pattern, + reshape_size, + ) + else: + return pattern + + +def _generate_linear_t_pattern( + _dequant_per_channel_pattern, + dtype, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + _dequant_per_channel_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + return t_pattern + + +def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): + # only insert to_dtype if is_bf16 is True + computation_call = _may_generate_pattern_with_dtype_convert( + call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users + ) + return unary_fusion(computation_call) + + +def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): + dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.tensor + if is_tensor_overload + else quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("x_quant_min"), + KeywordArg("x_quant_max"), + KeywordArg("x_dq_dtype"), + ) + return dequantize_per_tensor_activation_pattern + + +dequantize_per_channel_weight_pattern = CallFunction( + quantized_decomposed.dequantize_per_channel.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("w_axis"), + KeywordArg("w_quant_min"), + KeywordArg("w_quant_max"), + KeywordArg("w_dtype"), +) + +dequantize_per_channel_to_bf16_weight_pattern = ( + _may_generate_pattern_with_dtype_convert( + dequantize_per_channel_weight_pattern, + KeywordArg("autocast_wgt_dtype"), + ) +) + +dequantize_per_channel_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + +dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_to_bf16_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + + +def get_qconv_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv_pointwise.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +def get_qconv2d_binary_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv2d_pointwise.binary, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("accum"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("accum_scale"), + KeywordArg("accum_zero_point"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + +def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +def get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("x_2"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("x2_scale"), + KeywordArg("x2_zp"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + +dequantize_accum_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("accum"), + KeywordArg("accum_scale"), + KeywordArg("accum_zp"), + Arg(), + Arg(), + KeywordArg("accum_dq_dtype"), +) + + +def generate_pattern_with_binary( + binary_post_op, + computation_call, + extra_input_pattern, + dtype_convert=False, + swap_inputs=False, +): + binary_pattern = ( + CallFunction( + binary_post_op, + extra_input_pattern, + computation_call, + ) + if swap_inputs + else CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) + ) + return _may_generate_pattern_with_dtype_convert( + binary_pattern, + KeywordArg("convert_dtype_after_inplace_add"), + dtype_convert, + ) + + +def generate_pattern_with_unary(computation_call, unary_post_op): + if unary_post_op is not None: + return CallFunction( + unary_post_op, + computation_call, + ) + return computation_call + + +def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): + quantized_op_output_pattern_pt2e = CallFunction( + quantized_decomposed.quantize_per_tensor.default, + _may_generate_pattern_with_dtype_convert( + computation_call, + Arg(), + with_dtype_convert, + ), + KeywordArg("o_inv_scale"), + KeywordArg("o_zp"), + KeywordArg("o_qmin"), + KeywordArg("o_qmax"), + KeywordArg("o_dtype"), + ) + return quantized_op_output_pattern_pt2e + + +def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): + if kwarg_name in check_node.kwargs: + actual_value = check_node.kwargs[kwarg_name] + return actual_value == expected_value + else: + assert len(check_node.args) >= (args_index + 1) + actual_value = check_node.args[args_index] + return actual_value == expected_value + + +def _is_valid_quantized_conv_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qconv_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qconv_pointwise + )[0] + return _check_node_kwarg_arg_value( + qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype + ) + return True + + return fn + + +def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): + return ( + _is_valid_qconv_binary_optimization_pattern() + if has_binary_post_op + else _is_valid_quantized_conv_optimization_pattern() + ) + + +def _is_valid_quantized_linear_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qlinear_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qlinear_pointwise + )[0] + return _check_node_kwarg_arg_value( + qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype + ) + return True + + return fn + + +def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False): + return ( + _is_valid_qlinear_binary_optimization_pattern() + if has_binary_post_op + else _is_valid_quantized_linear_optimization_pattern() + ) + + +def _is_valid_qconv_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qconv_pointwise + ) + + +def _is_valid_qlinear_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qlinear_pointwise, + # we don't insert q-dq for extra input due to accuracy issues + extra_input_from_dequant=False, + ) + + +def _is_valid_quantized_op_binary_optimization_pattern( + qop, extra_input_from_dequant=True +): + # Check if it's a valid Binary Pattern for qconv2d and qlinear: + # * qop_pointwise should only has one users + # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern + # * the two inputs of binary node should have attribute "meta" and should be tensors + # * the two inputs of binary node should have the same shape + # * All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + compute_node = filter_nodes(match.nodes, qop)[0] + # qop_pointwise should only have one user + if len(compute_node.users) != 1: + return False + binary_node_inputs = next(iter(compute_node.users)).args + assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + if output_dtype in [torch.float32, torch.bfloat16]: + extra_input_of_binary_node = None + for arg in binary_node_inputs: + if arg != compute_node: + extra_input_of_binary_node = arg + break + assert extra_input_of_binary_node is not None + # Extra input of binary node comes from dequant pattern + if extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + != quantized_decomposed.dequantize_per_tensor.default + ) + ): + return False + + # the two inputs of binary node should have attribute "meta" and should be tensors + if not ( + hasattr(binary_node_inputs[0], "meta") + and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ) or not ( + hasattr(binary_node_inputs[1], "meta") + and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ): + return False + # the two inputs of binary node should have the same shape + if ( + binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] + != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] + ): + return False + + # All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + + from torch._inductor.fx_passes.mkldnn_fusion import _get_remaining_users + + extra_input_of_pattern = ( + match.kwargs["other"] + if "other" in match.kwargs + else ( + match.kwargs["accum"] + if (output_dtype in [torch.uint8, torch.int8]) + or (not extra_input_from_dequant) + else match.kwargs["accum_after_dequant"] + ) + ) + if ( + len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1 + or extra_input_of_pattern == compute_node.args[0] + ): + return False + return True + + return fn + + +def _is_valid_dequant_promotion_pattern(dtype=torch.float32): + def _inner(match): + assert dtype in [torch.float32, torch.bfloat16] + dequant_pattern_end_node = match.output_node() + if dequant_pattern_end_node.target not in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ]: + return False + + if dequant_pattern_end_node.target is aten.reshape.default: + dequant_node = ( + dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- reshape <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[0].args[ + 0 + ] # pattern: linear <- reshape <- to_bf16 <- dequant + ) + else: + dequant_node = ( + dequant_pattern_end_node # pattern: linear <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- to_bf16 <- dequant + ) + + if ( + dequant_node.target + in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + and len(list(dequant_pattern_end_node.users)) > 1 + ): + # If dequant pattern has more than 1 users, then do dequant promoted + return True + return False + + return _inner + + +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_promotion_pattern(dtype), + pass_number=pass_number, + ) + def dequant_promotion(match: Match, *args, **kwargs): + # Dequant_promotion will transform + # graph 1: + # quant + # + - - - | - - - + + # | dequant | + # | / \ | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # into: + # graph 2: + # quant + # + - - / - \ - - + + # |dequant dequant| + # | | | | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # In graph 1, the dequant node is shared by node1 and node2, + # as a result, neither node1 nor node2 could form an int8 + # fusion pattern. + # After this transformation, the graph 2 could hit the int8 + # fusion pattern: dequant-node-quant, respectively for + # node1 and node2. + assert dtype in [torch.float32, torch.bfloat16] + + def clone_to_new_node(graph, source_node, user_node): + # Clone the source_node to a new node + # Replace user_node's input from source_node to new_node + assert source_node.op == "call_function", ( + "clone_to_new_node only support node.op call_function" + ) + with graph.inserting_before(user_node): + new_node = graph.call_function( + source_node.target, + args=source_node.args, + kwargs=source_node.kwargs, + ) + new_node.meta = copy.copy(source_node.meta) + user_node.replace_input_with(source_node, new_node) + return new_node + + # Find the start node and end node of a dequant pattern + # * End node should be the match.output_node() + # * Start node should be the node of dequantize_per_tensor + dequant_pattern_end_node = match.output_node() + assert dequant_pattern_end_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ] + + # For a dequant pattern, we should expect see the node list as: + # * OPT(aten.reshape.default) + # * OPT(prims.convert_element_type.default) (to_bf16) + # * dequantize_per_tensor + def _find_first_node_in_dequant_pattern(_node): + if _node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For a dequant pattern, we expect the start node is a dequantize_per_tensor node + return _node + else: + assert len(_node.args) >= 1, ( + "In in dequant pattern, each node should have more than 1 arg." + ) + return _find_first_node_in_dequant_pattern(_node.args[0]) + + dequant_pattern_start_node = _find_first_node_in_dequant_pattern( + dequant_pattern_end_node + ) + + assert dequant_pattern_start_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + # Clone the dequant pattern for each user node + graph = match.graph + user_node_list = list(dequant_pattern_end_node.users) + for user_node in user_node_list[1:]: + _source_node = dequant_pattern_end_node + _user_node = user_node + while _source_node != dequant_pattern_start_node.args[0]: + _user_node = clone_to_new_node(graph, _source_node, _user_node) + _source_node = _source_node.args[0] # type: ignore[assignment] + + counters["inductor"]["dequant_promotion_matcher_count"] += 1 + counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) + + +def _is_valid_dequant_conv_pattern(dtype): + def _inner(match): + # Here we do some further check to ensure: + # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. + # 2. The dequant pattern has only 1 user of conv2d node. + # If these conditions don't meet, we will not + # insert weight prepack node into the matched pattern. + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu") + or meta_value.dim() not in [3, 4] + ): + # Only support conv1d/2d now + return False + + assert dtype in [torch.float32, torch.bfloat16] + + if dtype == torch.float32: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + return True + + return _inner + + +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_conv_pattern(dtype), + pass_number=pass_number, + ) + def qconv_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + if dtype == torch.float32: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr] + has_clone_to_channel_last_node_in_pattern = ( + conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] + ) + clone_node = ( + conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None + ) + + if dtype == torch.float32: + dequant_per_channel = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + else: + weight_to_bf16_node = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + + assert ( + dequant_per_channel.target # type: ignore[union-attr] + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Conv Params + bias, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(conv_node): + # Insert weight prepack node and the QConv node + packed_weight_inputs = ( + qw, + w_scale, + x_scale, + x_zp, + stride, + padding, + dilation, + groups, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qconv_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # attr + [], # scalars + "", # algorithm + ) + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) + conv_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(conv_node.meta) + + # Erase the original conv node + graph.erase_node(conv_node) + # Erase the dequant pattern + if dtype == torch.bfloat16: + graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_node) # type: ignore[arg-type] + # Erase the dequant per channel pattern + if clone_node is not None: + graph.erase_node(clone_node) # type: ignore[arg-type] + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_convolution_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32 +): + assert dtype in [torch.float32, torch.bfloat16] + dequant_convolution_node_pattern = CallFunction( + aten.convolution.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + _dequant_per_channel_pattern, + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("is_transposed"), + KeywordArg("out_padding"), + KeywordArg("groups"), + ) + return dequant_convolution_node_pattern + + +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): + assert dtype in [torch.float32, torch.bfloat16] + return ( + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_weight_pattern, + dtype, + ), + # There is another pattern due to the pass of convert_conv_weights_to_channels_last + # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. + # Depend on some heuristics, it may or may not insert to(channel_last) node + # between convolution and dequant_per_channel node + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_clone_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_clone_weight_pattern, + dtype, + ), + ) + + +def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): + output_reshape_node = None + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node = match.output_node() + assert output_reshape_node.target is aten.reshape.default + linear_node = output_reshape_node.args[0] + else: + linear_nodes = filter_nodes(match.nodes, aten.bmm.default) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + else: + linear_node = match.output_node() + + assert linear_node.target in ( + aten.addmm.default, + aten.mm.default, + aten.bmm.default, + ) + return linear_node, output_reshape_node + + +def _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +): + act_reshape_node = None + activation_to_bf16_node = None + act_expand_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + assert act_reshape_node.target is aten.reshape.default + if dtype == torch.float32: + # pattern: linear -> reshape -> dequant + dequant_node = act_reshape_node.args[0] + else: + # pattern: linear -> reshape -> to_bf16 -> dequant + activation_to_bf16_node = act_reshape_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous + act_expand_node = linear_node.args[input_index] + assert act_expand_node.target is aten.expand.default + if dtype == torch.float32: + dequant_node = act_expand_node.args[0] + else: + activation_to_bf16_node = act_expand_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + if dtype == torch.float32: + # pattern: linear -> dequant + dequant_node = linear_node.args[input_index] + else: + # pattern: linear -> to_bf16 -> dequant + activation_to_bf16_node = linear_node.args[input_index] + dequant_node = activation_to_bf16_node.args[0] + return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node + + +def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): + def _inner(match): + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + ( + dequant_node, + _, + _, + _, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + assert dequant_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + + # Extra check for bmm pattern + if input_dim_exceeds_two and not input_contiguous: + # Check for act + # Act expand size should be exactly same as act size + act_expand_size = match.kwargs["act_expand_size"] + act_node = match.kwargs["x"] + if not ( + hasattr(act_node, "meta") + and isinstance(act_node.meta.get("val", None), torch.Tensor) + and (act_node.meta["val"].size() == torch.Size(act_expand_size)) + ): + return False + + # Check for wgt + # wgt permute dims should be [1, 0] + wgt_permute_dims = match.kwargs["permute_axes"] + if wgt_permute_dims != [1, 0]: + return False + + # Check below wgt size items: + # wgt before expand should with dim 2 + # Expand size should with dim 3 + # Expand size[0] should same as act size[0] + # Expand size[1] should same as wgt size[1] + # Expand size[2] should same as wgt size[0] + qweight_node = match.kwargs["q_weight"] + wgt_expand_size = match.kwargs["wgt_expand_size"] + if not ( + hasattr(qweight_node, "meta") + and isinstance(qweight_node.meta.get("val", None), torch.Tensor) + and len(qweight_node.meta["val"].size()) == 2 + and len(wgt_expand_size) == 3 + and wgt_expand_size[0] == act_node.meta["val"].size()[0] + and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] + and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] + ): + return False + + return True + + return _inner + + +def _register_qlinear_weight_prepack_pass( + pattern, + pass_number, + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_linear_pattern( + dtype, input_dim_exceeds_two, input_contiguous + ), + pass_number=pass_number, + ) + def qlinear_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + mm/addmm <- t <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + dequant_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_channel = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_channel = weight_to_bf16_node.args[0] + assert ( + dequant_per_channel.target + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + qw, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + Node = torch.fx.node.Node + if isinstance(x_scale, Node) and isinstance(x_zp, Node): + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + else: + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.default, args=new_args + ) + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_reshape_node.meta) + else: + if bias: + output_add_node_for_bias = match.output_node() + assert output_add_node_for_bias.target is aten.add.Tensor + output_add_node_for_bias.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_add_node_for_bias.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + # Erase the original linear node + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(output_reshape_node) + elif not input_contiguous and bias: + graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] + graph.erase_node(linear_node) + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(act_reshape_node) + else: + graph.erase_node(act_expand_node) + graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(dequant_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_linear_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + input_dim_exceeds_two=False, + is_tensor_overload=False, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _generate_dequant_bmm_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + with_bias=False, + is_tensor_overload=False, +): + # When activation of linear dim exceed 2 and not contiguous + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + + assert dtype in [torch.float32, torch.bfloat16] + dequant_bmm_pattern = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + + def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): + if _with_bias: + return CallFunction( + aten.add.Tensor, + _dequant_bmm_pattern, + KeywordArg("b"), + ) + else: + return _dequant_bmm_pattern + + return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) + + +def _generate_qlinear_weight_prepack_patterns( + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, + with_bias=False, + is_tensor_overload=False, +): + if input_dim_exceeds_two and not input_contiguous: + return _generate_dequant_bmm_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + with_bias, + is_tensor_overload, + ) + else: + return _generate_dequant_linear_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + input_dim_exceeds_two, + is_tensor_overload, + ) + + +def _generate_linear_dynamic_fp16_pattern( + _dequant_weight_pattern, + input_dim_exceeds_two=False, + input_contiguous=True, + relu_fused=False, +): + dtype = torch.float32 + t_pattern = _generate_linear_t_pattern(_dequant_weight_pattern, dtype) + + if input_dim_exceeds_two and not input_contiguous: + # pattern is + # x -> expand -> bmm (-> add) (-> relu) + # w -> dequant -> permute -> expand / + pattern_no_bias = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + KeywordArg("x"), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + pattern_with_bias = CallFunction( + aten.add.Tensor, + pattern_no_bias, + KeywordArg("b"), + ) + if relu_fused: + pattern_with_bias = CallFunction(aten.relu.default, pattern_with_bias) + pattern_no_bias = CallFunction(aten.relu.default, pattern_no_bias) + return pattern_with_bias, pattern_no_bias + + x_pattern_with_reshape = _may_generate_pattern_with_reshape( + KeywordArg("x"), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_bias_pattern = generate_pattern_with_unary( + _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + x_pattern_with_reshape, + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ), + aten.relu.default if relu_fused else None, + ) + dequant_linear_no_bias_pattern = generate_pattern_with_unary( + _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + x_pattern_with_reshape, + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ), + aten.relu.default if relu_fused else None, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _register_dequant_promotion(): + dequant_pattern_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + # 4 dequantization patterns will be matched based on the dtype and input dimension size. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 + # quant + # + - - - - | - - - - + + # | dequant | + # | | | + # | OPT(to_bf16) | + # | | | + # | OPT(reshape) | + # | / \ | + # | node1 node2 | + # + - - | - - - | - - + + # OPT(reshape) OPT(reshape) + # + - - | - - - | - - + + # OPT(to_fp32) OPT(to_fp32) + # + - - | - - - | - - + + # quant quant + _register_dequant_promotion_pass( + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=is_tensor_overload + ), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + with_reshape=input_dim_exceeds_two, + ), + pass_number=0, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + +def _register_qconv_weight_prepack(): + for dtype in [torch.float32, torch.bfloat16]: + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qconv_weight_prepack_pass( + weight_prepack_pattern, pass_number=1, dtype=dtype + ) + + +def _register_qlinear_weight_prepack(): + # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. + # Then convert the pattern into a QLinear node with int8_fp32/bf16. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(reshape) | + + # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous + # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | expand permute | + # | \ | | + # | expand | + # | / | + # | bmm | + # | | | + # | OPT(add) | + + linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + + # Step 1: register patterns from mm and addmm + for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: + weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( + dtype, + input_dim_exceeds_two, + is_tensor_overload=is_tensor_overload, + ) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qlinear_weight_prepack_pass( + weight_prepack_pattern, + pass_number=1, + dtype=dtype, + input_dim_exceeds_two=input_dim_exceeds_two, + ) + + # Step 2: register patterns from bmm + # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous + # refer to: + # https://github.com/pytorch/pytorch/blob/ + # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 + # in this case, we can convert it back to qlinear + for dtype, with_bias, is_tensor_overload in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ): + bmm_pattern = _generate_qlinear_weight_prepack_patterns( + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + with_bias=with_bias, + is_tensor_overload=is_tensor_overload, + ) + _register_qlinear_weight_prepack_pass( + bmm_pattern, + pass_number=1 + if with_bias + else 2, # if with_bias, there is an output add, so we should try to match it firstly + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + ) + + +def _register_linear_dynamic_fp16_weight_prepack_pass( + pattern, + pass_number, + input_dim_exceeds_two=False, + input_contiguous=True, + relu_fused=False, +): + def _extra_check_fn(match: Match): + return match.kwargs["dtype_fp16"] == torch.float16 + + @register_freezing_graph_pattern( + pattern, + extra_check=_extra_check_fn, + pass_number=pass_number, + ) + def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + fp32 activation + | + mm/addmm <- t <- to_fp32 <- to_fp16 <- weight + | + (reshape) <- (relu) + + OR + + fp32 activation + | + expand + | + bmm <- expand <- t <- to_fp32 <- to_fp16 <- weight + | + (add) <- (relu) + + Insert weight prepack node and change the pattern to: + fp32 activation + | + onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight + (or onednn.linear_relu_dynamic_fp16) + """ + # find params + x = kwargs["x"] + w = kwargs["w"] + bias = kwargs["b"] if "b" in kwargs else None + + # find linear node + nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default] + linear_nodes = [] + for node in nodes_to_find: + linear_nodes.extend(filter_nodes(match.nodes, node)) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + assert isinstance(linear_node, torch.fx.node.Node) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + # find relu node + relu_node = None + if relu_fused: + relu_node = match.output_node() + assert isinstance(relu_node, torch.fx.node.Node) + + # find reshape node, expand node and add node + ( + act_reshape_node, + output_reshape_node, + expand_x_node, + expand_w_node, + add_bias_node, + ) = (None, None, None, None, None) + t_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + t_node = linear_node.args[weight_index] + output_reshape_node = next(iter(linear_node.users)) + assert output_reshape_node.target is aten.reshape.default + else: + expand_x_node = linear_node.args[input_index] + expand_w_node = linear_node.args[weight_index] + assert isinstance(expand_w_node, torch.fx.node.Node) + t_node = expand_w_node.args[0] + if bias: + add_bias_node = next(iter(linear_node.users)) + assert add_bias_node.target is aten.add.Tensor + else: + t_node = linear_node.args[weight_index] + assert isinstance(t_node, torch.fx.node.Node) + + w_to_fp32_node = t_node.args[0] + assert ( + isinstance(w_to_fp32_node, torch.fx.node.Node) + and w_to_fp32_node.target + is quantized_decomposed.convert_element_type.no_fuse + ) + w_to_fp16_node = w_to_fp32_node.args[0] + assert ( + isinstance(w_to_fp16_node, torch.fx.node.Node) + and w_to_fp16_node.target + is quantized_decomposed.convert_element_type.no_fuse + ) + + x_shape = x.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + w, + x_shape, + ) + packed_weight_op = torch.ops.onednn.linear_prepack_fp16 + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + # create new linear node and insert on graph + new_args: tuple[Any, ...] = ( + x, + prepack_weight_node, + bias, + ) + linear_op = ( + torch.ops.onednn.linear_relu_dynamic_fp16.default + if relu_fused + else torch.ops.onednn.linear_dynamic_fp16.default + ) + new_linear_node = graph.call_function(linear_op, args=new_args) + out_node = match.output_node() + out_node.replace_all_uses_with(new_linear_node) + + # Erase the original nodes in the reverse order + new_linear_node.meta.update(out_node.meta) + if relu_node is not None: + graph.erase_node(relu_node) + if output_reshape_node is not None: + graph.erase_node(output_reshape_node) + if add_bias_node is not None: + graph.erase_node(add_bias_node) + graph.erase_node(linear_node) + if act_reshape_node is not None: + assert isinstance(act_reshape_node, torch.fx.node.Node) + graph.erase_node(act_reshape_node) + if expand_x_node is not None: + assert isinstance(expand_x_node, torch.fx.node.Node) + graph.erase_node(expand_x_node) + if expand_w_node is not None: + assert isinstance(expand_w_node, torch.fx.node.Node) + graph.erase_node(expand_w_node) + graph.erase_node(t_node) + graph.erase_node(w_to_fp32_node) + graph.erase_node(w_to_fp16_node) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _register_linear_dynamic_fp16_weight_prepack(): + to_dtype_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse + weight_pattern = CallFunction( + to_dtype_op, + CallFunction( + to_dtype_op, + KeywordArg("w"), + KeywordArg("dtype_fp16"), + ), + KeywordArg("dtype_fp32"), + ) + cases = itertools.product( + [False, True], # input_dim_exceeds_two + [True, False], # input_contiguous + [False, True], # relu fused + ) + for input_dim_exceeds_two, input_contiguous, relu_fused in cases: + patterns = _generate_linear_dynamic_fp16_pattern( + weight_pattern, + input_dim_exceeds_two, + input_contiguous, + relu_fused, + ) + for pattern in patterns: + _register_linear_dynamic_fp16_weight_prepack_pass( + pattern, + pass_number=0 if relu_fused else 1, + input_dim_exceeds_two=input_dim_exceeds_two, + input_contiguous=input_contiguous, + relu_fused=relu_fused, + ) + + +def _register_smooth_quant_int_mm_pattern(): + """ + The pattern is: + (no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape + or + (with bias) pattern_no_bias -> add (-> reshape -> reshape) + """ + + # When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist + # When torch.compile'ing with dynamic=False, they don't exist + def get_pattern_no_bias(expand_a_scale: bool, reshape_a: bool = True): + return CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mul.Tensor, + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten._int_mm.default, + CallFunction( + aten.reshape.default, + KeywordArg("a"), + KeywordArg("in_shape"), + ) + if reshape_a + else KeywordArg("a"), + KeywordArg("b"), + ), + KeywordArg("dtype"), + ), + ( + CallFunction( + aten.expand.default, + KeywordArg("x_scale"), + Arg(), + ) + if expand_a_scale + else KeywordArg("x_scale") + ), + ), + KeywordArg("w_scale"), + ) + + def _with_outer_reshape(pattern): + return CallFunction( + aten.reshape.default, pattern, KeywordArg("out_shape_no_bias") + ) + + # for torch.compile(dynamic=False) + pattern_no_bias_1 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=False)) + pattern_with_bias_1 = CallFunction( + aten.add.Tensor, + pattern_no_bias_1, + KeywordArg("bias"), + ) + # for torch.compile(dynamic=True) + pattern_no_bias_2 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=True)) + pattern_with_bias_2 = CallFunction( + aten.reshape.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.add.Tensor, + pattern_no_bias_2, + KeywordArg("bias"), + ), + Arg(), + ), + KeywordArg("out_shape_with_bias"), + ) + + # The following patterns are for torchao int8_dynamic_activation_int8_weight linear, + # when both activation and weights are symmetrically quantized. + # In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used. + # Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias. + # Ideally, we should add mul + add post-op support in ATen int8 oneDNN linear op. + pattern1_with_no_outer_or_act_reshape = get_pattern_no_bias( + expand_a_scale=False, reshape_a=False + ) + pattern2_with_no_outer_or_act_reshape = get_pattern_no_bias( + expand_a_scale=True, reshape_a=False + ) + + def _validate_pattern(match: Match): + if len(match.nodes) not in [4, 5, 6, 7, 10]: + return False + # Make sure weight is a constant + aten_int_mm_node = filter_nodes(match.nodes, aten._int_mm.default)[0] + if not isinstance(aten_int_mm_node.args[1], torch.fx.node.Node): + return False + if aten_int_mm_node.args[1].op != "get_attr": + return False + + if len(match.nodes) == 10: + # Check the two tailing reshape nodes can be fused + if match.nodes[9].args[1] != match.nodes[6].args[1]: + return False + if len(match.nodes) == 10 or ( + len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor + ): + bias_idx = 7 if len(match.nodes) == 10 else 6 + # Check bias shape + bias_node = match.nodes[bias_idx].args[1] + if not isinstance(bias_node, torch.fx.node.Node): + return False + if len(bias_node.meta.get("tensor_meta").shape) != 1: # type: ignore[union-attr] + return False + return True + + pattern_to_pass_number = { + pattern_no_bias_2: 0, + pattern_with_bias_2: 0, + pattern_no_bias_1: 1, + pattern_with_bias_1: 1, + pattern1_with_no_outer_or_act_reshape: 2, + pattern2_with_no_outer_or_act_reshape: 2, + } + for pattern, pass_number in pattern_to_pass_number.items(): + + @register_freezing_graph_pattern( + pattern, + extra_check=_validate_pattern, + pass_number=pass_number, + ) + def _int_mm_weight_prepack(match: Match, *args, **kwargs): + bias = kwargs.get("bias", None) + x = kwargs["a"] + weight = kwargs["b"] + dtype = kwargs["dtype"] + x_scale = kwargs["x_scale"] + w_scale = kwargs["w_scale"] + x_shape = x.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + transpose_node = match.graph.call_function( + aten.permute.default, args=(weight, [1, 0]) + ) + contig_node = match.graph.call_function( + aten.contiguous.default, args=(transpose_node,) + ) + packed_weight_inputs = ( + contig_node, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = match.graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + dummy_zp = None + w_scale = match.graph.call_function( + prims.convert_element_type.default, args=(w_scale, torch.float32) + ) + + x_scale_shape = x_scale.meta.get("tensor_meta").shape + x_scale_is_scalar = False + if not has_free_symbols(x_scale_shape): + prod = 1 + for d in x_scale_shape: + prod *= d + x_scale_is_scalar = prod == 1 + + new_args: tuple[Any, ...] + if x_scale_is_scalar: + # in this case, we can call onednn.qlinear directly + new_args = ( + x, + x_scale, + dummy_zp, # x_zp + prepack_weight_node, + w_scale, + dummy_zp, # w_zp + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + new_linear_node = match.graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + out_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(out_node.meta) + else: + # onednn.qlinear does not support per-channel quantization of x + # so in this case, we have to apply x scale and add bias ourselves after qlinear + in_shape = kwargs.get("in_shape", None) + if in_shape is None: + x_reshaped = x + else: + x_reshaped = match.graph.call_function( + aten.reshape.default, args=(x, in_shape) + ) + new_args = ( + x_reshaped, + 1.0, # x_scale + 0, # x_zp + prepack_weight_node, + w_scale, + dummy_zp, # w_zp + None, # bias + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + new_linear_node = match.graph.call_function( + torch.ops.onednn.qlinear_pointwise, args=new_args + ) + # apply x scale + new_out_node = match.graph.call_function( + aten.mul.Tensor, args=(new_linear_node, x_scale) + ) + + # Add bias and reshape + has_outer_reshape = ( + kwargs.get("out_shape_with_bias", None) is not None + or kwargs.get("out_shape_no_bias", None) is not None + ) + + if has_outer_reshape: + out_shape = kwargs.get( + "out_shape_with_bias", kwargs["out_shape_no_bias"] + ) + if bias is not None: + new_out_node = match.graph.call_function( + aten.add.Tensor, args=(new_out_node, bias) + ) + if has_outer_reshape: + new_out_node = match.graph.call_function( + aten.reshape.default, + args=(new_out_node, out_shape), # type: ignore[possibly-undefined] + ) + else: + if has_outer_reshape: + new_out_node = match.graph.call_function( + aten.reshape.default, + args=(new_out_node, out_shape), # type: ignore[possibly-undefined] + ) + out_node.replace_all_uses_with(new_out_node) + new_out_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +class PostOpAttr: + def __init__( + self, + binary_op_name: str = "none", + alpha=None, + unary_op_name: str = "none", + scalars_attr=None, + algorithm_attr=None, + ) -> None: + self.binary_op_name = binary_op_name + self.alpha = alpha if alpha else 1.0 + self.unary_op_name = unary_op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + +def _register_qconv_post_op_fusion_pass( + pattern, + pass_number, + computation_op, + post_op_attr, +): + has_binary_post_op = post_op_attr.binary_op_name != "none" + + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) + o_zero_point = ( + kwargs["o_zp"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 0 + ) + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + if post_op_attr.unary_op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + post_op_attr.scalars_attr = [min_value, max_value] + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + if not has_binary_post_op: + computation_args: tuple[Any, ...] = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + else: + accum = ( + kwargs["accum"] + if output_dtype in [torch.uint8, torch.int8] + else kwargs["accum_after_dequant"] + ) + accum_scale = ( + kwargs["accum_scale"] + if output_dtype in [torch.uint8, torch.int8] + else 1.0 + ) + accum_zp = ( + kwargs["accum_zp"] + if output_dtype in [torch.uint8, torch.int8] + else 0 + ) + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + accum, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + accum_scale, + accum_zp, + post_op_attr.binary_op_name, + post_op_attr.alpha, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + new_conv_node = match.graph.call_function( + computation_op, args=computation_args + ) + out_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + count_key = ( + "qconv2d_binary_matcher_count" + if has_binary_post_op + else "qconv_unary_matcher_count" + ) + nodes_key = ( + "qconv2d_binary_matcher_nodes" + if has_binary_post_op + else "qconv_unary_matcher_nodes" + ) + counters["inductor"][count_key] += 1 + counters["inductor"][nodes_key] += len(match.nodes) + + return qconv + + +def _register_qconv_unary_fusion(): + from torch._inductor.fx_passes.mkldnn_fusion import ( + _hardswish_fusion, + _hardtanh_fusion, + _silu_fusion, + ) + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + # Priority 1 to match: QConv2d Unary pattern with int8 output + # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. + # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + conv_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + get_qconv_pt2e_pattern(1), + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_qconv_pt2e_pattern(1), aten.relu.default + ), + ), + PostOpAttr( + "none", None, "hardtanh", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardtanh_fusion, + get_qconv_pt2e_pattern(1), + 1, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "hardswish", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardswish_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "swish", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _silu_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + unary_attr, # unary_attr + ) + + # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output + conv_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + get_qconv_pt2e_pattern(1), aten.relu.default + ), + PostOpAttr( + "none", None, "hardtanh", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardtanh_fusion, + get_qconv_pt2e_pattern(1), + 1, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "hardswish", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardswish_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "swish", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _silu_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_qconv_post_op_fusion_pass( + patterns, + 4, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + unary_attr, # unary_attr + ) + + +def _register_qconv_binary_fusion(): + for int8_mixed_bf16_with_inplace_add in [False, True]: + # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output + swap_binary_inputs_list = [False, True] + binary_replace_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + ), + PostOpAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ), + ), + } + ) + + for binary_unary_attr, patterns in binary_replace_patterns.items(): + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ) + } + ) + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + if int8_mixed_bf16_with_inplace_add: + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + else: + _register_qconv_post_op_fusion_pass( + patterns, + 4, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + } + ) + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qconv_post_op_fusion_pass( + patterns, + 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + +def _register_qlinear_post_op_fusion_pass( + pattern, + pass_number, + computation_op, + post_op_attr, +): + has_binary_post_op = post_op_attr.binary_op_name != "none" + + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op), + pass_number=pass_number, + ) + def qlinear_post_op_fusion(match: Match, *args, **kwargs): + """ + Match the pattern: + qlinear - post op + """ + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs["b"] if "b" in kwargs else None + + # Output QParams + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype in [torch.uint8, torch.int8]) + else 1.0 + ) + o_zero_point = ( + kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0 + ) + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + if not has_binary_post_op: + computation_args: tuple[Any, ...] = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + else: + other = kwargs["other"] if "other" in kwargs else kwargs["accum"] + x2_scale = 1.0 + x2_zp = 0 + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + other, + b, + o_inv_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + post_op_attr.binary_op_name, + post_op_attr.alpha, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + new_linear_node = match.graph.call_function( + computation_op, args=computation_args + ) + out_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + count_key = ( + "qlinear_binary_matcher_count" + if has_binary_post_op + else "qlinear_unary_matcher_count" + ) + nodes_key = ( + "qlinear_binary_matcher_nodes" + if has_binary_post_op + else "qlinear_unary_matcher_nodes" + ) + counters["inductor"][count_key] += 1 + counters["inductor"][nodes_key] += len(match.nodes) + + +def _register_qlinear_unary_fusion(): + from torch._inductor.fx_passes.mkldnn_fusion import ( + _gelu_fusion_1 as _gelu_fusion_erf, + _gelu_fusion_2 as _gelu_fusion_tanh, + ) + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + qlinear_pattern, + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + computation_op, + unary_attr, # unary_attr + ) + + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + computation_op, + unary_attr, # unary_attr + ) + + +def _register_qlinear_binary_fusion(): + r""" + Supported linear-binary(-unary) patterns + + linear(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + 1. int8-mixed-fp32 + +---+---------------+-----------+------------------------------+---------+ + | # | Add type | Quant out | Pattern | Post op | + +---+---------------+-----------+------------------------------+---------+ + | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | + +---+---------------+-----------+------------------------------+---------+ + | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | + +---+---------------+-----------+------------------------------+---------+ + + 2. int8-mixed-bf16 + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | # | X2 dtype | Add type | Quant out | Pattern | Post op | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + + Note + (1) The positions of linear and the extra input can be swapped. + (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the + extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. + """ + for x_scale_zp_are_tensors in (False, True): + qlinear_binary_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + unary_postop_list = ["none", "relu"] + unary_postop_dict = { + "none": None, + "relu": aten.relu.default, + } + convert_dtype_after_binary_list = [False, True] + + # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output + # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, + # totally 3 patterns (2 are identical) + swap_binary_inputs_list = [False, True] + int8_mixed_bf16_list = [False, True] + combinations = itertools.product( + unary_postop_list, + int8_mixed_bf16_list, + swap_binary_inputs_list, + convert_dtype_after_binary_list, + ) + qlinear_binary_replace_patterns = {} + for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + if not int8_mixed_bf16 and cvt_dtype_binary: + # No convert node after binary node if dtypes are all fp32 + continue + qlinear_binary_replace_patterns.update( + { + PostOpAttr( + "add", 1.0, unary_op, [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + # If fp32 extra input is inplace added to bf16 linear output, + # a to_bf16 node is inserted after binary + dtype_convert=cvt_dtype_binary, + swap_inputs=swap_inputs, + ), + unary_postop_dict[unary_op], + ), + ) + } + ) + for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("add", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 5, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output + # Covers (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "add", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 5, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + +@functools.lru_cache(None) +def _register_quantization_weight_pack_pass(): + # Step 1: Dequant promotion for int8-mixed-fp32/bf16 + _register_dequant_promotion() + + # Step 2: QConv weight prepack + _register_qconv_weight_prepack() + + # Step 3: QLinear weight prepack + _register_qlinear_weight_prepack() + _register_linear_dynamic_fp16_weight_prepack() + + # Step 4: weight prepack for SmoothQuant from Torchao + _register_smooth_quant_int_mm_pattern() + + # Step 5: QLinear post op Fusion + if not torch.ops.mkldnn._is_mkldnn_acl_supported(): + # skip fusion on ARM + _register_qconv_unary_fusion() + _register_qconv_binary_fusion() + _register_qlinear_unary_fusion() + _register_qlinear_binary_fusion() + + +def quant_lift_up(module_graph: torch.fx.graph.Graph): + """ + Lift up the quant node before view like nodes. It can benefit performance + of Attention like block. For example, we have the pattern as: + + DQ + DQ LINEAR + LINEAR VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + Q Q + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + We want to lift up the the quant nodes from matmul before view like nodes + as the output of Linear node. + + DQ + DQ LINEAR + LINEAR Q + Q VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + It produces a DQ->LINEAR->Q pattern which can be fused by backend. + """ + + def is_view_op(node): + return node.op == "call_function" and node.target in _VIEW_OPS + + for node in module_graph.nodes: + # Leslie: Here we verify that the quant node has exactly + # one input FX node, with constant scalar value for scale and zero point. + # For the case input of quant node has more than one input FX nodes, + # extend the implementation to lift up all the connected nodes + # before the view nodes to keep the topological order. + if ( + node.op == "call_function" + and node.target in _PER_TENSOR_QUANTIZE_OPS + and len(node.all_input_nodes) == 1 + and is_view_op(node.all_input_nodes[0]) + ): + quant_node = node + input_node_of_quant = quant_node.args[0] + + # Check the nodes along lift up path has only 1 user node + # Propagate view like node to find where to insert the new quant node + could_lift_up = True + current_node = quant_node + input_node = current_node.args[0] + while is_view_op(input_node): + if len(input_node.users) != 1: + could_lift_up = False + break + current_node = input_node + input_node = current_node.args[0] + + # Further check the input node of the first view node has only 1 user node + if could_lift_up and len(input_node.users) == 1: + # Replace dequant's input from quant to quant's input + quant_node.replace_all_uses_with(input_node_of_quant) + # Insert the new quant node + with module_graph.inserting_before(current_node): + new_quant_node = module_graph.node_copy(quant_node) + input_node.replace_all_uses_with(new_quant_node) + + # Update inputs of new_quant_node + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == input_node_of_quant: + return input_node + else: + return n + + new_args = map_arg(new_quant_node.args, maybe_replace_node) + new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node) + new_quant_node.args = new_args # type: ignore[assignment] + new_quant_node.kwargs = new_kwargs # type: ignore[assignment] + module_graph.erase_node(quant_node) diff --git a/torchao/quantization/pt2e/pt2e/lowering.py b/torchao/quantization/pt2e/pt2e/lowering.py index 5491623b66..9caefa7a52 100644 --- a/torchao/quantization/pt2e/pt2e/lowering.py +++ b/torchao/quantization/pt2e/pt2e/lowering.py @@ -7,26 +7,53 @@ import torch from torch._inductor.constant_folding import constant_fold from torch._inductor.fx_passes.freezing_patterns import freezing_passes +from typing import Optional __all__ = [ "lower_pt2e_quantized_to_x86", ] +FUSION_PATH_REGISTERED = False + def lower_pt2e_quantized_to_x86( model: torch.fx.GraphModule, - example_inputs: tuple[torch.Tensor, ...], + example_inputs: Optional[tuple[torch.Tensor, ...]] = None, + compile: bool = True, + **compile_options: Optional[dict], ) -> torch.fx.GraphModule: """Lower a PT2E-qantized model to x86 backend. Args: * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model. + * `compile` (bool): whether to torch.compile the model. Default is True. + Torch.compile brings more performance improvement. + * `compile_options` (dict): options for torch.compile. Return: - A GraphModule lowered to x86 backend. + A module lowered to x86 backend. """ + if compile: + global FUSION_PATH_REGISTERED + if not FUSION_PATH_REGISTERED: + global torch + from torchao.prototype.inductor.fx_passes.quantization import ( + _register_quantization_weight_pack_pass, + quant_lift_up, + ) + import torch._inductor.config + + torch._inductor.config.pre_grad_custom_pass = quant_lift_up + _register_quantization_weight_pack_pass() + FUSION_PATH_REGISTERED = True + return torch.compile(model, **compile_options) + + assert example_inputs is not None, ( + "example_inputs should not be None when compile is False" + ) + def _post_autograd_decomp_table(): # type: ignore[no-untyped-def] decomp_table = torch.export.default_decompositions() From 2642a4b2e4a2b21c83b0c51f8f958e7cf3ebd363 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 02:09:50 -0700 Subject: [PATCH 02/12] Fix conflict after merging main --- test/quantization/pt2e/test_x86inductor_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 9073b1facb..bb7f0462e9 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -41,7 +41,7 @@ clone_preserve_strides_offset, HAS_CPU, ) -from torchao.quantization.pt2e.pt2e.lowering import lower_pt2e_quantized_to_x86 +from torchao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 # The dict value is match_nodes(computation_op+unary_op) From bd4b9ae4c511189e5d2ead903db41cb033ee52a8 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 02:24:38 -0700 Subject: [PATCH 03/12] Fix CI --- test/quantization/pt2e/test_x86inductor_fusion.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index bb7f0462e9..6cc269fdff 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -29,10 +29,8 @@ IS_FBCODE, IS_LINUX, IS_X86, - MI300_ARCH, parametrize, skipIfRocm, - skipIfRocmArch, TEST_ACL, xfailIfACL, ) @@ -312,7 +310,7 @@ def test_qconv2d_cpu(self): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - @skipIfRocmArch(MI300_ARCH) + @skipIfRocm def test_qconv2d_int8_mixed_bf16(self): r""" This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. @@ -1740,8 +1738,6 @@ def matcher_check_fn(): check_quantization=True, num_include_ops=[2, 2], ) - print("time 2:", time.time() - t0) - t0 = time.time() else: # For python wrapper self._test_code_common( From 3aaa0f515cdb20d179e7c15bd660209589aa2aad Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 02:34:44 -0700 Subject: [PATCH 04/12] Fix format issues --- test/quantization/pt2e/test_x86inductor_fusion.py | 10 +++++----- .../prototype/inductor/fx_passes/quantization.py | 14 ++++++-------- torchao/quantization/pt2e/lowering.py | 6 ++++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 6cc269fdff..1a3e9f1fe6 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -15,7 +15,7 @@ from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._inductor import config -from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.test_case import TestCase, run_tests from torch._inductor.utils import run_and_get_code from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.testing._internal.common_quantization import ( @@ -25,22 +25,22 @@ skipIfNoONEDNNBF16, ) from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, IS_FBCODE, IS_LINUX, IS_X86, + TEST_ACL, + instantiate_parametrized_tests, parametrize, skipIfRocm, - TEST_ACL, xfailIfACL, ) from torch.testing._internal.inductor_utils import ( + HAS_CPU, _check_has_dynamic_shape, clone_preserve_strides_offset, - HAS_CPU, ) -from torchao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 +from torchao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 # The dict value is match_nodes(computation_op+unary_op) unary_list = { diff --git a/torchao/prototype/inductor/fx_passes/quantization.py b/torchao/prototype/inductor/fx_passes/quantization.py index 06d4a9489c..4ccb2a1f31 100644 --- a/torchao/prototype/inductor/fx_passes/quantization.py +++ b/torchao/prototype/inductor/fx_passes/quantization.py @@ -3,24 +3,20 @@ import copy import functools import itertools -import math -import operator from typing import Any import torch from torch._dynamo.utils import counters -from torch.fx.experimental.symbolic_shapes import has_free_symbols -from torch.fx.node import map_arg - +from torch._inductor.fx_passes.freezing_patterns import register_freezing_graph_pattern from torch._inductor.pattern_matcher import ( Arg, CallFunction, - filter_nodes, KeywordArg, Match, + filter_nodes, ) -from torch._inductor.fx_passes.freezing_patterns import register_freezing_graph_pattern - +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.fx.node import map_arg aten = torch.ops.aten prims = torch.ops.prims @@ -2421,6 +2417,8 @@ def qlinear_post_op_fusion(match: Match, *args, **kwargs): def _register_qlinear_unary_fusion(): from torch._inductor.fx_passes.mkldnn_fusion import ( _gelu_fusion_1 as _gelu_fusion_erf, + ) + from torch._inductor.fx_passes.mkldnn_fusion import ( _gelu_fusion_2 as _gelu_fusion_tanh, ) diff --git a/torchao/quantization/pt2e/lowering.py b/torchao/quantization/pt2e/lowering.py index 9caefa7a52..f8fa0f7337 100644 --- a/torchao/quantization/pt2e/lowering.py +++ b/torchao/quantization/pt2e/lowering.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import torch from torch._inductor.constant_folding import constant_fold from torch._inductor.fx_passes.freezing_patterns import freezing_passes -from typing import Optional __all__ = [ "lower_pt2e_quantized_to_x86", @@ -39,11 +40,12 @@ def lower_pt2e_quantized_to_x86( global FUSION_PATH_REGISTERED if not FUSION_PATH_REGISTERED: global torch + import torch._inductor.config + from torchao.prototype.inductor.fx_passes.quantization import ( _register_quantization_weight_pack_pass, quant_lift_up, ) - import torch._inductor.config torch._inductor.config.pre_grad_custom_pass = quant_lift_up _register_quantization_weight_pack_pass() From 4491a989ef9cceeab1fbe326b697b74eeb29c5f8 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 02:37:53 -0700 Subject: [PATCH 05/12] Fix format issue --- test/quantization/pt2e/test_x86inductor_fusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 1a3e9f1fe6..e9a15d3f6c 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2570,5 +2570,6 @@ def matcher_check_fn(): if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available(): # set weight_prepack = False to skip fusion passes in pytorch core import torch._inductor.config + torch._inductor.config.cpp.weight_prepack = False run_tests() From e260028c5393d599baeaa2ac1d9e2ca21dac25a0 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 02:58:21 -0700 Subject: [PATCH 06/12] Fix versioning issue in UT --- test/quantization/pt2e/test_x86inductor_fusion.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index e9a15d3f6c..504aaf9782 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -37,11 +37,18 @@ from torch.testing._internal.inductor_utils import ( HAS_CPU, _check_has_dynamic_shape, - clone_preserve_strides_offset, ) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, +) from torchao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 +if TORCH_VERSION_AT_LEAST_2_6: + from torch.testing._internal.common_utils import TEST_ACL +else: + TEST_ACL = False + # The dict value is match_nodes(computation_op+unary_op) unary_list = { torch.nn.ReLU(): 2, @@ -173,10 +180,6 @@ def _test_common( device = self.device mod = mod.to(device=device) - if device != "cpu": - inputs = tuple( - clone_preserve_strides_offset(x, device=device) for x in inputs - ) counters.clear() torch._dynamo.reset() if check_autocast == torch.bfloat16 and ( From 692566ef272647a263a0fd57c1caff83abb1a200 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 03:00:31 -0700 Subject: [PATCH 07/12] Fix format issue --- test/quantization/pt2e/test_x86inductor_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 504aaf9782..73e06876d4 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -39,10 +39,10 @@ _check_has_dynamic_shape, ) +from torchao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_6, ) -from torchao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 if TORCH_VERSION_AT_LEAST_2_6: from torch.testing._internal.common_utils import TEST_ACL From 11c3924823979b0f132c1bea9329d64996eed5ba Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 04:28:40 -0700 Subject: [PATCH 08/12] Fix CI --- test/quantization/pt2e/test_x86inductor_fusion.py | 1 - torchao/prototype/inductor/__init__.py | 0 torchao/prototype/inductor/fx_passes/__init__.py | 0 3 files changed, 1 deletion(-) create mode 100644 torchao/prototype/inductor/__init__.py create mode 100644 torchao/prototype/inductor/fx_passes/__init__.py diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 73e06876d4..71e70c332d 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -28,7 +28,6 @@ IS_FBCODE, IS_LINUX, IS_X86, - TEST_ACL, instantiate_parametrized_tests, parametrize, skipIfRocm, diff --git a/torchao/prototype/inductor/__init__.py b/torchao/prototype/inductor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/inductor/fx_passes/__init__.py b/torchao/prototype/inductor/fx_passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 1c2a948ee7504c6ad928f7685b60fdd447d31097 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 04:37:03 -0700 Subject: [PATCH 09/12] Fix CI --- test/quantization/pt2e/test_x86inductor_fusion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 71e70c332d..1fa1a0df58 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -31,7 +31,6 @@ instantiate_parametrized_tests, parametrize, skipIfRocm, - xfailIfACL, ) from torch.testing._internal.inductor_utils import ( HAS_CPU, @@ -2406,7 +2405,6 @@ def matcher_check_fn(): } ) class TestDynamicPatternMatcher(TestPatternMatcherBase): - @xfailIfACL def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): r""" This testcase will quantize a single Conv2d->Maxpool2d->Linear module @@ -2569,7 +2567,7 @@ def matcher_check_fn(): instantiate_parametrized_tests(TestPatternMatcher) if __name__ == "__main__": - if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available(): + if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available(): # set weight_prepack = False to skip fusion passes in pytorch core import torch._inductor.config From ee2c1b15affc2d467123655b3fce334362777a8d Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 28 Apr 2025 07:53:22 -0700 Subject: [PATCH 10/12] Fix CI --- test/quantization/pt2e/test_x86inductor_fusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 1fa1a0df58..786280d6df 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -40,6 +40,7 @@ from torchao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_8, ) if TORCH_VERSION_AT_LEAST_2_6: @@ -251,6 +252,7 @@ def _test_code_common( torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8+") class TestPatternMatcher(TestPatternMatcherBase): def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): class M(torch.nn.Module): @@ -2404,6 +2406,7 @@ def matcher_check_fn(): "specialize_float": True, } ) +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8+") class TestDynamicPatternMatcher(TestPatternMatcherBase): def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): r""" From 524281a909c4e7b70a85629d43b42e272658c687 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 29 Apr 2025 23:26:42 -0700 Subject: [PATCH 11/12] Move registration of Inductor fusion passes to x86_inductor_quantizer.py --- .../pt2e/test_x86inductor_fusion.py | 43 +++++++++++++++---- .../pt2e/inductor_passes/__init__.py | 0 .../pt2e/inductor_passes/x86.py} | 0 torchao/quantization/pt2e/lowering.py | 31 +------------ .../pt2e/quantizer/x86_inductor_quantizer.py | 12 ++++++ 5 files changed, 47 insertions(+), 39 deletions(-) create mode 100644 torchao/quantization/pt2e/inductor_passes/__init__.py rename torchao/{prototype/inductor/fx_passes/quantization.py => quantization/pt2e/inductor_passes/x86.py} (100%) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 786280d6df..78204fb756 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -11,15 +11,13 @@ import unittest import torch -import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._inductor import config from torch._inductor.test_case import TestCase, run_tests from torch._inductor.utils import run_and_get_code -from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( - _generate_qdq_quantized_model, skipIfNoDynamoSupport, skipIfNoONEDNN, skipIfNoONEDNNBF16, @@ -37,7 +35,16 @@ _check_has_dynamic_shape, ) -from torchao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 +import torchao +import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( + X86InductorQuantizer, +) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_8, @@ -102,6 +109,26 @@ def get_default_quantizer(is_qat, is_dynamic): return quantizer +def _generate_qdq_quantized_model( + mod, inputs, is_qat=False, is_dynamic=False, quantizer=None +): + maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() + with maybe_no_grad: + export_model = export_for_training(mod, inputs, strict=True).module() + quantizer = ( + quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic) + ) + prepare_model = ( + prepare_qat_pt2e(export_model, quantizer) + if is_qat + else prepare_pt2e(export_model, quantizer) + ) + prepare_model(*inputs) + torchao.quantization.pt2e.move_exported_model_to_eval(prepare_model) + convert_model = convert_pt2e(prepare_model) + return convert_model + + def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"): # this function is to decide how many kernels are generated # while testing conv2d/3d/deconv2d @@ -201,15 +228,13 @@ def _test_common( mod, inputs, is_qat, is_dynamic, quantizer ) with torch.no_grad(), maybe_autocast: - _ = lower_pt2e_quantized_to_x86(convert_model)(*inputs) + _ = torch.compile(convert_model)(*inputs) matcher_check_fn() else: with torch.no_grad(), maybe_autocast: clone_inputs = self._clone_inputs(inputs) expected = mod(*inputs) - actual = lower_pt2e_quantized_to_x86(mod, **compile_options)( - *clone_inputs - ) + actual = torch.compile(mod, **compile_options)(*clone_inputs) torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) matcher_check_fn() @@ -232,7 +257,7 @@ def _test_code_common( mod = _generate_qdq_quantized_model(mod, inputs, quantizer=quantizer) expected = mod(*inputs) actual, (source_code,) = run_and_get_code( - lower_pt2e_quantized_to_x86(mod, fullgraph=True, dynamic=check_dynamic), + torch.compile(mod, fullgraph=True, dynamic=check_dynamic), *clone_inputs, ) for op in include_ops: diff --git a/torchao/quantization/pt2e/inductor_passes/__init__.py b/torchao/quantization/pt2e/inductor_passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/inductor/fx_passes/quantization.py b/torchao/quantization/pt2e/inductor_passes/x86.py similarity index 100% rename from torchao/prototype/inductor/fx_passes/quantization.py rename to torchao/quantization/pt2e/inductor_passes/x86.py diff --git a/torchao/quantization/pt2e/lowering.py b/torchao/quantization/pt2e/lowering.py index f8fa0f7337..76dad800cd 100644 --- a/torchao/quantization/pt2e/lowering.py +++ b/torchao/quantization/pt2e/lowering.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional - import torch from torch._inductor.constant_folding import constant_fold from torch._inductor.fx_passes.freezing_patterns import freezing_passes @@ -14,48 +12,21 @@ "lower_pt2e_quantized_to_x86", ] -FUSION_PATH_REGISTERED = False - def lower_pt2e_quantized_to_x86( model: torch.fx.GraphModule, - example_inputs: Optional[tuple[torch.Tensor, ...]] = None, - compile: bool = True, - **compile_options: Optional[dict], + example_inputs: tuple[torch.Tensor, ...], ) -> torch.fx.GraphModule: """Lower a PT2E-qantized model to x86 backend. Args: * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model. - * `compile` (bool): whether to torch.compile the model. Default is True. - Torch.compile brings more performance improvement. - * `compile_options` (dict): options for torch.compile. Return: A module lowered to x86 backend. """ - if compile: - global FUSION_PATH_REGISTERED - if not FUSION_PATH_REGISTERED: - global torch - import torch._inductor.config - - from torchao.prototype.inductor.fx_passes.quantization import ( - _register_quantization_weight_pack_pass, - quant_lift_up, - ) - - torch._inductor.config.pre_grad_custom_pass = quant_lift_up - _register_quantization_weight_pack_pass() - FUSION_PATH_REGISTERED = True - return torch.compile(model, **compile_options) - - assert example_inputs is not None, ( - "example_inputs should not be None when compile is False" - ) - def _post_autograd_decomp_table(): # type: ignore[no-untyped-def] decomp_table = torch.export.default_decompositions() diff --git a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py index 9b49cadf77..d49d4b9602 100644 --- a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py +++ b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py @@ -1625,3 +1625,15 @@ def _same_shape(n1: Node, n2: Node): def validate(self, model: torch.fx.GraphModule) -> None: pass + + +# Register Inductor fusion passes +import torch._inductor.config + +from torchao.quantization.pt2e.inductor_passes.x86 import ( + _register_quantization_weight_pack_pass, + quant_lift_up, +) + +torch._inductor.config.pre_grad_custom_pass = quant_lift_up +_register_quantization_weight_pack_pass() From 8e4532fc7f762909c8d5b8a252d0b0b22aaeca04 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 30 Apr 2025 00:44:52 -0700 Subject: [PATCH 12/12] Fix CI --- .../quantization/pt2e/quantizer/x86_inductor_quantizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py index d49d4b9602..cc296ebe33 100644 --- a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py +++ b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py @@ -1634,6 +1634,8 @@ def validate(self, model: torch.fx.GraphModule) -> None: _register_quantization_weight_pack_pass, quant_lift_up, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 -torch._inductor.config.pre_grad_custom_pass = quant_lift_up -_register_quantization_weight_pack_pass() +if TORCH_VERSION_AT_LEAST_2_8: + torch._inductor.config.pre_grad_custom_pass = quant_lift_up + _register_quantization_weight_pack_pass()