Skip to content

Add support for quantize_() with Float8Linear module #1344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch.nn as nn

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -531,6 +531,21 @@ def test_inference_mode(self):
with torch.inference_mode(mode=True):
m(x)

@unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available")
def test_quantize(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
m = convert_to_float8_training(m)
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
from torchao.quantization.quant_api import float8_weight_only, quantize_

quantize_(m, float8_weight_only())
assert (
m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn
), "Post quantization dtype should be torch.float8_e4m3fn"
with torch.no_grad():
m(x)


class TestScaledMM:
@unittest.skipIf(
Expand Down Expand Up @@ -576,7 +591,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not is_cuda_8_9, "CUDA not available")
Expand Down
7 changes: 7 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
to_affine_quantized_intx,
to_marlinqqq_quantized_intx,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.inference import Float8MMConfig
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
Expand Down Expand Up @@ -222,6 +223,12 @@ def _replace_with_custom_fn_if_matches_filter(
Returns:
None
"""
if isinstance(model, Float8Linear):
with torch.device("meta"):
new_module = nn.Linear(model.in_features, model.out_features)
new_module.weight = model.weight
new_module.bias = model.bias
model = new_module
if filter_fn(model, cur_fqn[:-1]):
if device is not None:
model.to(device=device) # move to device before quantization
Expand Down
Loading