Skip to content

Add Float8ActInt4WeightQATQuantizer #2289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401

from torchao import quantize_
from torchao.float8.config import ScalingGranularity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda hate that we have ScalingGranuliarty and Ganularity of the other FP8 inference APIs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is worth fixing before landing. @andrewor14 , how about just using rowwise scaling (since I assume that the one you want) and removing the option to confugure it? That will at least keep this problem away from the BC surface of QAT in a way that we can more easily fix later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure

from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.quantization.granularity import (
PerAxis,
PerGroup,
Expand All @@ -40,15 +43,18 @@
)
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
_Float8RowwiseActivationFakeQuantizer,
)
from torchao.quantization.qat.linear import (
FakeQuantizedLinear,
Float8ActInt4WeightQATQuantizer,
Int4WeightOnlyQATLinear,
Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.utils import (
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_Float8RowwiseFakeQuantize,
_get_qmin_qmax,
)
from torchao.quantization.quant_api import (
Expand All @@ -68,6 +74,7 @@
)
from torchao.quantization.utils import (
_get_per_token_block_size,
compute_error,
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor,
Expand Down Expand Up @@ -1474,7 +1481,6 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
numerics that match exactly over N trials.
"""
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.utils import compute_error

num_trials = 1000
group_size = 16
Expand Down Expand Up @@ -1688,6 +1694,61 @@ def test_qat_range_learning(self):
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))

def test_float8_rowwise_fake_quantize(self):
"""
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8Tensor`.
"""
torch.manual_seed(self.SEED)
dtype = torch.float8_e4m3fn
x = torch.randn(32, 64)
axiswise_dim = 0
out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim)
out_expected = hp_tensor_to_float8_dynamic(
x,
dtype,
LinearMMConfig(),
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
).to_original_precision()
torch.testing.assert_close(out, out_expected, atol=0, rtol=0)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower"
)
def test_qat_fp8a4w_quantizer(self):
"""
Test basic model training with `Float8ActInt4WeightQATQuantizer`.
"""
torch.manual_seed(self.SEED)
m = M()
qat_quantizer = Float8ActInt4WeightQATQuantizer()
qat_model = qat_quantizer.prepare(m)
for linear in [m.linear1, m.sub.linear, m.linear2]:
self.assertIsInstance(linear, FakeQuantizedLinear)
self.assertIsInstance(
linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer
)
self.assertIsInstance(linear.weight_fake_quantizer, FakeQuantizer)
prev_weight = copy.deepcopy(m.linear1.weight)

# Simulate training
optimizer = torch.optim.SGD(
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer.zero_grad()
target = torch.randn(1, 512).float()
example_inputs = m.example_inputs()
out = qat_model(*example_inputs)
loss = loss_fn(out, target)
loss.backward()
optimizer.step()
# Assert that weights have valid gradients and are being updated
new_weight = m.linear1.weight
self.assertIsNotNone(new_weight.grad)
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
Int4WeightOnlyEmbeddingQATQuantizer,
)
from .linear import (
Float8ActInt4WeightQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)

__all__ = [
"ComposableQATQuantizer",
"FakeQuantizeConfig",
"Float8ActInt4WeightQATQuantizer",
"FromIntXQuantizationAwareTrainingConfig",
"Int4WeightOnlyEmbeddingQATQuantizer",
"Int4WeightOnlyQATQuantizer",
Expand Down
21 changes: 21 additions & 0 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .utils import (
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_Float8RowwiseFakeQuantize,
)


Expand Down Expand Up @@ -186,3 +187,23 @@ def __repr__(self) -> str:
Return a human readable representation of this `FakeQuantizer` with config details.
"""
return "FakeQuantizer(%s)" % self.config


class _Float8RowwiseActivationFakeQuantizer(torch.nn.Module):
"""
Simple fake quantizer for float8 rowwise fake quantization, intended for activations only.
"""

def __init__(self):
super().__init__()
self.enabled = True

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.enabled:
return _Float8RowwiseFakeQuantize.apply(
x,
torch.float8_e4m3fn,
-1,
)
else:
return x
118 changes: 111 additions & 7 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

from .api import FakeQuantizeConfig
from .fake_quantizer import FakeQuantizer
from .fake_quantizer import (
FakeQuantizer,
_Float8RowwiseActivationFakeQuantizer,
)
from .utils import (
_get_qmin_qmax,
)
Expand Down Expand Up @@ -145,6 +148,11 @@ def from_linear(
return new_linear


# ===========================
# | QAT quantizer interface |
# ===========================


class _LegacyQATQuantizer(TwoStepQuantizer):
"""
Base class for sharing common methods across legacy QAT quantizers.
Expand All @@ -157,9 +165,30 @@ def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return None


# =========================================================
# | Linear int8 dynamic activations + int4 weight QAT |
# =========================================================
def enable_linear_fake_quant(
mod: torch.nn.Module,
enabled: bool = True,
):
"""
Helper function to enable fake quantization in `FakeQuantizerLinear`.
"""
if isinstance(mod, FakeQuantizedLinear):
if mod.activation_fake_quantizer is not None:
mod.activation_fake_quantizer.enabled = enabled
if mod.weight_fake_quantizer is not None:
mod.weight_fake_quantizer.enabled = enabled


def disable_linear_fake_quant(mod: torch.nn.Module):
"""
Helper function to disable fake quantization in `FakeQuantizerLinear`.
"""
enable_linear_fake_quant(mod, enabled=False)


# ===========================================
# | int8 dynamic activations + int4 weights |
# ===========================================


class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer):
Expand Down Expand Up @@ -307,6 +336,7 @@ def disable_fake_quant(self):
self.enable_fake_quant(False)


# TODO: remove these in favor of enable_linear_fake_quant
def enable_8da4w_fake_quant(mod: torch.nn.Module):
"""
Enable fake quantization for `Int8DynActInt4WeightQATLinear`.
Expand All @@ -315,6 +345,7 @@ def enable_8da4w_fake_quant(mod: torch.nn.Module):
mod.enable_fake_quant()


# TODO: remove in favor of disable_linear_fake_quant
def disable_8da4w_fake_quant(mod: torch.nn.Module):
"""
Disable fake quantization for `Int8DynActInt4WeightQATLinear`.
Expand Down Expand Up @@ -357,9 +388,9 @@ def _get_8da4w_weight_config(
)


# ===================================
# | Linear int4 weight-only QAT |
# ===================================
# ====================
# | int4 weight-only |
# ====================


class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer):
Expand Down Expand Up @@ -501,6 +532,7 @@ def disable_fake_quant(self):
self.enable_fake_quant(False)


# TODO: remove these in favor of enable_linear_fake_quant
def enable_4w_fake_quant(mod: torch.nn.Module):
"""
Enable fake quantization for `Int4WeightOnlyQATLinear`.
Expand All @@ -509,6 +541,7 @@ def enable_4w_fake_quant(mod: torch.nn.Module):
mod.enable_fake_quant()


# TODO: remove these in favor of disable_linear_fake_quant
def disable_4w_fake_quant(mod: torch.nn.Module):
"""
Disable fake quantization for `Int4WeightOnlyQATLinear`.
Expand All @@ -533,3 +566,74 @@ def _get_4w_weight_config(
zero_point_precision=qparams_precision,
zero_point_domain=ZeroPointDomain.FLOAT,
)


# =============================================
# | float8 rowwise activations + int4 weights |
# =============================================


class Float8ActInt4WeightQATQuantizer(_LegacyQATQuantizer):
"""
QAT quantizer for applying dynamic rowwise float8 activation + int4
per group/channel symmetric weight fake quantization to linear layers
in the model. Currently only supports rowwise granularity for float8
activations.

args:
group_size (Optional[int]): the number of elements in each quantized
group for weights, defaults to 64. Use None for per channel.
scale_precision: precision of weight scales, defaults to torch.bfloat16.
"""

def __init__(
self,
group_size: Optional[int] = 64,
scale_precision: torch.dtype = torch.bfloat16,
):
if group_size is not None:
weight_granularity = "per_group"
else:
weight_granularity = "per_channel"
self._weight_config = FakeQuantizeConfig(
dtype=torch.int4,
granularity=weight_granularity,
group_size=group_size,
is_symmetric=True,
is_dynamic=True,
scale_precision=scale_precision,
)

def prepare(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
"""
Swap all `nn.Linear` with `FakeQuantizedLinear` with float8
fake quantizer for activations and int4 fake quantizer for weights.
"""
for name, child in model.named_children():
if isinstance(child, torch.nn.Linear):
# TODO: add a config for float8?
new_linear = FakeQuantizedLinear.from_linear(
child,
weight_config=self._weight_config,
)
new_linear.activation_fake_quantizer = (
_Float8RowwiseActivationFakeQuantizer()
)
setattr(model, name, new_linear)
else:
self.prepare(child)
return model

# TODO: add convert path
def convert(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
raise NotImplementedError

def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet")

def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return self.weight_config
32 changes: 32 additions & 0 deletions torchao/quantization/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,38 @@
)


class _Float8RowwiseFakeQuantize(torch.autograd.Function):
"""
Implementation of float8 rowwise fake quantize with backward STE.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
x: torch.Tensor,
float8_dtype: torch.dtype,
axiswise_dim: int,
):
# compute rowwise scale based on `torchao.float8.float8_utils.tensor_to_scale`
eps = 1e-12
amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True)
amax = amax.to(torch.float64)
scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=eps)
scale = scale.to(torch.float32)

# fake quantize
max_value = torch.finfo(float8_dtype).max
x_fq = x.to(torch.float32) * scale
x_fq = x_fq.clamp(min=-max_value, max=max_value)
x_fq = x_fq.to(float8_dtype).to(x.dtype)
x_fq = x_fq / scale
return x_fq.to(x.dtype)

@staticmethod
def backward(ctx, gy):
return gy, None, None


# TODO: delete?
class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function):
"""
Expand Down
Loading