-
Notifications
You must be signed in to change notification settings - Fork 52
add autoround._generate_recipe() #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xin3he
wants to merge
13
commits into
main
Choose a base branch
from
xinhe/mix-precision
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
d8b831e
add autoround._generate_recipe()
xinhe3 a78f99c
remove duplicate code
xinhe3 62f81e5
pop lm_head in recipe dict
xinhe3 79323d6
add more info in code and ut
xinhe3 086eae2
refactor code per review comments
xinhe3 b67c79a
resolve circular import
xinhe3 715e2a1
fix bug
xinhe3 5a631f1
add workspace for testing
xinhe3 2a0e0b8
stop compile
xinhe3 f02331d
use loss only
xinhe3 92025ad
add mxfp4_loss / mxfp8_loss to help
xinhe3 ade7736
add aggressive mode
xinhe3 580ec8b
add ref_q and tight threshold
xinhe3 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,8 @@ | |
SUPPORTED_LAYER_TYPES, | ||
TORCH_VERSION_AT_LEAST_2_6, | ||
CpuInfo, | ||
_generate_block_recipe, | ||
_generate_recipe, | ||
_gguf_args_check, | ||
block_forward, | ||
check_and_mark_fp8_model, | ||
|
@@ -67,9 +69,9 @@ | |
infer_bits_by_data_type, | ||
init_cache, | ||
is_debug_mode, | ||
is_hpex_available, | ||
is_mx_fp, | ||
is_nv_fp, | ||
is_optimum_habana_available, | ||
is_standard_fp, | ||
llm_load_model, | ||
logger, | ||
|
@@ -101,6 +103,11 @@ class AutoRound(object): | |
enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers. | ||
""" | ||
|
||
# If function is not necessary for AutoRound, putting it in other place and | ||
# assembling the function here, can improve code readability and maintainability | ||
_generate_recipe = _generate_recipe | ||
_generate_block_recipe = _generate_block_recipe | ||
|
||
def __init__( | ||
self, | ||
model: Union[torch.nn.Module, str], | ||
|
@@ -364,7 +371,7 @@ def __init__( | |
|
||
torch.set_printoptions(precision=3, sci_mode=True) | ||
|
||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
logger.info("Optimum Habana is available, import htcore explicitly.") | ||
import habana_frameworks.torch.core as htcore # pylint: disable=E0401 | ||
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] | ||
|
@@ -430,6 +437,10 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: | |
self.enable_torch_compile = False | ||
logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") | ||
|
||
### Recipe Mode ### | ||
self.recipe_mode = False | ||
self.recipe_results = {"recipe": {}, "results": {}} | ||
|
||
def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None: | ||
"""Sets the device map for specific blocks in the model. | ||
|
||
|
@@ -1394,7 +1405,11 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) - | |
if self.device_map is not None: | ||
accelerate.hooks.remove_hook_from_submodules(block) | ||
|
||
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats): | ||
if ( | ||
hasattr(self, "formats") | ||
and is_nv_fp(self.act_data_type) | ||
and any("nv_fp" in format_ for format_ in self.formats) | ||
): | ||
from auto_round.utils import set_amax_for_all_moe_layers | ||
|
||
# enable moe experts act_max automatic generation for linears | ||
|
@@ -1549,6 +1564,8 @@ def quantize(self): | |
f"Expected exactly one packing format when 'is_packing_immediate' is True, " | ||
f"but got {len(self.formats)} formats." | ||
) | ||
if self.recipe_mode: | ||
return | ||
|
||
self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately | ||
|
||
|
@@ -2439,7 +2456,14 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to | |
|
||
modules = block.modules() | ||
for module in modules: | ||
update_fused_layer_global_scales(module) | ||
try: | ||
update_fused_layer_global_scales(module) | ||
except: | ||
# mix-precision may cause error, since q,k,v are not the same dtype. | ||
logger.warning_once( | ||
"Cannot keep the same global scale for fused layers, " | ||
+ "so the model may not work with vLLM with fused QKV or else." | ||
) | ||
round_params = [] | ||
minmax_params = [] | ||
for n, m in block.named_modules(): | ||
|
@@ -2561,7 +2585,7 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to | |
logger.info(f"{unquantized_layer_names} have not been quantized") | ||
with torch.no_grad(): | ||
unwrapper_block(block, best_params) | ||
if self.enable_quanted_input: | ||
if self.enable_quanted_input and hasattr(self, "formats"): | ||
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats): | ||
from auto_round.utils import set_amax_for_all_moe_layers | ||
|
||
|
@@ -2616,6 +2640,20 @@ def quantize_blocks( | |
clear_memory() | ||
input_ids = to_device(input_ids, self.cache_device) | ||
input_others = to_device(input_others, self.cache_device) | ||
if self.recipe_mode: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to wrap this new code into a function and call it as early as possible. |
||
logger.info("[Recipe Mode] starts") | ||
q_input_ids, ref_q_input_ids = None, None # init value | ||
for block_name in tqdm(block_names): | ||
block = get_module(model, block_name) | ||
if not self.model.device.type == "meta" or self.low_cpu_mem_usage: | ||
block = block.to(device) | ||
input_ids, q_input_ids, ref_q_input_ids = self._generate_block_recipe( | ||
block, block_name, input_ids, q_input_ids, ref_q_input_ids, input_others | ||
) | ||
if is_hpex_available(): | ||
htcore.mark_step() | ||
logger.info("[Recipe Mode] ends") | ||
return | ||
## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage | ||
tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
for i in range(len(input_ids)): | ||
|
@@ -2891,7 +2929,7 @@ def scale_loss_and_backward(self, scaler, loss): | |
""" | ||
scale_loss = loss * 1000 | ||
scale_loss.backward() | ||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
htcore.mark_step() | ||
return scale_loss | ||
|
||
|
@@ -2908,7 +2946,7 @@ def step(self, scaler, optimizer, lr_schedule): | |
""" | ||
optimizer.step() | ||
# for hpu | ||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
htcore.mark_step() | ||
optimizer.zero_grad() | ||
lr_schedule.step() | ||
|
@@ -3117,7 +3155,7 @@ def scale_loss_and_backward(self, scaler, loss): | |
loss = scaler.scale(loss) | ||
|
||
loss.backward() | ||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
htcore.mark_step() | ||
return loss | ||
|
||
|
@@ -3131,5 +3169,5 @@ def step(self, scaler, optimizer, lr_schedule): | |
optimizer.step() | ||
optimizer.zero_grad() | ||
lr_schedule.step() | ||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
htcore.mark_step() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.