Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
bb94782
load w8a8
yiliu30 Aug 11, 2025
9bef826
refactor
yiliu30 Aug 12, 2025
b30a126
add ut
yiliu30 Aug 12, 2025
eaad3a6
remove example
yiliu30 Aug 12, 2025
c411ca5
fix typo
yiliu30 Aug 12, 2025
9802313
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 12, 2025
6597d5c
Update auto_round/export/export_to_autoround/export_to_fp8_woq.py
yiliu30 Aug 13, 2025
9b0f32f
Update export_to_fp8_woq.py
yiliu30 Aug 13, 2025
c32daa6
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 13, 2025
c136339
megre main
yiliu30 Aug 24, 2025
5ebca24
update shape
yiliu30 Aug 24, 2025
03cb217
refactor
yiliu30 Aug 26, 2025
e7280f6
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 26, 2025
66388e5
tmp add bk
yiliu30 Aug 26, 2025
17ddd2d
refactor code
yiliu30 Aug 27, 2025
808449d
refine code
yiliu30 Aug 27, 2025
f74ed6f
fix device list
yiliu30 Aug 27, 2025
632cf8a
fix
yiliu30 Aug 27, 2025
5b8b29d
refactor code
yiliu30 Aug 27, 2025
57b4c19
fix
yiliu30 Aug 27, 2025
bdf5f3e
update
yiliu30 Aug 27, 2025
ce3384f
fix ut
yiliu30 Aug 27, 2025
7cea90e
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 28, 2025
22d11de
correct
yiliu30 Aug 28, 2025
9082613
clean
yiliu30 Aug 28, 2025
6503355
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 Aug 28, 2025
b687633
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 28, 2025
2202856
fix shape
yiliu30 Aug 28, 2025
10f5753
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 Aug 29, 2025
cc42e47
merge with main
yiliu30 Aug 29, 2025
d0b99a8
fix check
yiliu30 Aug 29, 2025
31845d0
clean code
yiliu30 Aug 29, 2025
fdecdde
merge
yiliu30 Sep 4, 2025
1f2e674
fix backend check
yiliu30 Sep 4, 2025
b56ad25
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 4, 2025
4cec318
update config
yiliu30 Sep 4, 2025
6b2962f
revert change
yiliu30 Sep 4, 2025
638718e
fix
yiliu30 Sep 4, 2025
4df3e8f
fix
yiliu30 Sep 4, 2025
e01603c
update
yiliu30 Sep 4, 2025
0cdf28b
propagate the config
yiliu30 Sep 4, 2025
27910da
pass config to checker
yiliu30 Sep 4, 2025
d46acdb
add more check
yiliu30 Sep 4, 2025
fd05799
refine code
yiliu30 Sep 4, 2025
3d75c27
fix equal check
yiliu30 Sep 4, 2025
e0c0d58
fix equal check
yiliu30 Sep 4, 2025
75f2928
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 Sep 4, 2025
fa3ec2d
fix get
yiliu30 Sep 4, 2025
ad5269e
rename
yiliu30 Sep 5, 2025
35e45ed
update check
yiliu30 Sep 5, 2025
f4e254b
add warning
yiliu30 Sep 5, 2025
7cba242
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 5, 2025
ff5a1e9
rename check
yiliu30 Sep 5, 2025
b98f3db
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 5, 2025
50968fd
rename
yiliu30 Sep 5, 2025
586d6a2
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 5, 2025
5e84ff9
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 6, 2025
abd83ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2025
9e2c63f
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 8, 2025
d332a95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
94508e3
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 8, 2025
8a4a533
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
f05e38b
fix
yiliu30 Sep 9, 2025
d759ca3
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 9, 2025
c58a61c
update
yiliu30 Sep 9, 2025
c89ffc0
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 9, 2025
04ae0fd
Merge branch 'main' into wfp8-afp8
yiliu30 Sep 10, 2025
2c34244
fix
yiliu30 Sep 10, 2025
b3a0910
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import traceback
from dataclasses import asdict, fields
from enum import Enum
from typing import Any, Callable, Union

import accelerate
Expand All @@ -30,6 +31,7 @@

from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size
from auto_round.export.export_to_autoround import AutoRoundFormat
from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType
from auto_round.low_cpu_mem.utils import get_layers_before_block
from auto_round.schemes import QuantizationScheme, preset_name_to_scheme
Expand Down Expand Up @@ -857,8 +859,8 @@ def remove_duplicates(lst):
format = "auto_round:auto_awq"
elif is_nv_fp(self.data_type) or is_mx_fp(self.data_type):
format = f"auto_round:{self.data_type}"
elif is_wfp8afp8(self): # staic wfp8afp8
format = "auto_round:fp8"
elif is_static_wfp8afp8(self): # staic wfp8afp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeiweiZhang1 you have an AR to refine formats related code, please be aware of this change

