Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b133369
Split choose_qparams_affine
jainapurva Apr 28, 2025
8e4bca8
Remove preserve_zero and zero_point_domain
jainapurva Apr 29, 2025
a68a679
Update choose_qparams_affine_min_max
jainapurva Apr 29, 2025
b9c7c53
Update float8 choose_qparams
jainapurva Apr 29, 2025
ea5525e
Use float8 choose/quantize/dequantize
jainapurva Apr 29, 2025
cff885b
Updates to choose_qparams_affine uses
jainapurva Apr 29, 2025
f747fff
Test fixes
jainapurva Apr 30, 2025
694dab3
Updates
jainapurva Apr 30, 2025
62a99a1
Updates
jainapurva Apr 30, 2025
3a5efa7
Split quantize_affine based on zero_point_domain
jainapurva Apr 30, 2025
57d55b0
Merge remote-tracking branch 'origin/main' into qparam_args
jainapurva Apr 30, 2025
6e42999
Fix tests
jainapurva Apr 30, 2025
414df66
dequantize_affine and test fixes
jainapurva May 1, 2025
e3f307d
Test fixes
jainapurva May 5, 2025
2f1ded8
Ignore quantize_pt2e until fixed
jainapurva May 5, 2025
f67a1f3
Fix pt2e
jainapurva May 5, 2025
d2b47c4
Ruff fixes
jainapurva May 5, 2025
e64cf1f
Minor fixes
jainapurva May 5, 2025
9780257
Fix qat issues
jainapurva May 13, 2025
9462d0a
Merge branch 'main' into qparam_args
jainapurva May 13, 2025
a978360
Updates
jainapurva May 13, 2025
fae98f1
Updates
jainapurva May 13, 2025
d9b2339
Updates
jainapurva May 13, 2025
cb760bf
Updates
jainapurva May 13, 2025
6c90140
Minor fixes
jainapurva May 15, 2025
5ce7e50
Update reviews
jainapurva May 21, 2025
17a3674
Updates
jainapurva May 21, 2025
0e9bb83
Merge remote-tracking branch 'origin/main' into qparam_args
jainapurva May 21, 2025
e4183a5
Updates
jainapurva May 21, 2025
214e704
Updates
jainapurva May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
dequantize_affine,
quantize_affine,
Expand Down Expand Up @@ -112,7 +111,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
mapping_type = MappingType.SYMMETRIC
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
block_size = (1, group_size)

scale, zero_point = choose_qparams_affine(
Expand All @@ -123,8 +121,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
eps=eps,
scale_dtype=torch.float32,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=zero_point_domain,
)

aqt = quantize_affine(
Expand All @@ -133,15 +129,12 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
scale,
zero_point,
dtype,
zero_point_domain=zero_point_domain,
)
# Note: output will be uint8 tensor for sub byte tensors for now

q = to_uintx(aqt, dtype, -1)
assert q is not None, "quantization failed"
deqaunt = dequantize_affine(
q, block_size, scale, zero_point, dtype, zero_point_domain=zero_point_domain
)
deqaunt = dequantize_affine(q, block_size, scale, zero_point, dtype)
assert deqaunt is not None, "deqauntization failed"


Expand Down
190 changes: 27 additions & 163 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
import unittest

import torch
from parameterized import parameterized

from torchao.float8.float8_utils import EPS as float8_eps
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_tinygemm,
dequantize_affine,
dequantize_affine_float8,
fake_quantize_affine,
fake_quantize_affine_cachemask,
quantize_affine,
quantize_affine_float8,
)

# TODO: remove test for utils?
Expand Down Expand Up @@ -650,35 +646,6 @@ def test_raises(self):
with self.assertRaisesRegex(RuntimeError, "is invalid for input of size 1"):
_ = quantize_affine(input, block_size, scale, zero_point, dtype)

def test_not_preserve_zero_not_supported(self):
"""Making sure preserve_zero == False is not supported for symmetric quant"""
input = torch.randn(10, 256)
n_bit = 4
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
quant_max = 2**n_bit - 1
eps = 1e-6
scale_dtype = torch.bfloat16
zero_point_dtype = torch.bfloat16
with self.assertRaisesRegex(
ValueError,
"preserve_zero == False is not supported for symmetric quantization",
):
choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
)

def test_get_groupwise_affine_qparams(self):
input = torch.randn(10, 256)
n_bit = 4
Expand All @@ -702,22 +669,33 @@ def test_get_groupwise_affine_qparams(self):
dtype=torch.bfloat16,
zero_point_domain=zero_point_domain,
)
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=zero_point_domain == ZeroPointDomain.INT,
zero_point_domain=zero_point_domain,
)
if zero_point_domain == ZeroPointDomain.FLOAT:
scale, zero_point = choose_qparams_affine_tinygemm(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
)
else:
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))

def test_groupwise_affine_quantize_tensor_from_qparams(self):
input = torch.randn(10, 256)
Expand Down Expand Up @@ -847,120 +825,6 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

def test_none_zero_point_domain(self):
"""A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
input = torch.randn(10, 256)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = None
quant_max = None
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
try:
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=None,
)
except ValueError:
# This exception was expected
# Now test for ZeroPointDomain.NONE
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)
else:
# An exception should have been thrown for zero_point_domain None
self.assertTrue(
False,
msg="A runtime exception should have been thrown for zero_point_domain None",
)

@parameterized.expand(
[
(
torch.float32,
torch.float8_e4m3fn,
),
(
torch.float32,
torch.float8_e5m2,
),
(
torch.bfloat16,
torch.float8_e4m3fn,
),
(
torch.bfloat16,
torch.float8_e5m2,
),
]
)
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
input = torch.randn(10, 10)

# float8 quantization primitives
scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype)
quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype)
dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype)

# reference implementation using generic primitives
expected_scale, _ = choose_qparams_affine(
input,
MappingType.SYMMETRIC,
input.shape,
float8_dtype,
eps=float8_eps, # use same EPS as float8 training
scale_dtype=torch.float32,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
)
expected_quantized = quantize_affine(
input,
input.shape,
scale,
output_dtype=float8_dtype,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=ZeroPointDomain.NONE,
)
expected_dequantized = dequantize_affine(
expected_quantized,
input.shape,
scale,
input_dtype=float8_dtype,
output_dtype=hp_dtype,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=ZeroPointDomain.NONE,
)

self.assertTrue(torch.equal(expected_scale, scale))
torch.testing.assert_close(expected_quantized, quantized)
torch.testing.assert_close(expected_dequantized, dequantized)


if __name__ == "__main__":
unittest.main()
6 changes: 0 additions & 6 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from torchao.quantization.quant_api import int4_weight_only, quantize_
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
)
Expand Down Expand Up @@ -92,8 +91,6 @@ def test_pack_unpack_equivalence(self):
eps = 1e-6
zero_point_dtype = torch.bfloat16
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
scale_dtype = None

w = torch.rand(shape, dtype=torch.float16, device="cuda")
Expand All @@ -112,8 +109,6 @@ def test_pack_unpack_equivalence(self):
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
w_q_24 = quantize_affine(
w_24,
Expand All @@ -123,7 +118,6 @@ def test_pack_unpack_equivalence(self):
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
scales = scales.reshape(-1, w_q_24.shape[1])

Expand Down
Loading
Loading