diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6a5ea8ef9d..8e047985c5 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -72,7 +72,7 @@ AQInt8WeightOnlyQuantizedLinearWeight2, AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, - + AQFloat8WeightOnlyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -98,6 +98,7 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf(not is_H100, "Need H100 to run") + def test_aq_float8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 82fb117363..4ed0974172 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -16,6 +16,7 @@ logging.basicConfig(level=logging.INFO) +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) class TestQuantFlow(unittest.TestCase): @@ -49,6 +50,25 @@ def test_int_mm(self, device, dtype): assert out32_2.dtype == out32_1.dtype torch.testing.assert_allclose(out32_1, out32_2) + @parameterized.expand( + [ + ("cuda", torch.bfloat16), + ("cuda", torch.float16), + ] + ) + @unittest.skipIf(not is_H100, "Needs H100") + def test_int_mm_float8(self, device, dtype): + from torchao.kernel import intmm + + dtype = torch.bfloat16 + m, k, n = (128, 64, 16) + x = torch.randn(m, k, dtype=dtype, device=device) + w = torch.randn(n, k, dtype=dtype, device=device).t() + x_float8 = x.to(dtype=torch.float8_e4m3fn) + w_float8 = w.to(dtype=torch.float8_e4m3fn) + out32_1 = intmm.safe_int_mm(x_float8, w_float8) + assert out32_1.dtype == torch.int32 + @parameterized.expand( [ ("cuda", torch.bfloat16), diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 418e75d039..025f36ec39 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -335,8 +335,8 @@ def from_hp_to_floatx( input_float: torch.Tensor, block_size: Tuple[int, ...], target_dtype: torch.dtype, - scale_dtype: Optional[torch.dtype], layout_type: LayoutType, + scale_dtype: Optional[torch.dtype] = None, ): if target_dtype in FP8_TYPES: diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 3005cb16a9..81e7b19b15 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -69,7 +69,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: input = ( input.contiguous() ) # (it seems the transpose makes cublas check the above j constraint on i) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except Exception: + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) else: def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 39482caf84..089add1d87 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,7 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) -from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType +from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( @@ -477,6 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): def from_float(cls, weight): return weight +class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn + """ + target_dtype: torch.dtype = torch.float8_e4m3fn + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias) + + @classmethod + def from_float(cls, weight): + block_size = (1, weight.shape[1]) + return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) + + # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, @@ -493,6 +509,11 @@ def from_float(cls, weight): AQInt4G64WeightOnlyQuantizedLinearWeight ] +OTHER_AUTOQUANT_CLASS_LIST = [ + AQFloat8WeightOnlyQuantizedLinearWeight, +] + + def _change_linears_to_autoquantizable(model, **kwargs): """ Converts all linear weight tensors to the @@ -617,6 +638,8 @@ def autoquant( if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() + if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST: + assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9" # perform initial swap from linear weights # to AutoQuantizableLinearWeight