format = f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"
elif self.data_type == "fp" and self.bits == 8 and self.act_bits >= 16: # woq fp8
format = "auto_round:fp8"
elif self.act_bits < 16:
Expand Down Expand Up @@ -956,10 +958,10 @@ def _check_supported_format(self, format: str) -> bool:
)
format = "fake"
else:
if not (format == "auto_round" or format == "auto_round:fp8"):
if not (format == "auto_round" or format == f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"):
logger.warning(
f"Currently only support to export auto_round or fake format for static W{self.bits}AFP8 model,"
" change format to auto_round"
f" change format {format} to auto_round"
)
format = "auto_round"
if self.act_group_size != 0 and not self.act_dynamic and format == "auto_round:fp8":
Expand Down
54 changes: 54 additions & 0 deletions auto_round/experimental/qmodules/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Optional, Union

import torch

__all__ = ["QModuleBase"]


class QModuleBase(torch.nn.Module):
"""
Base class used to describe the weight creation and forward pass
of different quantization schemes supported by Auto-Round.
The design is inspired by vLLM's CompressedTensorsScheme:
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py

"""

def __init__(self):
super().__init__()

@classmethod
@abstractmethod
def from_original(cls, config, original_layer: torch.nn.Module):
raise NotImplementedError

@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError

@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError
122 changes: 122 additions & 0 deletions auto_round/experimental/qmodules/fp8_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional, Union

import torch

from auto_round.experimental.qmodules.base import QModuleBase
from auto_round.utils import logger

__all__ = ["WeightFP8ActFP8StaticQuantLinear"]


def _quant_tensor_to_fp8_with_scale(tensor: torch.Tensor, scale: torch.Tensor):
FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max
qtensor = tensor / scale
clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE)
clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn)
return scale, clipped_qtensor_fp8


class WeightFP8ActFP8StaticQuantLinear(QModuleBase):
hp_dtype = torch.bfloat16
fp8_dtype = torch.float8_e4m3fn

def __init__(
self,
in_features,
out_features,
weight: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
bias: Union[torch.Tensor, bool, None] = None,
input_scale: Optional[torch.Tensor] = None,
dtype=torch.bfloat16,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_weight = torch.zeros((out_features, in_features), dtype=dtype) if weight is None else weight
self.weight = torch.nn.Parameter(init_weight, requires_grad=False)
self.dtype = dtype
if bias is not None:
if isinstance(bias, bool):
bias = torch.zeros((out_features,), dtype=dtype)
self.bias = torch.nn.Parameter(bias, requires_grad=False)
else:
self.register_parameter("bias", None)
init_weight_scale = torch.empty((out_features), dtype=dtype) if weight_scale is None else weight_scale
self.register_buffer("weight_scale", init_weight_scale.to(dtype))

init_input_scale = torch.zeros((1), dtype=dtype) if input_scale is None else input_scale
self.register_buffer("input_scale", init_input_scale.to(dtype))
self.pre_dequantized = False

@classmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
# TODO: correct that config once we add fp8 op support.
logger.warning_once("FP8 ops are not yet supported. Using capability 0.")
return 0

def process_weights_after_loading(self, layer: torch.nn.Module):
pass

@classmethod
def from_original(cls, config, original_layer):
"""
Create an `WeightFP8ActFP8StaticQuantLinear` layer from an original linear layer.
"""
logger.warning_once(
"FP8 static quantization is still in experimental stage, the inference speed might be slow."
)
device = original_layer.weight.device
with torch.device(device):
qdq_linear = cls(
in_features=original_layer.in_features,
out_features=original_layer.out_features,
bias=original_layer.bias,
)
return qdq_linear

def dequant_weight_online(self):
if self.pre_dequantized:
return self.weight
qdq_weight = self.weight.to(self.dtype) * self.weight_scale.unsqueeze(1)
return qdq_weight

def pre_dequantize(self):
if self.pre_dequantized:
return
dequant_weight = self.dequant_weight_online()
del self.weight
del self.weight_scale
self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False)
self.pre_dequantized = True

def qdq_input(self, bf16_input: torch.Tensor):
input_scale, input_fp8 = _quant_tensor_to_fp8_with_scale(bf16_input, self.input_scale.data)
qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale
return qdq_input_bf16

@torch.no_grad()
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:

qdq_input = self.qdq_input(bf16_input)
qdq_weight = self.dequant_weight_online()
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
return out
2 changes: 1 addition & 1 deletion auto_round/export/export_to_autoround/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .export import save_quantized_as_autoround
from .export import save_quantized_as_autoround, AutoRoundFormat
16 changes: 14 additions & 2 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from enum import Enum

import threadpoolctl as tctl
import torch
Expand All @@ -43,6 +44,12 @@
)


class AutoRoundFormat(str, Enum):
# Weight: FP8, per-channel, may be extended to per-tensor in future
# Activation: FP8, per-tensor
TORCH_FP8_STATIC = "fp8_static"


def dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits=16):
"""
Dynamically imports and returns the appropriate QuantLinear class based on the specified backend and parameters.
Expand Down Expand Up @@ -152,7 +159,7 @@ def pack_layer(layer_name, model, backend, device=None):

return pack_layer(layer_name, model, backend, device)

if backend == "auto_round:fp8":
if backend == "auto_round:fp8" or backend == f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}":
from auto_round.export.export_to_autoround.export_to_fp8 import pack_layer

return pack_layer(layer_name, model, backend, device)
Expand Down Expand Up @@ -268,9 +275,14 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround

return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs)
from auto_round.autoround import AutoRoundFormat

##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source
if (kwargs.get("sym") is None or kwargs.get("sym")) and ("gptq" not in backend and "awq" not in backend):
if (
(kwargs.get("sym") is None or kwargs.get("sym"))
and ("gptq" not in backend and "awq" not in backend)
and (AutoRoundFormat.TORCH_FP8_STATIC.value not in backend)
):
backend = backend.replace("auto_round", "auto_round:auto_gptq")

model = kwargs["model"]
Expand Down
Loading