Skip to content

Remove preserve_zero and zero_point_domain from choose_qparams_affine #2149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
May 21, 2025

Conversation

jainapurva
Copy link
Contributor

@jainapurva jainapurva commented Apr 29, 2025

This pull request focuses on refactoring and simplifying quantization-related code by removing unused or redundant functionality and introducing specialized methods for handling specific cases. The most important changes include removing the preserve_zero and zero_point_domain parameters from many functions, introducing new specialized quantization and dequantization methods, and modifying use-cases accordingly.

Refactoring and Simplification:

  • Removed the preserve_zero and zero_point_domain parameters from choose_qparams_affine, quantize_affine, and dequantize_affine calls across multiple files, while introducing specialized methods to handle specific quantization scenarios.

The following table contains the new methods:

Original Method ZeroPointDomain value preserve_zero value New Method
choose_qparams_affine INT/NONE True choose_qparams_affine
choose_qparams_affine FLOAT False choose_qparams_affine_tinygemm
choose_qparams_affine INT False choose_qparams_affine_dont_preserve_zero
quantize_affine INT N/A quantize_affine
quantize_affine FLOAT N/A quantize_affine_float_zero_point
quantize_affine NONE N/A quantize_affine_no_zero_point
dequantize_affine INT N/A dequantize_affine
dequantize_affine FLOAT N/A dequantize_affine_float_zero_point
dequantize_affine NONE N/A dequantize_affine_no_zero_point

Notable updates related to the changes:

  • from_hp_tp_intx and from_hp_to_intx_static still take zero_point_domain and preserve_zero as inout, and call the respective choose_qparams/quantize/dequantize_affine functions.
  • from_hp_to_floatx and from_hp_to_floatx_static use the float8 methods: choose_qparams_affine_float8, quantize_affine_float8 and dequantize_affine_float8

The following list contains AOBaseConfigs, along with the corresponding choose_qparams_affine function calls made by the backend for each configuration:

AoBaseConfig choose_qparams_affine
Int8DynamicActivationInt4WeightConfig choose_qparams_affine
Int8DynamicActivationIntxWeightConfig choose_qparams_affine / choose_qparams_affine_dont_preserve_zero
GemliteUIntXWeightOnlyConfig choose_qparams_and_quantize_affine_hqq / choose_qparams_affine
Int4WeightOnlyConfig choose_qparams_affine / choose_qparams_affine_tinygemm / choose_qparams_affine_dont_preserve_zero
Int8WeightOnlyConfig choose_qparams_affine
Int8DynamicActivationInt8WeightConfig choose_qparams_affine
Float8WeightOnlyConfig choose_qparams_affine_float8
Float8DynamicActivationFloat8WeightConfig choose_qparams_affine_float8
Float8StaticActivationFloat8WeightConfig choose_qparams_affine_float8
UIntXWeightOnlyConfig choose_qparams_and_quantize_affine_hqq / choose_qparams_affine
IntxWeightOnlyConfig choose_qparams_affine / choose_qparams_affine_dont_preserve_zero
FPXWeightOnlyConfig choose_qparams_affine_fpx

Copy link

pytorch-bot bot commented Apr 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2149

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 New Failures

As of commit 214e704 with merge base 212d912 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 29, 2025
@jainapurva jainapurva added topic: not user facing Use this tag if you don't want this PR to show up in release notes topic: for developers Use this tag if this PR is mainly developer facing labels Apr 29, 2025
@jainapurva jainapurva marked this pull request as ready for review April 30, 2025 17:36
@jainapurva jainapurva marked this pull request as draft April 30, 2025 18:10
@jainapurva jainapurva force-pushed the qparam_args branch 2 times, most recently from 85936a5 to 9780257 Compare May 13, 2025 21:52
@jainapurva jainapurva marked this pull request as ready for review May 14, 2025 05:20
@@ -255,7 +254,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
target_dtype = torch.int32
quant_min = 0
quant_max = 15
zero_point_domain = ZeroPointDomain.FLOAT
# zero_point_domain is ZeroPointDomain.FLOAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought to keep it for now, as it can be an indicator of previous implementation.

@@ -1025,7 +1025,6 @@ def get_per_token_block_size(x):
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be reverted?

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks @jainapurva for carefully working through this!

Comment on lines 437 to 441
zero_point_domain is optional specifies how we quantize the floating point to quantized data:
INT: quantized_val = (float_val / scale) (integer) + zero_point (integer)
FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization
Where we do not want to round values to nearest integer and instead scale and cast.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can just leave the one that is relevant

raise ValueError("Please use ZeroPointDomain.NONE instead of None")
elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None:
raise ValueError("zero_point should be None when zero_point_domain is NONE")
# if zero_point_domain is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please remove the commented code before landing

quant_max: Union[int, float],
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function converts AQT tensors to their high precision floating point representation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we only have doc for non-private helper functions?

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the docs has to be updated a bit, commented inline

@jerryzh168 jerryzh168 added the topic: bc-breaking Use this tag if this PR breaks backward compatibility label May 16, 2025
@jainapurva jainapurva merged commit 04fb450 into main May 21, 2025
12 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bc-breaking Use this tag if this PR breaks backward compatibility topic: for developers Use this tag if this PR is mainly developer facing topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants