diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index e6c68b06e4..74d00a04cb 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2876,6 +2876,73 @@ def test_lowering_to_x86(self): lower=True, ) + @skipIfNoX86 + def test_annotate_mul_tensor(self): + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x * y + + class M2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x * y.sum(-1) + + class M3(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x * y.sum() + + class M4(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x * y.sum().item() + + for Mod in [M1, M2, M3, M4]: + with override_quantized_engine("x86"), torch.no_grad(): + m = Mod().eval() + example_inputs = (torch.randn(64, 64), torch.randn(64, 64)) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + quantizer.set_function_type_qconfig( + torch.mul, quantizer.get_global_quantization_config() + ) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2 + if isinstance(m, M1) + else 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2 + if isinstance(m, M1) + else 0, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.aten.mul.Tensor, + ] + if isinstance(m, M1): + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + node_list + + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + if __name__ == "__main__": run_tests() diff --git a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py index 51861f1a3f..a74905f9c2 100644 --- a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py +++ b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py @@ -93,6 +93,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, + torch.ops.aten.mul.Tensor, } # A superset of default_quantizable_ops includes operators support the int8 data type @@ -219,6 +220,12 @@ def _map_module_function_to_aten_operator_type(): ], torch.ops.aten.matmul.default, ), + ( + [ + torch.mul, + ], + torch.ops.aten.mul.Tensor, + ), ) for map_item in map_list: module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] @@ -735,6 +742,7 @@ def _annotate_with_config( self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) self._annotate_matmul(model, quantization_config, filter_fn) + self._annotate_mul_tensor(model, quantization_config, filter_fn) # Step2: Recipe to propagate annotation for patterns beside conv/linear. # Go through all the nodes from start to end. @@ -1577,5 +1585,47 @@ def _annotate_linear_binary_unary( ) ) + def _annotate_mul_tensor( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + def _is_tensor(n: Node): + return isinstance(n, Node) and isinstance( + n.meta["val"], torch._subclasses.fake_tensor.FakeTensor + ) + + def _same_shape(n1: Node, n2: Node): + return n1.meta["val"].shape == n2.meta["val"].shape + + for node in model.graph.nodes: + if node.target != torch.ops.aten.mul.Tensor: + continue + + if _skip_annotate([node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + continue + + assert len(node.args) == 2 + if not (_is_tensor(node.args[0]) and _is_tensor(node.args[1])): + continue + + if not _same_shape(node.args[0], node.args[1]): + continue + + input_qspec_map = {} + mul_node = node + for input_node in mul_node.args: + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + mul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + def validate(self, model: torch.fx.GraphModule) -> None: pass