Skip to content

Commit 92025ad

Browse files
committed
add mxfp4_loss / mxfp8_loss to help
Signed-off-by: xinhe3 <[email protected]>
1 parent f02331d commit 92025ad

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

auto_round/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2740,6 +2740,7 @@ def get_output(block, input_ids):
27402740
reference_output = get_output(block, input_ids)
27412741
q_input_ids = input_ids if q_input_ids is None else q_input_ids
27422742
# generate q_output of sample input_ids and get loss
2743+
@torch.no_grad()
27432744
def get_loss(q_block, q_input_ids):
27442745
q_output = get_output(q_block, q_input_ids)
27452746
total_loss = 0
@@ -2753,6 +2754,29 @@ def get_loss(q_block, q_input_ids):
27532754
htcore.mark_step()
27542755
return round(total_loss, 6)
27552756

2757+
# get mxfp8 loss
2758+
hp_layers = quantizable_layers
2759+
block = create_mp_block(block, hp_layers, self.recipe_mp_dtype)
2760+
mxfp8_loss = get_loss(block, q_input_ids)
2761+
block = recover_mp_block(block, hp_layers, raw_dtype)
2762+
hp_layers = []
2763+
block = create_mp_block(block, hp_layers, self.recipe_mp_dtype)
2764+
mxfp4_loss = get_loss(block, q_input_ids)
2765+
block = recover_mp_block(block, hp_layers, raw_dtype)
2766+
logger.info(f"loss_ratio [mxfp4_loss / mxfp8_loss]: {mxfp4_loss/mxfp8_loss}")
2767+
if is_hpex_available():
2768+
htcore.mark_step()
2769+
if int(block_name.split(".")[-1]) == 0:
2770+
self.target_loss_ratio = (mxfp4_loss / mxfp8_loss) * (1 - mp_ratio)
2771+
logger.warning_once(f"[Recipe Mode] Based on the mp_ratio, we set the target_loss_ratio: {self.target_loss_ratio}")
2772+
if mxfp4_loss / mxfp8_loss > self.target_loss_ratio:
2773+
quantizable_num += 1
2774+
logger.warning(f"[Recipe Mode] Due to [mxfp4_loss / mxfp8_loss]: {mxfp4_loss / mxfp8_loss} > {self.target_loss_ratio}")
2775+
logger.warning(f"[Recipe Mode] Set {quantizable_num} layers using mixed precision for this block.")
2776+
elif mxfp4_loss / mxfp8_loss < 1: # special case for llama3.3 70B
2777+
quantizable_num -= 1
2778+
logger.warning(f"[Recipe Mode] Due to [mxfp4_loss / mxfp8_loss]: {mxfp4_loss / mxfp8_loss} < 1")
2779+
logger.warning(f"[Recipe Mode] Set {quantizable_num} layers using mixed precision for this block.")
27562780
combination_list = []
27572781
avg_bits_list = []
27582782
loss_list = []

0 commit comments

Comments
 (0)