-
Notifications
You must be signed in to change notification settings - Fork 54
Support loading for static quant weight fp8 act fp8 #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for loading static quantized models with FP8 weights and FP8 activations by implementing a new quantized linear layer class and updating the model conversion infrastructure.
Key changes:
- Implemented
WeightFP8ActFP8StaticQuantLinear
class for handling FP8 weight and activation quantization - Updated model conversion logic to detect and handle FP8 static quantization configurations
- Enhanced test coverage to verify both export and loading functionality for static FP8 quantization
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
test/test_cpu/test_export.py | Extended test to verify loading of static FP8 quantized models and renamed test method |
auto_round/inference/convert_model.py | Added support for act_dynamic parameter and FP8 static quantization detection in model conversion |
auto_round/inference/backend.py | Added FP8 static quantization detection function and updated dynamic import logic |
auto_round/export/export_to_autoround/export_to_fp8_woq.py | Implemented new WeightFP8ActFP8StaticQuantLinear class with quantization/dequantization methods |
This PR is unnecessary for now, you need to work with Heng to fix the FP8 |
@wenhuach21 The purpose of this PR is to support loading an existing qmodel from disk and then evaluating its accuracy. cc @n1ck-guo |
Yes, but the primary purpose is for evaluation, which the fake model should cover well #731. This is not a product feature, and it involves changes to critical product code. As discussed earlier, please hold this PR for now, or move the code elsewhere without modifying the important HF model inference code. |
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
@@ -857,8 +859,8 @@ def remove_duplicates(lst): | |||
format = "auto_round:auto_awq" | |||
elif is_nv_fp(self.data_type) or is_mx_fp(self.data_type): | |||
format = f"auto_round:{self.data_type}" | |||
elif is_wfp8afp8(self): # staic wfp8afp8 | |||
format = "auto_round:fp8" | |||
elif is_static_wfp8afp8(self): # staic wfp8afp8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WeiweiZhang1 you have an AR to refine formats related code, please be aware of this change
Additionally, please make sure all ut in https://github.com/intel/auto-round/blob/main/test/test_cuda/test_transformers.py could pass before merging, |
Signed-off-by: yiliu30 <[email protected]>
The local tests passed. =============== short test summary info ===============
PASSED test_transformers.py::AutoRoundTest::test_convert_from_gptq
PASSED test_transformers.py::AutoRoundTest::test_mixed_bits
PASSED test_transformers.py::AutoRoundTest::test_quantized_model
PASSED test_transformers.py::AutoRoundTest::test_quantized_model_bf16
PASSED test_transformers.py::AutoRoundTest::test_quantized_model_multi_gpu
PASSED test_transformers.py::AutoRoundTest::test_raise_if_non_quantized
PASSED test_transformers.py::AutoRoundTest::test_save_pretrained
SKIPPED [1] test_transformers.py:166: test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see https://github.com/intel/intel-extension-for-pytorch
SKIPPED [1] test_transformers.py:101: test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see https://github.com/intel/intel-extension-for-pytorch
= 7 passed, 2 skipped, 38 warnings in 102.73s (0:01:42) = |
auto_round:torch_fp8_static
for loading and inference w8afp8QuantizationScheme
to support dict-style access.QuantizationScheme
, and propagate it to the backend check