-
Notifications
You must be signed in to change notification settings - Fork 54
Support loading for static quant weight fp8 act fp8 #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
69 commits
Select commit
Hold shift + click to select a range
bb94782
load w8a8
yiliu30 9bef826
refactor
yiliu30 b30a126
add ut
yiliu30 eaad3a6
remove example
yiliu30 c411ca5
fix typo
yiliu30 9802313
Merge branch 'main' into wfp8-afp8
yiliu30 6597d5c
Update auto_round/export/export_to_autoround/export_to_fp8_woq.py
yiliu30 9b0f32f
Update export_to_fp8_woq.py
yiliu30 c32daa6
Merge branch 'main' into wfp8-afp8
yiliu30 c136339
megre main
yiliu30 5ebca24
update shape
yiliu30 03cb217
refactor
yiliu30 e7280f6
Merge branch 'main' into wfp8-afp8
yiliu30 66388e5
tmp add bk
yiliu30 17ddd2d
refactor code
yiliu30 808449d
refine code
yiliu30 f74ed6f
fix device list
yiliu30 632cf8a
fix
yiliu30 5b8b29d
refactor code
yiliu30 57b4c19
fix
yiliu30 bdf5f3e
update
yiliu30 ce3384f
fix ut
yiliu30 7cea90e
Merge branch 'main' into wfp8-afp8
yiliu30 22d11de
correct
yiliu30 9082613
clean
yiliu30 6503355
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 b687633
Merge branch 'main' into wfp8-afp8
yiliu30 2202856
fix shape
yiliu30 10f5753
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 cc42e47
merge with main
yiliu30 d0b99a8
fix check
yiliu30 31845d0
clean code
yiliu30 fdecdde
merge
yiliu30 1f2e674
fix backend check
yiliu30 b56ad25
Merge branch 'main' into wfp8-afp8
yiliu30 4cec318
update config
yiliu30 6b2962f
revert change
yiliu30 638718e
fix
yiliu30 4df3e8f
fix
yiliu30 e01603c
update
yiliu30 0cdf28b
propagate the config
yiliu30 27910da
pass config to checker
yiliu30 d46acdb
add more check
yiliu30 fd05799
refine code
yiliu30 3d75c27
fix equal check
yiliu30 e0c0d58
fix equal check
yiliu30 75f2928
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 fa3ec2d
fix get
yiliu30 ad5269e
rename
yiliu30 35e45ed
update check
yiliu30 f4e254b
add warning
yiliu30 7cba242
Merge branch 'main' into wfp8-afp8
yiliu30 ff5a1e9
rename check
yiliu30 b98f3db
Merge branch 'main' into wfp8-afp8
yiliu30 50968fd
rename
yiliu30 586d6a2
Merge branch 'main' into wfp8-afp8
yiliu30 5e84ff9
Merge branch 'main' into wfp8-afp8
yiliu30 abd83ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9e2c63f
Merge branch 'main' into wfp8-afp8
yiliu30 d332a95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 94508e3
Merge branch 'main' into wfp8-afp8
yiliu30 8a4a533
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f05e38b
fix
yiliu30 d759ca3
Merge branch 'main' into wfp8-afp8
yiliu30 c58a61c
update
yiliu30 c89ffc0
Merge branch 'main' into wfp8-afp8
yiliu30 04ae0fd
Merge branch 'main' into wfp8-afp8
yiliu30 2c34244
fix
yiliu30 b3a0910
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
__all__ = ["QModuleBase"] | ||
|
||
|
||
class QModuleBase(torch.nn.Module): | ||
""" | ||
Base class used to describe the weight creation and forward pass | ||
of different quantization schemes supported by Auto-Round. | ||
The design is inspired by vLLM's CompressedTensorsScheme: | ||
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py | ||
|
||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
@classmethod | ||
@abstractmethod | ||
def from_original(cls, config, original_layer: torch.nn.Module): | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
@abstractmethod | ||
def get_min_capability(cls) -> int: | ||
""" | ||
Get minimum device capability. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def process_weights_after_loading(self, layer: torch.nn.Module): | ||
""" | ||
Called after weight loading is complete for any cleanup that | ||
needs to occur. | ||
""" | ||
raise NotImplementedError |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from auto_round.experimental.qmodules.base import QModuleBase | ||
from auto_round.utils import logger | ||
|
||
__all__ = ["WeightFP8ActFP8StaticQuantLinear"] | ||
|
||
|
||
def _quant_tensor_to_fp8_with_scale(tensor: torch.Tensor, scale: torch.Tensor): | ||
FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max | ||
qtensor = tensor / scale | ||
clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE) | ||
clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn) | ||
return scale, clipped_qtensor_fp8 | ||
|
||
|
||
class WeightFP8ActFP8StaticQuantLinear(QModuleBase): | ||
hp_dtype = torch.bfloat16 | ||
fp8_dtype = torch.float8_e4m3fn | ||
|
||
def __init__( | ||
self, | ||
in_features, | ||
out_features, | ||
weight: Optional[torch.Tensor] = None, | ||
weight_scale: Optional[torch.Tensor] = None, | ||
bias: Union[torch.Tensor, bool, None] = None, | ||
input_scale: Optional[torch.Tensor] = None, | ||
dtype=torch.bfloat16, | ||
): | ||
super().__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
init_weight = torch.zeros((out_features, in_features), dtype=dtype) if weight is None else weight | ||
self.weight = torch.nn.Parameter(init_weight, requires_grad=False) | ||
self.dtype = dtype | ||
if bias is not None: | ||
if isinstance(bias, bool): | ||
bias = torch.zeros((out_features,), dtype=dtype) | ||
self.bias = torch.nn.Parameter(bias, requires_grad=False) | ||
else: | ||
self.register_parameter("bias", None) | ||
init_weight_scale = torch.empty((out_features), dtype=dtype) if weight_scale is None else weight_scale | ||
self.register_buffer("weight_scale", init_weight_scale.to(dtype)) | ||
|
||
init_input_scale = torch.zeros((1), dtype=dtype) if input_scale is None else input_scale | ||
self.register_buffer("input_scale", init_input_scale.to(dtype)) | ||
self.pre_dequantized = False | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
""" | ||
Get minimum device capability. | ||
""" | ||
# TODO: correct that config once we add fp8 op support. | ||
logger.warning_once("FP8 ops are not yet supported. Using capability 0.") | ||
return 0 | ||
|
||
def process_weights_after_loading(self, layer: torch.nn.Module): | ||
pass | ||
|
||
@classmethod | ||
def from_original(cls, config, original_layer): | ||
""" | ||
Create an `WeightFP8ActFP8StaticQuantLinear` layer from an original linear layer. | ||
""" | ||
logger.warning_once( | ||
"FP8 static quantization is still in experimental stage, the inference speed might be slow." | ||
) | ||
device = original_layer.weight.device | ||
with torch.device(device): | ||
qdq_linear = cls( | ||
in_features=original_layer.in_features, | ||
out_features=original_layer.out_features, | ||
bias=original_layer.bias, | ||
) | ||
return qdq_linear | ||
|
||
def dequant_weight_online(self): | ||
if self.pre_dequantized: | ||
return self.weight | ||
qdq_weight = self.weight.to(self.dtype) * self.weight_scale.unsqueeze(1) | ||
return qdq_weight | ||
|
||
def pre_dequantize(self): | ||
if self.pre_dequantized: | ||
return | ||
dequant_weight = self.dequant_weight_online() | ||
del self.weight | ||
del self.weight_scale | ||
self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False) | ||
self.pre_dequantized = True | ||
|
||
def qdq_input(self, bf16_input: torch.Tensor): | ||
input_scale, input_fp8 = _quant_tensor_to_fp8_with_scale(bf16_input, self.input_scale.data) | ||
qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale | ||
return qdq_input_bf16 | ||
|
||
@torch.no_grad() | ||
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor: | ||
|
||
qdq_input = self.qdq_input(bf16_input) | ||
qdq_weight = self.dequant_weight_online() | ||
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) | ||
return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WeiweiZhang1 you have an AR to refine formats related code, please be aware of this change