Skip to content

Test PARQ with torchao activation quantization #2370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 81 additions & 8 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@
)
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
from torchao.quantization.granularity import PerGroup
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
MappingType,
_is_linear,
int4_weight_only,
quantize_,
Expand Down Expand Up @@ -68,9 +75,9 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):


class M(nn.Module):
def __init__(self, m=256, n=128, k=16, bias=False):
def __init__(self, m=256, n=128, k=16, bias=False, embedding=True):
super().__init__()
self.embedding = nn.Embedding(10, m)
self.embedding = nn.Embedding(10, m) if embedding else nn.Identity()
self.linear1 = nn.Linear(m, n, bias=bias)
self.linear2 = nn.Linear(n, k, bias=bias)
self.relu = nn.ReLU()
Expand All @@ -83,7 +90,11 @@ def reset_parameters(self):
nn.init.zeros_(module.bias)

def example_inputs(self, device=None):
return torch.randint(1, 10, (1, 256), device=device)
return (
torch.randint(1, 10, (1, self.linear1.in_features), device=device)
if isinstance(self.embedding, nn.Embedding)
else torch.randn(1, self.linear1.in_features, device=device)
)

def forward(self, x):
x = self.embedding(x)
Expand Down Expand Up @@ -150,11 +161,11 @@ def compare_quantized_models(
p = p.view(-1, group_size)

q, Q = quantizer.quantize(p, b=b, dim=-1)
q = q.view(original_shape)

# compare to AffineQuantizedTensor instance
q = q.view(original_shape)
ref = getattr(m_ref, n).weight.dequantize()
self.assertTrue(q.equal(ref))
torch.testing.assert_close(q, ref, atol=0, rtol=0)

def compare_parq_convert(
self,
Expand Down Expand Up @@ -182,13 +193,13 @@ def compare_parq_convert(
p = module.weight.dequantize() # PARQ weight after quantize_
p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_

self.assertTrue(p_orig.equal(p_ref))
self.assertTrue(p.equal(p_ref))
torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0)
torch.testing.assert_true(p, p_ref, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@common_utils.parametrize("group_size", [32, 256])
def test_int4_weight_only(self, group_size: int = 32):
model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE)
model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16)
model.reset_parameters()

m_ref = copy.deepcopy(model).eval().to(_DEVICE)
Expand Down Expand Up @@ -265,8 +276,70 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
self.compare_parq_convert(model, m_ref, optimizer, config)


class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
def setUp(self):
torch.manual_seed(123)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@common_utils.parametrize("b", [2, 3, 4, 8])
@common_utils.parametrize("model_dtype", [torch.float16, torch.float32])
@common_utils.parametrize("group_size", [32, 128])
def test_int8_dynamic_activation_intx_e2e(
self,
b: int = 2,
model_dtype: torch.dtype = torch.float32,
group_size: int = 32,
):
model = M(embedding=False).to(_DEVICE, dtype=model_dtype)
x = model.example_inputs(device=_DEVICE).to(model_dtype)

# reference model using native quantization
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
quantizer = UnifTorchaoQuantizer()
config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=_BIT_WIDTH_TO_DTYPE[b],
weight_granularity=PerGroup(group_size),
weight_mapping_type=quantizer.mapping_type,
act_mapping_type=MappingType.ASYMMETRIC,
)
quantize_(m_ref, config)
ref_out = m_ref(x)

# quantize weights with PARQ
base_optimizer = torch.optim.SGD(build_param_groups(model, b, group_size))
optimizer = QuantOptimizer(
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
)
optimizer.zero_grad()
optimizer.step()

# apply torchao quantized activations on top
activation_config = FakeQuantizeConfig(
torch.int8,
granularity="per_token",
mapping_type=config.act_mapping_type,
)
filter_fn = optimizer.get_filter_fn(model)
quantize_(
model,
IntXQuantizationAwareTrainingConfig(activation_config=activation_config),
filter_fn=filter_fn,
)
out = model(x)
torch.testing.assert_close(out, ref_out, atol=0, rtol=0)

# equivalent to torchao's convert step
model.eval()
optimizer.restore_latent_params()
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
quantize_(model, config, filter_fn=filter_fn)
converted_out = model(x)
torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0)


common_utils.instantiate_parametrized_tests(TestPARQuantization)
common_utils.instantiate_parametrized_tests(TestUnifTorchaoQuantizer)
common_utils.instantiate_parametrized_tests(TestInt8DynamicActivationTorchaoQuantizer)


if __name__ == "__main__":
Expand Down
44 changes: 21 additions & 23 deletions torchao/prototype/parq/quant/uniform_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,23 @@ def __init__(
self.quant_min = quant_min
self.quant_max = quant_max
self.eps = eps
self.preserve_zero = preserve_zero
self.zero_point_domain = zero_point_domain

# defaults: zero_point_domain=ZeroPointDomain.INT, preserve_zero=True
self._choose_qparams = choose_qparams_affine
self._quantize = quantize_affine
self._dequantize = dequantize_affine

if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
self._choose_qparams = choose_qparams_affine_tinygemm
self._quantize = quantize_affine_tinygemm
self._dequantize = dequantize_affine_tinygemm
elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero:
self._choose_qparams = choose_qparams_affine_dont_preserve_zero
self._quantize = quantize_affine
self._dequantize = dequantize_affine
elif zero_point_domain == ZeroPointDomain.NONE:
self._quantize = quantize_affine_no_zero_point
self._dequantize = dequantize_affine_no_zero_point
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes aren't required to make the test pass right, just for clean ups? (we can keep them in this PR, just wanted to ask for my understanding)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is just for simplification!


def _init_quant_min_max(self, b: int) -> None:
if self.quant_min is None or self.quant_max is None:
Expand All @@ -74,24 +89,7 @@ def quantize(
# assume that p has already been grouped in QuantOptimizer.step
block_size = (1, p.size(-1)) if dim is not None else p.size()

if self.zero_point_domain == ZeroPointDomain.FLOAT and not self.preserve_zero:
_choose_qparams_affine = choose_qparams_affine_tinygemm
_quantize_affine = quantize_affine_tinygemm
_dequantize_affine = dequantize_affine_tinygemm
elif self.zero_point_domain == ZeroPointDomain.INT and not self.preserve_zero:
_choose_qparams_affine = choose_qparams_affine_dont_preserve_zero
_quantize_affine = quantize_affine
_dequantize_affine = dequantize_affine
else: # Default case: zero_point_domain == ZeroPointDomain.INT/NONE and preserve_zero
_choose_qparams_affine = choose_qparams_affine
if self.zero_point_domain == ZeroPointDomain.INT:
_quantize_affine = quantize_affine
_dequantize_affine = dequantize_affine
else:
_quantize_affine = quantize_affine_no_zero_point
_dequantize_affine = dequantize_affine_no_zero_point

s, zero_point = _choose_qparams_affine(
s, zero_point = self._choose_qparams(
p,
self.mapping_type,
block_size,
Expand All @@ -101,13 +99,13 @@ def quantize(
quant_max=self.quant_max,
)
q_args = (block_size, s, zero_point, self.target_dtype)
q = _quantize_affine(
q = self._quantize(
p,
*q_args,
quant_min=self.quant_min,
quant_max=self.quant_max,
)
q = _dequantize_affine(
q = self._dequantize(
q,
*q_args,
output_dtype=p.dtype,
Expand All @@ -124,7 +122,7 @@ def quantize(
else:
block_size = Q.shape

Q = _dequantize_affine(
Q = self._dequantize(
Q,
block_size,
*q_args[1:],
Expand Down
Loading