-
Notifications
You must be signed in to change notification settings - Fork 317
Adds Q/DQ layout support for embedding quantization with IntxWeightOnlyConfig #1972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1972
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 186f903 with merge base 5ded23c ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
2c3b9ac
to
05eec5d
Compare
@@ -1569,6 +1572,92 @@ def _uintx_weight_only_transform( | |||
return module | |||
|
|||
|
|||
@dataclass | |||
class IntxWeightOnlyConfig(AOBaseConfig): | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andrewor14 can you have a look at this comment if there are any issues with it working well with QAT workflow with FakeQuantizeConfig.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strange that we have IntxWeightOnly and Int4WeightOnly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I feel we should probably merge these two
has_weight_zeros=has_weight_zeros, | ||
).quantize(quantized_model_reference) | ||
quantize_( | ||
quantized_model_reference, | ||
Int8DynamicActivationIntxWeightConfig( | ||
weight_dtype=weight_dtype, | ||
granularity=granularity, | ||
granularity=PerRow(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be PerAxis as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can't be because that's controlled by Int8DynamicActivationIntxWeightConfig, which uses PerRow until #1968 lands
@@ -155,7 +154,7 @@ def test_shared_embedding(self): | |||
quantized_model = copy.deepcopy(model) | |||
SharedEmbeddingQuantizer( | |||
weight_dtype=weight_dtype, | |||
granularity=granularity, | |||
granularity=PerRow(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good overall, just need to change PerRow to PerAxis(axis=0) as we discussed in meeting
…lyConfig (#1972) * up * up * up * up * up * up * up * up
@@ -263,6 +269,9 @@ def _(func, types, args, kwargs): | |||
|
|||
@implements(torch.nn.functional.embedding) | |||
def _(func, types, args, kwargs): | |||
if _embedding_q_dq_check(args, kwargs): | |||
return _embedding_q_dq_impl(args, kwargs) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does line 299 only dequantizes weight bu tnot actually run embedding op?
This will be used to quantize embeddings in ET.