@@ -2740,6 +2740,7 @@ def get_output(block, input_ids):
2740
2740
reference_output = get_output (block , input_ids )
2741
2741
q_input_ids = input_ids if q_input_ids is None else q_input_ids
2742
2742
# generate q_output of sample input_ids and get loss
2743
+ @torch .no_grad ()
2743
2744
def get_loss (q_block , q_input_ids ):
2744
2745
q_output = get_output (q_block , q_input_ids )
2745
2746
total_loss = 0
@@ -2753,6 +2754,29 @@ def get_loss(q_block, q_input_ids):
2753
2754
htcore .mark_step ()
2754
2755
return round (total_loss , 6 )
2755
2756
2757
+ # get mxfp8 loss
2758
+ hp_layers = quantizable_layers
2759
+ block = create_mp_block (block , hp_layers , self .recipe_mp_dtype )
2760
+ mxfp8_loss = get_loss (block , q_input_ids )
2761
+ block = recover_mp_block (block , hp_layers , raw_dtype )
2762
+ hp_layers = []
2763
+ block = create_mp_block (block , hp_layers , self .recipe_mp_dtype )
2764
+ mxfp4_loss = get_loss (block , q_input_ids )
2765
+ block = recover_mp_block (block , hp_layers , raw_dtype )
2766
+ logger .info (f"loss_ratio [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } " )
2767
+ if is_hpex_available ():
2768
+ htcore .mark_step ()
2769
+ if int (block_name .split ("." )[- 1 ]) == 0 :
2770
+ self .target_loss_ratio = (mxfp4_loss / mxfp8_loss ) * (1 - mp_ratio )
2771
+ logger .warning_once (f"[Recipe Mode] Based on the mp_ratio, we set the target_loss_ratio: { self .target_loss_ratio } " )
2772
+ if mxfp4_loss / mxfp8_loss > self .target_loss_ratio :
2773
+ quantizable_num += 1
2774
+ logger .warning (f"[Recipe Mode] Due to [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } > { self .target_loss_ratio } " )
2775
+ logger .warning (f"[Recipe Mode] Set { quantizable_num } layers using mixed precision for this block." )
2776
+ elif mxfp4_loss / mxfp8_loss < 1 : # special case for llama3.3 70B
2777
+ quantizable_num -= 1
2778
+ logger .warning (f"[Recipe Mode] Due to [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } < 1" )
2779
+ logger .warning (f"[Recipe Mode] Set { quantizable_num } layers using mixed precision for this block." )
2756
2780
combination_list = []
2757
2781
avg_bits_list = []
2758
2782
loss_list = []
0 commit comments