Skip to content

Add convert path for quantize_ QAT API #1540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
from torchao.quantization.qat.embedding import (
Expand All @@ -42,6 +43,9 @@
_GenericFakeQuantize,
_get_qmin_qmax,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
)
from torchao.quantization.quant_primitives import (
MappingType,
TorchAODType,
Expand Down Expand Up @@ -1262,6 +1266,67 @@ def test_quantize_api_errors(self):
lambda m, _: isinstance(m, torch.nn.ReLU),
)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api_convert_path(self):
"""
Test that the following:

quantize_(model, intx_quantization_aware_training(...))
quantize_(model, from_intx_quantization_aware_training(...))
quantize_(model, int8_dynamic_activation_int4_weight())

can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
"""
from torchao.quantization.qat import (
Int8DynActInt4WeightQATQuantizer,
)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
baseline_model = copy.deepcopy(m)

# Baseline prepare
baseline_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
baseline_model = baseline_quantizer.prepare(baseline_model)

# quantize_ prepare
activation_config = FakeQuantizeConfig(
torch.int8,
"per_token",
is_symmetric=False,
)
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
quantize_(
m,
intx_quantization_aware_training(activation_config, weight_config),
)

# Compare prepared values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

# Baseline convert
baseline_model = baseline_quantizer.convert(baseline_model)

# quantize_ convert
quantize_(m, from_intx_quantization_aware_training())
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
Comment on lines +1319 to +1320
Copy link
Contributor

@jerryzh168 jerryzh168 Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the scale/zero_point calculated from int8_dynamic_activation_int4_weight is not guaranteed to be the same as the ones from QAT right, is this OK?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is actually also the case in the old flow (Int8DynActInt4WeightQATQuantizer). I did verify that the qparams are the same today because we calculate them the same way before and after convert. If the user really wants to guarantee the same qparams they can also just store them somewhere and set them manually in the weight tensors, so I think it's OK


# Compare converted values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
from .embedding import (
Expand All @@ -18,4 +19,5 @@
"Int4WeightOnlyEmbeddingQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"intx_quantization_aware_training",
"from_intx_quantization_aware_training",
]
40 changes: 38 additions & 2 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, List, Optional, Union
from typing import Any, Callable, List, Optional, Union

import torch

Expand Down Expand Up @@ -242,7 +242,7 @@ def __setattr__(self, name: str, value: Any):
def intx_quantization_aware_training(
activation_config: Optional[FakeQuantizeConfig] = None,
weight_config: Optional[FakeQuantizeConfig] = None,
) -> torch.nn.Module:
) -> Callable:
"""
Return a function that applies fake quantization to a `torch.nn.Module`.
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
Expand Down Expand Up @@ -295,6 +295,42 @@ def _insert_fake_quantize(mod: torch.nn.Module):
return _insert_fake_quantize


def from_intx_quantization_aware_training() -> Callable:
"""
Return a function that converts a model with fake quantized modules,
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
back to model with the original, corresponding modules without
fake quantization. This should be used with
:func:`~torchao.quantization.quant_api.quantize_`.

Example usage::

from torchao.quantization import quantize_
quantize_(
model_with_fake_quantized_linears,
from_intx_quantization_aware_training(),
)
"""

def _remove_fake_quantize(mod: torch.nn.Module):
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod

return _remove_fake_quantize


class ComposableQATQuantizer(TwoStepQuantizer):
"""
Composable quantizer that users can use to apply multiple QAT quantizers easily.
Expand Down
18 changes: 18 additions & 0 deletions torchao/quantization/qat/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.sparse,
)

def to_embedding(self) -> torch.nn.Embedding:
new_embedding = torch.nn.Embedding(
self.num_embeddings,
self.embedding_dim,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
device=self.weight.device,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if self.weight.device != torch.device("meta"):
new_embedding.weight = self.weight
return new_embedding

@classmethod
def from_embedding(
cls,
Expand Down
11 changes: 11 additions & 0 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight
return F.linear(x, w)

def to_linear(self) -> torch.nn.Linear:
new_linear = torch.nn.Linear(
self.in_features, self.out_features, self.bias, device=self.weight.device
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if self.weight.device != torch.device("meta"):
new_linear.weight = self.weight
return new_linear

@classmethod
def from_linear(
cls,
Expand Down
Loading