From a840ef535dbf91bdfca97742d469a541eb33c1e1 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 16 Jun 2025 13:57:11 +0000 Subject: [PATCH 1/6] quantize_affine_float8/dequantize_affine_float8 not decomposed on inductor --- test/float8/test_compile.py | 59 ++++++++++++++++++++++++ torchao/quantization/quant_primitives.py | 24 +++++++++- torchao/utils.py | 11 +++-- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ac5d1f8d96..fb4d6ea316 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,6 +37,10 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig +from torchao.quantization.quant_primitives import ( + dequantize_affine_float8, + quantize_affine_float8, +) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -392,5 +396,60 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) +@pytest.mark.parametrize( + "float8_dtype", + [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], +) +@pytest.mark.parametrize( + "hp_dtype", + [ + torch.float32, + torch.float16, + torch.bfloat16, + ], +) +@unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "skipping when torch version is 2.5 or lower" +) +def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): + input = torch.randn(10, 10) + with torch.no_grad(): + torch._dynamo.reset() + expected_scale = torch.tensor(2.0) + expected_quantized = quantize_affine_float8( + input, + expected_scale, + float8_dtype, + ) + expected_dequantized = dequantize_affine_float8( + expected_quantized, + expected_scale, + output_dtype=hp_dtype, + ) + test_q, (code_q,) = torch._inductor.utils.run_and_get_code( + torch.compile(quantize_affine_float8), + input, + expected_scale, + float8_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.quantize_affine_float8.default" + ).run(code_q) + test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( + torch.compile(dequantize_affine_float8), + test_q, + expected_scale, + hp_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.dequantize_affine_float8.default" + ).run(code_dq) + torch.testing.assert_close(expected_quantized, test_q) + torch.testing.assert_close(expected_dequantized, test_dq) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index df136bc06e..72b3935157 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2270,10 +2270,11 @@ def _expand_scale_to_tensor_shape( return expanded_scale +@_register_custom_op(quant_lib, False) def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype = torch.float8_e4m3fn, + float8_dtype: torch.dtype, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2290,10 +2291,20 @@ def _quantize_affine_float8( return fp8_tensor +@torch.library.impl(quant_lib, "quantize_affine_float8", "Meta") +def _quantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=float8_dtype) + + +@_register_custom_op(quant_lib, False) def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype = torch.float32, + output_dtype: torch.dtype, ) -> torch.Tensor: """ Dequantizes the float8 tensor to high precision tensor. @@ -2305,3 +2316,12 @@ def _dequantize_affine_float8( hp_tensor = fp8_tensor * scale_expanded return hp_tensor.to(output_dtype) + + +@torch.library.impl(quant_lib, "dequantize_affine_float8", "Meta") +def _dequantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=output_dtype) diff --git a/torchao/utils.py b/torchao/utils.py index 416d23d785..99a0a729f5 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int: return n + k - (n % k) -def _register_custom_op(lib): +def _register_custom_op(lib, implicit=True): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -206,6 +206,10 @@ def _the_op_that_needs_to_be_preserved(...) """ from torch._inductor.decomposition import register_decomposition + dispatch_key = ( + "CompositeImplicitAutograd" if implicit else "CompositeExplicitAutograd" + ) + def decorator(fn): if TORCH_VERSION_AT_LEAST_2_5: from torch._library.infer_schema import infer_schema @@ -221,11 +225,12 @@ def decorator(fn): op_name = fn.__name__[1:] schema = op_name + infer_schema(fn, mutates_args={}) lib.define(schema) - lib.impl(op_name, fn, "CompositeImplicitAutograd") + lib.impl(op_name, fn, dispatch_key) lib_namespace = lib.ns op = getattr(getattr(torch.ops, lib_namespace), op_name) - register_decomposition([op])(fn) + if implicit: + register_decomposition([op])(fn) return op else: return fn From 02d045b267c20cca276770682d8769ec9d8c258b Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 16 Jun 2025 14:31:21 +0000 Subject: [PATCH 2/6] remove redundant unittest.skipIf --- test/float8/test_compile.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index fb4d6ea316..43f0d8e2f2 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -411,9 +411,6 @@ def test_dynamic_scale_numeric_parity( torch.bfloat16, ], ) -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "skipping when torch version is 2.5 or lower" -) def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): input = torch.randn(10, 10) with torch.no_grad(): From 9860c56e87ef83986f9ca25b87c32e7a3023f186 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 15:44:02 +0000 Subject: [PATCH 3/6] fix rebase issue --- test/float8/test_compile.py | 10 ++++------ torchao/quantization/quant_primitives.py | 8 ++++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 43f0d8e2f2..64feaf7b5d 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,10 +37,6 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.quantization.quant_primitives import ( - dequantize_affine_float8, - quantize_affine_float8, -) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -412,6 +408,8 @@ def test_dynamic_scale_numeric_parity( ], ) def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): + quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 + dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 input = torch.randn(10, 10) with torch.no_grad(): torch._dynamo.reset() @@ -419,7 +417,7 @@ def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): expected_quantized = quantize_affine_float8( input, expected_scale, - float8_dtype, + float8_dtype=float8_dtype, ) expected_dequantized = dequantize_affine_float8( expected_quantized, @@ -430,7 +428,7 @@ def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): torch.compile(quantize_affine_float8), input, expected_scale, - float8_dtype, + float8_dtype=float8_dtype, ) torch.testing.FileCheck().check( "torch.ops.torchao.quantize_affine_float8.default" diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 72b3935157..56e8422197 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2274,7 +2274,7 @@ def _expand_scale_to_tensor_shape( def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2295,7 +2295,7 @@ def _quantize_affine_float8( def _quantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: return torch.empty_like(tensor, dtype=float8_dtype) @@ -2304,7 +2304,7 @@ def _quantize_affine_float8_meta( def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Dequantizes the float8 tensor to high precision tensor. @@ -2322,6 +2322,6 @@ def _dequantize_affine_float8( def _dequantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: return torch.empty_like(tensor, dtype=output_dtype) From ca662f343e9164afa761e352e2a8e72400501c0a Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 15:47:32 +0000 Subject: [PATCH 4/6] change dispatch key to a flag decomposed --- torchao/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/utils.py b/torchao/utils.py index 99a0a729f5..4814c7ec63 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int: return n + k - (n % k) -def _register_custom_op(lib, implicit=True): +def _register_custom_op(lib, decomposed=True): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -207,7 +207,7 @@ def _the_op_that_needs_to_be_preserved(...) from torch._inductor.decomposition import register_decomposition dispatch_key = ( - "CompositeImplicitAutograd" if implicit else "CompositeExplicitAutograd" + "CompositeImplicitAutograd" if decomposed else "CompositeExplicitAutograd" ) def decorator(fn): @@ -229,7 +229,7 @@ def decorator(fn): lib_namespace = lib.ns op = getattr(getattr(torch.ops, lib_namespace), op_name) - if implicit: + if decomposed: register_decomposition([op])(fn) return op else: From d6ab45dcc24422969dfdca8fefa6d997e479aa81 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Thu, 19 Jun 2025 09:13:44 +0000 Subject: [PATCH 5/6] To be more explicit, use name inductor_decomposed instead --- torchao/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchao/utils.py b/torchao/utils.py index 4814c7ec63..1a12fb0668 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int: return n + k - (n % k) -def _register_custom_op(lib, decomposed=True): +def _register_custom_op(lib, inductor_decomposed=True): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -207,7 +207,9 @@ def _the_op_that_needs_to_be_preserved(...) from torch._inductor.decomposition import register_decomposition dispatch_key = ( - "CompositeImplicitAutograd" if decomposed else "CompositeExplicitAutograd" + "CompositeImplicitAutograd" + if inductor_decomposed + else "CompositeExplicitAutograd" ) def decorator(fn): @@ -229,7 +231,7 @@ def decorator(fn): lib_namespace = lib.ns op = getattr(getattr(torch.ops, lib_namespace), op_name) - if decomposed: + if inductor_decomposed: register_decomposition([op])(fn) return op else: From 4e0445f629ddaf2b43ac80eaeb8f18fbcfd5e50f Mon Sep 17 00:00:00 2001 From: wengshiy Date: Thu, 19 Jun 2025 17:01:51 +0000 Subject: [PATCH 6/6] Change ut path --- test/dtypes/test_affine_quantized_float.py | 40 ++++++++++++++++ test/float8/test_compile.py | 54 ---------------------- 2 files changed, 40 insertions(+), 54 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index b63a406715..33a1fe66a7 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -675,6 +675,46 @@ def test_preprocess_scale_3d_reshape(self): expected_shape = (8, 1) # Flattened (2*2*2, 1) self.assertEqual(result.shape, expected_shape) + @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) + def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype): + quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 + dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 + input = torch.randn(10, 10) + with torch.no_grad(): + torch._dynamo.reset() + expected_scale = torch.tensor(2.0) + expected_quantized = quantize_affine_float8( + input, + expected_scale, + float8_dtype=float8_dtype, + ) + expected_dequantized = dequantize_affine_float8( + expected_quantized, + expected_scale, + output_dtype=hp_dtype, + ) + test_q, (code_q,) = torch._inductor.utils.run_and_get_code( + torch.compile(quantize_affine_float8), + input, + expected_scale, + float8_dtype=float8_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.quantize_affine_float8.default" + ).run(code_q) + test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( + torch.compile(dequantize_affine_float8), + test_q, + expected_scale, + hp_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.dequantize_affine_float8.default" + ).run(code_dq) + torch.testing.assert_close(expected_quantized, test_q) + torch.testing.assert_close(expected_dequantized, test_dq) + common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 64feaf7b5d..ac5d1f8d96 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -392,59 +392,5 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) -@pytest.mark.parametrize( - "float8_dtype", - [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ], -) -@pytest.mark.parametrize( - "hp_dtype", - [ - torch.float32, - torch.float16, - torch.bfloat16, - ], -) -def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): - quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 - dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 - input = torch.randn(10, 10) - with torch.no_grad(): - torch._dynamo.reset() - expected_scale = torch.tensor(2.0) - expected_quantized = quantize_affine_float8( - input, - expected_scale, - float8_dtype=float8_dtype, - ) - expected_dequantized = dequantize_affine_float8( - expected_quantized, - expected_scale, - output_dtype=hp_dtype, - ) - test_q, (code_q,) = torch._inductor.utils.run_and_get_code( - torch.compile(quantize_affine_float8), - input, - expected_scale, - float8_dtype=float8_dtype, - ) - torch.testing.FileCheck().check( - "torch.ops.torchao.quantize_affine_float8.default" - ).run(code_q) - test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( - torch.compile(dequantize_affine_float8), - test_q, - expected_scale, - hp_dtype, - ) - torch.testing.FileCheck().check( - "torch.ops.torchao.dequantize_affine_float8.default" - ).run(code_dq) - torch.testing.assert_close(expected_quantized, test_q) - torch.testing.assert_close(expected_dequantized, test_dq) - - if __name__ == "__main__": pytest.main([__file__])