diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index b63a406715..33a1fe66a7 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -675,6 +675,46 @@ def test_preprocess_scale_3d_reshape(self): expected_shape = (8, 1) # Flattened (2*2*2, 1) self.assertEqual(result.shape, expected_shape) + @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) + def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype): + quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 + dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 + input = torch.randn(10, 10) + with torch.no_grad(): + torch._dynamo.reset() + expected_scale = torch.tensor(2.0) + expected_quantized = quantize_affine_float8( + input, + expected_scale, + float8_dtype=float8_dtype, + ) + expected_dequantized = dequantize_affine_float8( + expected_quantized, + expected_scale, + output_dtype=hp_dtype, + ) + test_q, (code_q,) = torch._inductor.utils.run_and_get_code( + torch.compile(quantize_affine_float8), + input, + expected_scale, + float8_dtype=float8_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.quantize_affine_float8.default" + ).run(code_q) + test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( + torch.compile(dequantize_affine_float8), + test_q, + expected_scale, + hp_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.dequantize_affine_float8.default" + ).run(code_dq) + torch.testing.assert_close(expected_quantized, test_q) + torch.testing.assert_close(expected_dequantized, test_dq) + common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index df136bc06e..56e8422197 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2270,6 +2270,7 @@ def _expand_scale_to_tensor_shape( return expanded_scale +@_register_custom_op(quant_lib, False) def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, @@ -2290,6 +2291,16 @@ def _quantize_affine_float8( return fp8_tensor +@torch.library.impl(quant_lib, "quantize_affine_float8", "Meta") +def _quantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=float8_dtype) + + +@_register_custom_op(quant_lib, False) def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, @@ -2305,3 +2316,12 @@ def _dequantize_affine_float8( hp_tensor = fp8_tensor * scale_expanded return hp_tensor.to(output_dtype) + + +@torch.library.impl(quant_lib, "dequantize_affine_float8", "Meta") +def _dequantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=output_dtype) diff --git a/torchao/utils.py b/torchao/utils.py index 416d23d785..1a12fb0668 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int: return n + k - (n % k) -def _register_custom_op(lib): +def _register_custom_op(lib, inductor_decomposed=True): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -206,6 +206,12 @@ def _the_op_that_needs_to_be_preserved(...) """ from torch._inductor.decomposition import register_decomposition + dispatch_key = ( + "CompositeImplicitAutograd" + if inductor_decomposed + else "CompositeExplicitAutograd" + ) + def decorator(fn): if TORCH_VERSION_AT_LEAST_2_5: from torch._library.infer_schema import infer_schema @@ -221,11 +227,12 @@ def decorator(fn): op_name = fn.__name__[1:] schema = op_name + infer_schema(fn, mutates_args={}) lib.define(schema) - lib.impl(op_name, fn, "CompositeImplicitAutograd") + lib.impl(op_name, fn, dispatch_key) lib_namespace = lib.ns op = getattr(getattr(torch.ops, lib_namespace), op_name) - register_decomposition([op])(fn) + if inductor_decomposed: + register_decomposition([op])(fn) return op else: return fn