diff --git a/auto_round/autoround.py b/auto_round/autoround.py index aedcd860d..c084c6b15 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -38,6 +38,8 @@ SUPPORTED_LAYER_TYPES, TORCH_VERSION_AT_LEAST_2_6, CpuInfo, + _generate_block_recipe, + _generate_recipe, _gguf_args_check, block_forward, check_and_mark_fp8_model, @@ -67,9 +69,9 @@ infer_bits_by_data_type, init_cache, is_debug_mode, + is_hpex_available, is_mx_fp, is_nv_fp, - is_optimum_habana_available, is_standard_fp, llm_load_model, logger, @@ -101,6 +103,11 @@ class AutoRound(object): enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers. """ + # If function is not necessary for AutoRound, putting it in other place and + # assembling the function here, can improve code readability and maintainability + _generate_recipe = _generate_recipe + _generate_block_recipe = _generate_block_recipe + def __init__( self, model: Union[torch.nn.Module, str], @@ -364,7 +371,7 @@ def __init__( torch.set_printoptions(precision=3, sci_mode=True) - if is_optimum_habana_available(): + if is_hpex_available(): logger.info("Optimum Habana is available, import htcore explicitly.") import habana_frameworks.torch.core as htcore # pylint: disable=E0401 import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] @@ -430,6 +437,10 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: self.enable_torch_compile = False logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") + ### Recipe Mode ### + self.recipe_mode = False + self.recipe_results = {"recipe": {}, "results": {}} + def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None: """Sets the device map for specific blocks in the model. @@ -1394,7 +1405,11 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) - if self.device_map is not None: accelerate.hooks.remove_hook_from_submodules(block) - if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats): + if ( + hasattr(self, "formats") + and is_nv_fp(self.act_data_type) + and any("nv_fp" in format_ for format_ in self.formats) + ): from auto_round.utils import set_amax_for_all_moe_layers # enable moe experts act_max automatic generation for linears @@ -1549,6 +1564,8 @@ def quantize(self): f"Expected exactly one packing format when 'is_packing_immediate' is True, " f"but got {len(self.formats)} formats." ) + if self.recipe_mode: + return self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately @@ -2439,7 +2456,14 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to modules = block.modules() for module in modules: - update_fused_layer_global_scales(module) + try: + update_fused_layer_global_scales(module) + except: + # mix-precision may cause error, since q,k,v are not the same dtype. + logger.warning_once( + "Cannot keep the same global scale for fused layers, " + + "so the model may not work with vLLM with fused QKV or else." + ) round_params = [] minmax_params = [] for n, m in block.named_modules(): @@ -2561,7 +2585,7 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to logger.info(f"{unquantized_layer_names} have not been quantized") with torch.no_grad(): unwrapper_block(block, best_params) - if self.enable_quanted_input: + if self.enable_quanted_input and hasattr(self, "formats"): if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats): from auto_round.utils import set_amax_for_all_moe_layers @@ -2616,6 +2640,20 @@ def quantize_blocks( clear_memory() input_ids = to_device(input_ids, self.cache_device) input_others = to_device(input_others, self.cache_device) + if self.recipe_mode: + logger.info("[Recipe Mode] starts") + q_input_ids = None # init value + for block_name in tqdm(block_names): + block = get_module(model, block_name) + if not self.model.device.type == "meta" or self.low_cpu_mem_usage: + block = block.to(device) + input_ids, q_input_ids = self._generate_block_recipe( + block, block_name, input_ids, q_input_ids, input_others + ) + if is_hpex_available(): + htcore.mark_step() + logger.info("[Recipe Mode] ends") + return ## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage tmp_dtype = self.amp_dtype if self.amp else torch.float32 for i in range(len(input_ids)): @@ -2891,7 +2929,7 @@ def scale_loss_and_backward(self, scaler, loss): """ scale_loss = loss * 1000 scale_loss.backward() - if is_optimum_habana_available(): + if is_hpex_available(): htcore.mark_step() return scale_loss @@ -2908,7 +2946,7 @@ def step(self, scaler, optimizer, lr_schedule): """ optimizer.step() # for hpu - if is_optimum_habana_available(): + if is_hpex_available(): htcore.mark_step() optimizer.zero_grad() lr_schedule.step() @@ -3117,7 +3155,7 @@ def scale_loss_and_backward(self, scaler, loss): loss = scaler.scale(loss) loss.backward() - if is_optimum_habana_available(): + if is_hpex_available(): htcore.mark_step() return loss @@ -3131,5 +3169,5 @@ def step(self, scaler, optimizer, lr_schedule): optimizer.step() optimizer.zero_grad() lr_schedule.step() - if is_optimum_habana_available(): + if is_hpex_available(): htcore.mark_step() diff --git a/auto_round/data_type/mxfp.py b/auto_round/data_type/mxfp.py index 0c6ae265c..eaef3f454 100644 --- a/auto_round/data_type/mxfp.py +++ b/auto_round/data_type/mxfp.py @@ -22,6 +22,7 @@ revert_tensor_by_pad, round_ste, ) +from auto_round.utils import is_hpex_available MXFP_FORMAT_CACHE = { # data type: ebits, mbits, emax, max_norm, min_norm @@ -77,7 +78,6 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"): return tensor -@torch.compile() def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs): """Quantize the given tensor using the specified parameters. @@ -128,7 +128,6 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_roundin return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None -@torch.compile() def quant_mx_rceil( tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs ): @@ -180,6 +179,11 @@ def quant_mx_rceil( return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None +# HPU returns error with Habana software 1.22.0, so skip torch.compile here. +if False and not is_hpex_available(): + quant_mx = torch.compile(quant_mx) + quant_mx_rceil = torch.compile(quant_mx_rceil) + for key in MXFP_FORMAT_CACHE.keys(): QUANT_FUNC_WITH_DTYPE[key] = quant_mx QUANT_FUNC_WITH_DTYPE[key + "_rceil"] = quant_mx_rceil diff --git a/auto_round/data_type/nvfp.py b/auto_round/data_type/nvfp.py index 058c5345c..4ee95dc67 100644 --- a/auto_round/data_type/nvfp.py +++ b/auto_round/data_type/nvfp.py @@ -89,6 +89,8 @@ def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, **kwargs): tensor_max = tensor.abs().max().to(torch.float32) global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max) global_scale = global_scale.to(tensor.device) + # Ensure global_scale is in float32, sometimes tensor is in bf16/fp16 + global_scale = global_scale.to(torch.float32) qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v) qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len) return qdq_res.to(orig_dtype), scale, None @@ -109,6 +111,8 @@ def nv_fp4_with_static_gs(tensor, bits=4, group_size=16, v=0, tensor_max=None, * tensor_max = tensor.abs().max().to(torch.float32) global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max) global_scale = global_scale.to(tensor.device) + # Ensure global_scale is in float32, sometimes tensor is in bf16/fp16 + global_scale = global_scale.to(torch.float32) qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v) qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len) return qdq_res.to(orig_dtype), scale, None diff --git a/auto_round/utils.py b/auto_round/utils.py index 40fbcc5de..28a7f629c 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -85,6 +85,12 @@ def __getitem__(self, key): SUPPORTED_LAYER_TYPES = SUPPORTED_LAYER_TYPES + (LinearLayer, LinearAllreduce) +DTYPE_INFO_MAPPING = { + "nv_fp4": {"bits": 4, "group_size": 16, "sym": True}, + "mx_fp4": {"bits": 4, "group_size": 32, "sym": True}, + "mx_fp8": {"bits": 8, "group_size": 32, "sym": True}, +} + def infer_bits_by_data_type(data_type: str): if data_type is None: @@ -172,6 +178,30 @@ def __call__(self, *args, **kwargs): htcore = LazyImport("habana_frameworks.torch.core") +def is_package_available(package_name): + """Check if the package exists in the environment without importing. + + Args: + package_name (str): package name + """ + from importlib.util import find_spec + + package_spec = find_spec(package_name) + return package_spec is not None + + +if is_package_available("habana_frameworks"): + _hpex_available = True + import habana_frameworks.torch.hpex # pylint: disable=E0401 +else: + _hpex_available = False + + +def is_hpex_available(): + """Returns whether hpex is available.""" + return _hpex_available + + def is_optimum_habana_available(): from transformers.utils.import_utils import is_optimum_available @@ -569,7 +599,7 @@ def is_valid_digit(s): if torch.cuda.is_available(): device = torch.device("cuda") # logger.info("Using GPU device") - elif is_optimum_habana_available(): # pragma: no cover + elif is_hpex_available(): # pragma: no cover device = torch.device("hpu") # logger.info("Using HPU device") elif torch.xpu.is_available(): # pragma: no cover @@ -2502,3 +2532,272 @@ def is_mx_fp(backend): def is_nv_fp(backend): return BackendDataType.NV_FP in backend + + +######################################### Recipe related codes #################################### +def create_mp_block(block, mp_layers, mp_dtype): + """Create a mixed precision block. + + Args: + block (torch.nn.Module): original block. + mp_layers (list): list of layer names to apply mixed precision. + mp_dtype (dict): mixed precision data type configuration. + + Returns: + torch.nn.Module: mixed precision block. + """ + from auto_round.wrapper import WrapperLinear + + for layer_name in mp_layers: + layer = get_module(block, layer_name) + layer.data_type, layer.bits, layer.sym = mp_dtype["data_type"], mp_dtype["bits"], mp_dtype["sym"] + layer.act_data_type, layer.act_bits, layer.act_sym = ( + mp_dtype["act_data_type"], + mp_dtype["act_bits"], + mp_dtype["act_sym"], + ) + for n, m in block.named_modules(): + if isinstance(m, SUPPORTED_LAYER_TYPES): + if check_to_quantized(m): + new_m = WrapperLinear( + m, + enable_minmax_tuning=False, + enable_norm_bias_tuning=False, + device=m.weight.device, + ) + set_module(block, n, new_m) + if is_hpex_available(): + htcore.mark_step() + return block + + +def recover_mp_block(block, mp_layers, raw_dtype): + """Recover a mixed precision block. + + Args: + block (torch.nn.Module): mixed precision block. + mp_layers (list): list of layer names to recover. + raw_dtype (dict): original data type configuration. + + Returns: + torch.nn.Module: recovered block. + """ + from auto_round.wrapper import WrapperLinear + + for n, m in block.named_modules(): + if isinstance(m, WrapperLinear): + set_module(block, n, m.orig_layer) + for layer_name in mp_layers: + layer = get_module(block, layer_name) + layer.data_type, layer.bits, layer.sym = raw_dtype["data_type"], raw_dtype["bits"], raw_dtype["sym"] + layer.act_data_type, layer.act_bits, layer.act_sym = ( + raw_dtype["act_data_type"], + raw_dtype["act_bits"], + raw_dtype["act_sym"], + ) + if is_hpex_available(): + htcore.mark_step() + return block + + +def get_avg_bits(module): + """ + Calculates the average number of bits per weight element for supported layers in a given module. + + Iterates through all named modules in the module, accumulating the total number of weight elements + and the corresponding bit usage, including additional scale bits for specific data types. + + Args: + module: A neural network module containing layers to be analyzed. + + Returns: + float: The average number of bits per weight element across all supported layers. + + Note: + - Only layers of types specified in SUPPORTED_LAYER_TYPES are considered. + - For certain data types ("fp4_v2", "nv_fp4", "mx_fp4", "mx_fp8"), scale bits are added. + - For "fp4_v2" and "nv_fp4", an additional 32 global scale bits are included. + """ + all_numel = 0 + all_bits = 0 + for n, m in module.named_modules(): + if isinstance(m, SUPPORTED_LAYER_TYPES): + m_numel = m.weight.numel() + all_numel += m_numel + w_bits = m.bits * m_numel + all_bits += w_bits + if m.data_type in ("fp4_v2", "nv_fp", "mx_fp", "nv_fp4", "mx_fp4", "mx_fp8"): + scale_bits = 8 * (m_numel // m.group_size) + if m.data_type in ("fp4_v2", "nv_fp"): + scale_bits += 32 # global scale bits + all_bits += scale_bits + else: # woq + scale_bits = 16 * (m_numel // m.group_size) + all_bits += scale_bits + + avg_bits = all_bits / all_numel if all_numel > 0 else 0 + return round(avg_bits, 6) + + +def _generate_recipe( + self, + # same data type config as before + mp_dtype={ + "data_type": "mx_fp8", + "act_data_type": "mx_fp8", + }, + # special mix-precision configuration + mp_config={ + "mp_ratio": 1 / 3, + }, +): + """ + Generates a quantization recipe for the model based on the specified mixed-precision data type and configuration. + + Args: + mp_dtype (dict, optional): Dictionary specifying the mixed-precision data types for weights and activations. + Defaults to {"data_type": "mx_fp8", "act_data_type": "mx_fp8"}. + mp_config (dict, optional): Dictionary specifying the mixed-precision configuration parameters such as + ratio, loss weight, and numel weight. Defaults to {"mp_ratio": 1/3}. + + Returns: + dict: A dictionary containing the quantization recipe for each layer, excluding the "lm_head" layer. + """ + self.recipe_mp_dtype = mp_dtype + self.recipe_mp_config = mp_config + # traverse all blocks with self.recipe_mode=True + self.recipe_mode = True + self.quantize() + self.recipe_mode = False + # combine self.layer_config with self.recipe_results["recipe"] + recipe_layer_config = copy.deepcopy(self.layer_config) + recipe_layer_config.update(self.recipe_results["recipe"]) + recipe_layer_config.pop("lm_head") # lm_head is not included in the recipe + self.recipe_results["recipe"] = recipe_layer_config + # dump average bits of all blocks + avg_bits_all_block = 0 + for block_name, result in self.recipe_results["results"].items(): + avg_bits_all_block += result["bits"] + avg_bits_all_block /= len(self.recipe_results["results"]) + logger.info(f"[Recipe Mode] Average bits of all blocks: {round(avg_bits_all_block, 3)}") + return self.recipe_results + + +def _generate_block_recipe(self, block, block_name, input_ids, q_input_ids, input_others): + from itertools import combinations + + # fetch mix-precision recipe configuration + sample_num = self.recipe_mp_config.get("sample_num", 8) + quantizable_layers = [n for n, m in block.named_modules() if isinstance(m, SUPPORTED_LAYER_TYPES)] + target_bits = self.recipe_mp_config.get("target_bits", None) + if target_bits is None: + mp_ratio = self.recipe_mp_config.get("mp_ratio", 1 / 3) + + # calculate the number of layers to use mix-precision + mp_ratio_list = [f"{i}/{len(quantizable_layers)}" for i in range(1, len(quantizable_layers))] + quantizable_num = int(mp_ratio * len(quantizable_layers)) # It's ceiling + logger.warning_once( + f"[Recipe Mode] {len(quantizable_layers)} layers are detected, so the available mp_ratio values are {mp_ratio_list}" + ) + logger.warning_once(f"[Recipe Mode] {quantizable_num} layers of each block use the mixed precision.") + # fetch raw low-bits dtype of block for recovering mix-precision block + layer = get_module(block, quantizable_layers[0]) + raw_dtype = { + "data_type": layer.data_type, + "bits": layer.bits, + "sym": layer.sym, + "act_data_type": layer.act_data_type, + "act_bits": layer.act_bits, + "act_sym": layer.act_sym, + } + # update self.recipe_mp_dtype + self.recipe_mp_dtype.update( + { + "bits": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["bits"], + "group_size": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["group_size"], + "sym": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["sym"], + "act_bits": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["bits"], + "act_group_size": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["group_size"], + "act_sym": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["sym"], + } + ) + + # generate reference output of sample input_ids + def get_output(block, input_ids): + output = self.get_block_outputs( + block, + input_ids[:sample_num], + input_others, + bs=self.batch_size, + device=self.device, + cache_device=self.cache_device, + save_output=True, + ) + if is_hpex_available(): + htcore.mark_step() + return output + + reference_output = get_output(block, input_ids) + q_input_ids = input_ids if q_input_ids is None else q_input_ids + # generate q_output of sample input_ids and get loss + def get_loss(q_block, q_input_ids): + q_output = get_output(q_block, q_input_ids) + total_loss = 0 + mse_loss = torch.nn.MSELoss(reduction="sum").to(self.device) + for i in range(len(q_output)): + loss = mse_loss( # pylint: disable=not-callable + q_output[i].to(torch.float32), reference_output[i].to(torch.float32) + ) + total_loss += float(loss) + if is_hpex_available(): + htcore.mark_step() + return round(total_loss, 6) + + combination_list = [] + avg_bits_list = [] + loss_list = [] + for hp_layers in combinations(quantizable_layers, quantizable_num): + combination_list.append(hp_layers) + # get loss + block = create_mp_block(block, hp_layers, self.recipe_mp_dtype) + # get average bits + avg_bits = get_avg_bits(block) + avg_bits_list.append(avg_bits) + loss = get_loss(block, q_input_ids) + loss_list.append(loss) + block = recover_mp_block(block, hp_layers, raw_dtype) + if is_hpex_available(): + htcore.mark_step() + logger.debug(f"{hp_layers}, {loss}, {avg_bits}") + + # get combination with lowest loss + best_loss = float("inf") + for i, (loss, avg_bits) in enumerate(zip(loss_list, avg_bits_list)): + if best_loss > loss: + best_loss = loss + best_avg_bits = avg_bits + best_combination = combination_list[i] + + logger.info(f"[Recipe Mode] Recipe results of {block_name}:\nMix precision layers: {best_combination};\nAverage bits: {best_avg_bits}.") + # generate output of quantized block of sample input_ids + block = create_mp_block(block, best_combination, self.recipe_mp_dtype) + q_output = get_output(block, q_input_ids) + block = recover_mp_block(block, best_combination, raw_dtype) + # update recipe and results + for layer_name in best_combination: + self.recipe_results["recipe"].update({block_name + "." + layer_name: self.recipe_mp_dtype}) + self.recipe_results["results"].update( + { + block_name: { + "mp_layers": best_combination, + "bits": best_avg_bits, + } + } + ) + if is_hpex_available(): + htcore.mark_step() + + return reference_output, q_output + + +############################################################################################### diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 8942a2c0d..64ecafe0e 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -91,6 +91,7 @@ def __init__( self.orig_layer = orig_layer self.output_device = device self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device + self.extra_repr_org = orig_layer.extra_repr self.enable_minmax_tuning = enable_minmax_tuning self.enable_round_tuning = enable_round_tuning self.enable_norm_bias_tuning = enable_norm_bias_tuning and (orig_layer.bias is not None) @@ -451,6 +452,9 @@ def forward(self, x): output = self.orig_forward(x, weight_q, bias).to(self.output_device) return output + def extra_repr(self): + return f"{self.extra_repr_org()}, weight_type={self.data_type}, act_data_type={self.act_data_type}" + class WrapperWALayer(torch.nn.Module): def __init__(self, orig_layer): diff --git a/test/test_cpu/test_recipe.py b/test/test_cpu/test_recipe.py new file mode 100644 index 000000000..f3968592e --- /dev/null +++ b/test/test_cpu/test_recipe.py @@ -0,0 +1,73 @@ +import shutil +import sys +import unittest + +sys.path.insert(0, "../..") +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round import AutoRound + + +class LLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.ones([1, 10], dtype=torch.long) + + +class TestAutoRound(unittest.TestCase): + @classmethod + def setUpClass(self): + model_name = "facebook/opt-125m" + self.save_dir = "./saved" + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.llm_dataloader = LLMDataLoader() + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_recipe_api(self): + bits = 4 + act_bits = 4 + data_type = "nv_fp" + act_data_type = "nv_fp4_with_static_gs" + group_size = 16 + sym = True + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + act_bits=act_bits, + data_type=data_type, + act_data_type=act_data_type, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + ) + layer_config = autoround._generate_recipe( + mp_dtype={ + "data_type": "mx_fp8", + "act_data_type": "mx_fp8", + }, + mp_config={ + "mp_ratio": 1 / 3, + "loss_weight": 2.0, + "numel_weight": 1.0, + }, + ) + autoround.layer_config = layer_config + autoround.quantize() + # autoround.quantize_and_save() # save is not supported for mix-precision + print(autoround.model) + + +if __name__ == "__main__": + unittest.main() diff --git a/workspace/README.md b/workspace/README.md new file mode 100644 index 000000000..acfb247a6 --- /dev/null +++ b/workspace/README.md @@ -0,0 +1,29 @@ + + ```bash +############################### Gaudi model path ############################################# +deepspeed --include="localhost:2,3" --master_port=29520 quantize.py --autoround --batch_size 16 --accuracy --dtype mx_fp4 --mp_ratio 5/7 2>&1 |tee mxfp4_op_5_8b.log + +deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --autoround --batch_size 16 --accuracy --dtype mx_fp4 --mp_ratio 3/7 2>&1 |tee mxfp4_op_3_8b.log + +deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /git_lfs/data/pytorch/llama3.3/Meta-Llama-3.3-70B-Instruct/ --batch_size 8 --accuracy --dtype mx_fp4 --mp_ratio 4/7 --autoround 2>&1 |tee mxfp4_op_4_3.3_70b.log + +deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /git_lfs/data/pytorch/llama3.3/Meta-Llama-3.3-70B-Instruct/ --batch_size 8 --accuracy --dtype mx_fp4 --mp_ratio 5/7 --autoround 2>&1 |tee mxfp4_op_5_3.3_70b.log + + +############################### H20 model path ############################################# +deepspeed --include="localhost:4,5,6,7" --master_port=29500 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.1-8B-Instruct/ --autoround --batch_size 64 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 3/7 2>&1 |tee mxfp4_op_3_8b.log + +deepspeed --include="localhost:2,3" --master_port=29520 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.1-8B-Instruct/ --autoround --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 5/7 2>&1 |tee mxfp4_op_5_8b.log + +deepspeed --include="localhost:2,3" --master_port=29520 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.1-8B-Instruct/ --autoround --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 6/7 2>&1 |tee mxfp4_op_6_8b.log + +deepspeed --include="localhost:0,1,2,3" --master_port=29500 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.3-70B-Instruct/ --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 5/7 --autoround 2>&1 |tee mxfp4_op_5_3.3_70b.log + +deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.3-70B-Instruct/ --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 3/7 --autoround 2>&1 |tee mxfp4_op_3_3.3_70b.log + +H20-2-1 +deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.3-70B-Instruct/ --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 4/7 --autoround 2>&1 |tee mxfp4_op_4_3.3_70b.log + + +``` +python quantize.py --model_name_or_path facebook/opt-125m --batch_size 8 --accuracy --dtype mx_fp4 --mp_ratio 1/3 --autoround diff --git a/workspace/quantize.py b/workspace/quantize.py new file mode 100644 index 000000000..f704e5b1c --- /dev/null +++ b/workspace/quantize.py @@ -0,0 +1,320 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +import transformers + +######################## HPU Memory Optimization ########################### +# ensure that unnecessary memory is released during quantization. +os.environ.setdefault("PT_HPU_LAZY_MODE", "1") +os.environ.setdefault("PT_HPU_WEIGHT_SHARING", "0") +if int(os.getenv("WORLD_SIZE", "0")) > 0: + os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0") + os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true") +from neural_compressor.torch.utils import is_hpex_available + +if is_hpex_available(): + import habana_frameworks.torch.core as htcore + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + from neural_compressor.torch.utils import get_used_hpu_mem_MB + + htcore.hpu_set_env() + + +def initialize_model_and_tokenizer(model_name_or_path): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) + config = transformers.AutoConfig.from_pretrained(model_name_or_path) + # using memory mapping with torch_dtype=config.torch_dtype + model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=config.torch_dtype) + # shard model for multi-cards and enable hpu graph + from neural_compressor.torch.utils import local_rank, logger, world_size + + if world_size > 1: + ds_inference_kwargs = { + "dtype": config.torch_dtype, + "tensor_parallel": {"tp_size": world_size}, + } + import deepspeed + + ds_model = deepspeed.init_inference(model, **ds_inference_kwargs) + model = ds_model.module + model.eval() + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Habana FP8 quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="model name or path" + ) + parser.add_argument("--output_dir", type=str, default="saved_results", help="model name or path") + parser.add_argument("--device", type=str, default="hpu", help="device") + parser.add_argument("--dtype", type=str, default="mx_fp4", help="model name or path") + parser.add_argument("--quantize", action="store_true", help="whether to quantize model") + parser.add_argument("--tune", action="store_true", help="whether to autoround model") + parser.add_argument("--autoround", action="store_true", help="whether to autoround model") + parser.add_argument("--iters", default=None, type=int, help="iters for autoround.") + parser.add_argument("--seqlen", default=None, type=int, help="sequence length for autoround.") + parser.add_argument("--nsamples", default=None, type=int, help="number of samples for autoround.") + parser.add_argument("--target_bits", default=5, type=float, help="number of samples for autoround.") + parser.add_argument("--target_loss_ratio", default=1.2, type=float, help="number of samples for autoround.") + parser.add_argument( + "--use_hpu_graph", action="store_true", help="whether to use hpu graph mode to accelerate performance" + ) + parser.add_argument( + "--enable_block_wise_calibration", action="store_true", help="whether to use block-wise calibration" + ) + parser.add_argument( + "--disable_optimum_habana", action="store_true", help="whether to use adapt_transformers_to_gaudi" + ) + parser.add_argument("--mp_ratio", default="1/3", type=str, help="number of samples for autoround.") + parser.add_argument("--save", action="store_true", help="whether to save the quantized model") + parser.add_argument("--load", action="store_true", help="whether to load the quantized model") + parser.add_argument("--save_path", type=str, default="saved_results", help="path to save the quantized model") + parser.add_argument("--quant_lm_head", action="store_true", help="performance measurement") + parser.add_argument("--accuracy", action="store_true", help="accuracy measurement") + parser.add_argument("--performance", action="store_true", help="performance measurement") + parser.add_argument("--local_rank", type=int, default=0, metavar="N", help="Local process rank.") + parser.add_argument("--batch_size", default=4, type=int, help="batch size for accuracy measurement.") + parser.add_argument("--num_fewshot", default=0, type=int, help="num_fewshot of lm_eval.") + parser.add_argument( + "--mxfp8_mod_list", + type=str, + nargs="*", + default=[], # 默认值 + help="List of module names or patterns for MXFP8 quantization.", + ) + parser.add_argument( + "--fp8_mod_list", + type=str, + nargs="+", # 接受一个或多个字符串作为列表 + default=[], # 默认值 + help="List of module names or patterns for FP8 quantization.", + ) + parser.add_argument( + "--bf16_mod_list", + type=str, + nargs="+", # 接受一个或多个字符串作为列表 + default=[], # 默认值 + help="List of module names or patterns for MXFP8 quantization.", + ) + parser.add_argument( + "--dump_stats_path", type=str, default="./hqt_output/measure", help="path and prefix to calibration info file." + ) + parser.add_argument( + "--tasks", + type=str, + nargs="+", # 接受一个或多个字符串作为列表 + default=[ + "piqa", + "hellaswag", + "mmlu", + "winogrande", + "lambada_openai", + ], # 默认值 + help="tasks for accuracy validation, text-generation and code-generation tasks are different.", + ) + parser.add_argument( + "--dataset_name", type=str, default="NeelNanda/pile-10k", help="dataset name for calibration dataloader" + ) + parser.add_argument("--limit", type=int, default=None, help="number of samples for accuracy evaluation") + args = parser.parse_args() + print("Target data type:", args.dtype) + + model, tokenizer = initialize_model_and_tokenizer(args.model_name_or_path) + if args.quantize: + lm_head_config = { + "group_size": 32 if "mx" in args.dtype else 16, + "data_type": args.dtype, + "act_data_type": "fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype, + } + layer_config = {"lm_head": lm_head_config} + from auto_round import AutoRound + + autoround = AutoRound( + model, + tokenizer, + device=args.device, + iters=200 if args.tune else 0, + low_gpu_mem_usage=True, + group_size=32 if "mx" in args.dtype else 16, + data_type=args.dtype, + act_data_type="fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype, + layer_config=layer_config if args.quant_lm_head else None, + ) + autoround.quantize() + model = autoround.model + + if args.autoround: + from deepspeed.module_inject import ( + LinearAllreduce, + LinearLayer, + ) + + MXFP4_MODULE_MAPPING = { + torch.nn.Linear: None, + torch.nn.EmbeddingBag: None, + LinearLayer: None, + LinearAllreduce: None, + } + + from auto_round import AutoRound + + def match_pattern(name, pattern): + for pat in pattern: + if pat in name: + return True + return False + + layer_config = {} + fp8_config = { + "bits": 8, + "data_type": "fp8", + "act_data_type": "fp8", + } + mxfp4_config = { + "bits": 4, + "group_size": 32, + "data_type": "mx_fp4", + "act_data_type": "mx_fp4", + } + mxfp8_config = { + "bits": 8, + "group_size": 32, + "data_type": "mx_fp8", + "act_data_type": "mx_fp8", + } + module_name_to_quantize: list[str] = [ + n for n, m in model.named_modules() if isinstance(m, tuple(MXFP4_MODULE_MAPPING.keys())) + ] + for name in module_name_to_quantize: + if match_pattern(name, args.mxfp8_mod_list): + layer_config.update({name: mxfp8_config}) + if match_pattern(name, args.fp8_mod_list): + layer_config.update({name: fp8_config}) + if args.quant_lm_head: + layer_config.update({"lm_head": mxfp8_config}) + + from auto_round import AutoRound + + autoround = AutoRound( + model, + tokenizer, + device=args.device, + low_gpu_mem_usage=True, + group_size=32 if "mx" in args.dtype else 16, + data_type=args.dtype, + act_data_type="fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype, + layer_config=layer_config, + ) + + recipe_results = autoround._generate_recipe( + mp_config={ + "mp_ratio": float(eval(args.mp_ratio)), + }, + ) + autoround.layer_config = recipe_results["recipe"] + autoround.quantize() + model = autoround.model + + # preprocess model for accuracy and performance measurement + if not args.load and not args.autoround and not args.quantize: + # compare fp8 with bf16, not fp32. + model = model.to(torch.bfloat16) + model = model.eval().to(args.device) + print(model) + + if args.accuracy: + if is_hpex_available(): + model = wrap_in_hpu_graph(model) + htcore.hpu_inference_initialize(model, mark_only_scales_as_const=True) + from neural_compressor.evaluation.lm_eval import LMEvalParser, evaluate + + tasks = ",".join(args.tasks) + eval_args = LMEvalParser( + model="hf", + user_model=model, + tokenizer=tokenizer, + batch_size=args.batch_size, + tasks=tasks, + device="hpu", + pad_to_buckets=True, + num_fewshot=args.num_fewshot, + limit=args.limit, + add_bos_token=True, + ) + results = evaluate(eval_args) + torch.hpu.synchronize() + all_accuracy = {} + for task_name, task_results in results["results"].items(): + if task_name in ["hellaswag", "lambada_openai", "piqa", "winogrande", "mmlu"]: + accu = task_results["acc,none"] + all_accuracy[task_name] = accu + print(f"Accuracy for {task_name}: {accu:.4f}") + print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}") + else: + # model = torch.compile(model) + args.tasks = ["piqa", "hellaswag", "mmlu", "gsm8k"] + all_accuracy = {} + test_gsm8k = False + test_normal = False + if "gsm8k" in args.tasks: + test_gsm8k = True + args.tasks.remove("gsm8k") + if args.tasks: + test_normal = True + import lm_eval + from lm_eval.models.huggingface import HFLM + + if test_normal: + lm = HFLM( + pretrained=model, + tokenizer=tokenizer, + add_bos_token=True, + batch_size=args.batch_size, + ) + results = lm_eval.simple_evaluate( + lm, + tasks=args.tasks, + limit=args.limit, + ) + for task_name, task_results in results["results"].items(): + if task_name in ["hellaswag", "lambada_openai", "piqa", "winogrande", "mmlu"]: + accu = task_results["acc,none"] + all_accuracy[task_name] = accu + ########################## gms8k ######################### + if test_gsm8k: + lm = HFLM( + pretrained=model, + tokenizer=tokenizer, + add_bos_token=False, + batch_size=args.batch_size, + ) + results_gsm8k = lm_eval.simple_evaluate( + lm, + tasks=["gsm8k"], + limit=args.limit, + ) + for task_name, task_results in results_gsm8k["results"].items(): + accu = task_results["exact_match,strict-match"] + all_accuracy[task_name] = accu + ########################## gms8k end ######################### + for task_name, accu in all_accuracy.items(): + print(f"Accuracy for {task_name}: {accu:.4f}") + print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}")