Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
SUPPORTED_LAYER_TYPES,
TORCH_VERSION_AT_LEAST_2_6,
CpuInfo,
_generate_block_recipe,
_generate_recipe,
_gguf_args_check,
block_forward,
check_and_mark_fp8_model,
Expand Down Expand Up @@ -67,9 +69,9 @@
infer_bits_by_data_type,
init_cache,
is_debug_mode,
is_hpex_available,
is_mx_fp,
is_nv_fp,
is_optimum_habana_available,
is_standard_fp,
llm_load_model,
logger,
Expand Down Expand Up @@ -101,6 +103,11 @@ class AutoRound(object):
enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers.
"""

# If function is not necessary for AutoRound, putting it in other place and
# assembling the function here, can improve code readability and maintainability
_generate_recipe = _generate_recipe
_generate_block_recipe = _generate_block_recipe

def __init__(
self,
model: Union[torch.nn.Module, str],
Expand Down Expand Up @@ -364,7 +371,7 @@ def __init__(

torch.set_printoptions(precision=3, sci_mode=True)

if is_optimum_habana_available():
if is_hpex_available():
logger.info("Optimum Habana is available, import htcore explicitly.")
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401]
Expand Down Expand Up @@ -430,6 +437,10 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as fp8 is enabled")

### Recipe Mode ###
self.recipe_mode = False
self.recipe_results = {"recipe": {}, "results": {}}

def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None:
"""Sets the device map for specific blocks in the model.

Expand Down Expand Up @@ -1394,7 +1405,11 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) -
if self.device_map is not None:
accelerate.hooks.remove_hook_from_submodules(block)

if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
if (
hasattr(self, "formats")
and is_nv_fp(self.act_data_type)
and any("nv_fp" in format_ for format_ in self.formats)
):
from auto_round.utils import set_amax_for_all_moe_layers

# enable moe experts act_max automatic generation for linears
Expand Down Expand Up @@ -1549,6 +1564,8 @@ def quantize(self):
f"Expected exactly one packing format when 'is_packing_immediate' is True, "
f"but got {len(self.formats)} formats."
)
if self.recipe_mode:
return

self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately

Expand Down Expand Up @@ -2439,7 +2456,14 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to

modules = block.modules()
for module in modules:
update_fused_layer_global_scales(module)
try:
update_fused_layer_global_scales(module)
except:
# mix-precision may cause error, since q,k,v are not the same dtype.
logger.warning_once(
"Cannot keep the same global scale for fused layers, "
+ "so the model may not work with vLLM with fused QKV or else."
)
round_params = []
minmax_params = []
for n, m in block.named_modules():
Expand Down Expand Up @@ -2561,7 +2585,7 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to
logger.info(f"{unquantized_layer_names} have not been quantized")
with torch.no_grad():
unwrapper_block(block, best_params)
if self.enable_quanted_input:
if self.enable_quanted_input and hasattr(self, "formats"):
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
from auto_round.utils import set_amax_for_all_moe_layers

Expand Down Expand Up @@ -2616,6 +2640,20 @@ def quantize_blocks(
clear_memory()
input_ids = to_device(input_ids, self.cache_device)
input_others = to_device(input_others, self.cache_device)
if self.recipe_mode:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to wrap this new code into a function and call it as early as possible.

logger.info("[Recipe Mode] starts")
q_input_ids, ref_q_input_ids = None, None # init value
for block_name in tqdm(block_names):
block = get_module(model, block_name)
if not self.model.device.type == "meta" or self.low_cpu_mem_usage:
block = block.to(device)
input_ids, q_input_ids, ref_q_input_ids = self._generate_block_recipe(
block, block_name, input_ids, q_input_ids, ref_q_input_ids, input_others
)
if is_hpex_available():
htcore.mark_step()
logger.info("[Recipe Mode] ends")
return
## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
tmp_dtype = self.amp_dtype if self.amp else torch.float32
for i in range(len(input_ids)):
Expand Down Expand Up @@ -2891,7 +2929,7 @@ def scale_loss_and_backward(self, scaler, loss):
"""
scale_loss = loss * 1000
scale_loss.backward()
if is_optimum_habana_available():
if is_hpex_available():
htcore.mark_step()
return scale_loss

Expand All @@ -2908,7 +2946,7 @@ def step(self, scaler, optimizer, lr_schedule):
"""
optimizer.step()
# for hpu
if is_optimum_habana_available():
if is_hpex_available():
htcore.mark_step()
optimizer.zero_grad()
lr_schedule.step()
Expand Down Expand Up @@ -3117,7 +3155,7 @@ def scale_loss_and_backward(self, scaler, loss):
loss = scaler.scale(loss)

loss.backward()
if is_optimum_habana_available():
if is_hpex_available():
htcore.mark_step()
return loss

Expand All @@ -3131,5 +3169,5 @@ def step(self, scaler, optimizer, lr_schedule):
optimizer.step()
optimizer.zero_grad()
lr_schedule.step()
if is_optimum_habana_available():
if is_hpex_available():
htcore.mark_step()
8 changes: 6 additions & 2 deletions auto_round/data_type/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
revert_tensor_by_pad,
round_ste,
)
from auto_round.utils import is_hpex_available

MXFP_FORMAT_CACHE = {
# data type: ebits, mbits, emax, max_norm, min_norm
Expand Down Expand Up @@ -77,7 +78,6 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"):
return tensor


@torch.compile()
def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs):
"""Quantize the given tensor using the specified parameters.

Expand Down Expand Up @@ -128,7 +128,6 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_roundin
return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None


@torch.compile()
def quant_mx_rceil(
tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs
):
Expand Down Expand Up @@ -180,6 +179,11 @@ def quant_mx_rceil(
return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None


# HPU returns error with Habana software 1.22.0, so skip torch.compile here.
if False and not is_hpex_available():
quant_mx = torch.compile(quant_mx)
quant_mx_rceil = torch.compile(quant_mx_rceil)

for key in MXFP_FORMAT_CACHE.keys():
QUANT_FUNC_WITH_DTYPE[key] = quant_mx
QUANT_FUNC_WITH_DTYPE[key + "_rceil"] = quant_mx_rceil
Expand Down
4 changes: 4 additions & 0 deletions auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, **kwargs):
tensor_max = tensor.abs().max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
global_scale = global_scale.to(tensor.device)
# Ensure global_scale is in float32, sometimes tensor is in bf16/fp16
global_scale = global_scale.to(torch.float32)
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
return qdq_res.to(orig_dtype), scale, None
Expand All @@ -109,6 +111,8 @@ def nv_fp4_with_static_gs(tensor, bits=4, group_size=16, v=0, tensor_max=None, *
tensor_max = tensor.abs().max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
global_scale = global_scale.to(tensor.device)
# Ensure global_scale is in float32, sometimes tensor is in bf16/fp16
global_scale = global_scale.to(torch.float32)
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
return qdq_res.to(orig_dtype), scale, None
Expand Down
Loading