Skip to content

[Quant][PT2E][X86] Enable annotation of aten.mul.tensor with X86InductorQuantizer #2075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions test/quantization/pt2e/test_x86inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,6 +2876,73 @@ def test_lowering_to_x86(self):
lower=True,
)

@skipIfNoX86
def test_annotate_mul_tensor(self):
class M1(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x * y

class M2(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x * y.sum(-1)

class M3(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x * y.sum()

class M4(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x * y.sum().item()

for Mod in [M1, M2, M3, M4]:
with override_quantized_engine("x86"), torch.no_grad():
m = Mod().eval()
example_inputs = (torch.randn(64, 64), torch.randn(64, 64))
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
quantizer.set_function_type_qconfig(
torch.mul, quantizer.get_global_quantization_config()
)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2
if isinstance(m, M1)
else 0,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2
if isinstance(m, M1)
else 0,
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
}
node_list = [
torch.ops.aten.mul.Tensor,
]
if isinstance(m, M1):
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
] + node_list

self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)


if __name__ == "__main__":
run_tests()
50 changes: 50 additions & 0 deletions torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
torch.ops.aten.mul.Tensor,
}

# A superset of default_quantizable_ops includes operators support the int8 data type
Expand Down Expand Up @@ -219,6 +220,12 @@ def _map_module_function_to_aten_operator_type():
],
torch.ops.aten.matmul.default,
),
(
[
torch.mul,
],
torch.ops.aten.mul.Tensor,
),
)
for map_item in map_list:
module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload]
Expand Down Expand Up @@ -735,6 +742,7 @@ def _annotate_with_config(
self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn)
self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn)
self._annotate_matmul(model, quantization_config, filter_fn)
self._annotate_mul_tensor(model, quantization_config, filter_fn)

# Step2: Recipe to propagate annotation for patterns beside conv/linear.
# Go through all the nodes from start to end.
Expand Down Expand Up @@ -1577,5 +1585,47 @@ def _annotate_linear_binary_unary(
)
)

def _annotate_mul_tensor(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
):
def _is_tensor(n: Node):
return isinstance(n, Node) and isinstance(
n.meta["val"], torch._subclasses.fake_tensor.FakeTensor
)

def _same_shape(n1: Node, n2: Node):
return n1.meta["val"].shape == n2.meta["val"].shape

for node in model.graph.nodes:
if node.target != torch.ops.aten.mul.Tensor:
continue

if _skip_annotate([node], filter_fn):
continue

if quantization_config is None:
_annotate_nodes_not_quantize(node)
continue

assert len(node.args) == 2
if not (_is_tensor(node.args[0]) and _is_tensor(node.args[1])):
continue

if not _same_shape(node.args[0], node.args[1]):
continue

input_qspec_map = {}
mul_node = node
for input_node in mul_node.args:
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
mul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
_is_output_of_quantized_pattern=True,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass
Loading