diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d00b96d3bb..245abe0d02 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -531,6 +531,21 @@ def test_inference_mode(self): with torch.inference_mode(mode=True): m(x) + @unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available") + def test_quantize(self): + x = torch.randn(32, 32, device="cuda") + m = nn.Sequential(nn.Linear(32, 32)).cuda() + m = convert_to_float8_training(m) + assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" + from torchao.quantization.quant_api import float8_weight_only, quantize_ + + quantize_(m, float8_weight_only()) + assert ( + m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn + ), "Post quantization dtype should be torch.float8_e4m3fn" + with torch.no_grad(): + m(x) + class TestScaledMM: @unittest.skipIf( @@ -576,7 +591,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: - atol, rtol = 2e-3, 2e-3 + atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) @unittest.skipIf(not is_cuda_8_9, "CUDA not available") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ddeb4ef2fb..60a7341e39 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -39,6 +39,7 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) +from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -222,6 +223,12 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ + if isinstance(model, Float8Linear): + with torch.device("meta"): + new_module = nn.Linear(model.in_features, model.out_features) + new_module.weight = model.weight + new_module.bias = model.bias + model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization