1414from QEfficient .transformers .quantizers .auto import replace_transformers_quantizers , undo_transformers_quantizers
1515from QEfficient .transformers .quantizers .awq import WQLinear_GEMM
1616from QEfficient .transformers .quantizers .gptq import QuantLinearGPTQ
17+ from QEfficient .transformers .quantizers .quantizer_compressed_tensors import FP8DeQuantLinear
1718
1819
1920def duplicate_weights_for_linear_layer (
@@ -49,6 +50,15 @@ def duplicate_weights_for_linear_layer(
4950 1 ,
5051 ).view (hidden_size // layer .group_size , new_kv_heads * head_dim )
5152 layer .out_features = layer .out_features * repeat
53+
54+ elif isinstance (layer , FP8DeQuantLinear ):
55+ layer .weight .data = torch .repeat_interleave (
56+ layer .weight .data .view (orig_kv_heads , head_dim , hidden_size ), repeat , 0
57+ ).view (new_kv_heads * head_dim , hidden_size )
58+ layer .weight_scale .data = torch .repeat_interleave (
59+ layer .weight_scale .data .view (orig_kv_heads , head_dim ), repeat , 0
60+ ).view (new_kv_heads * head_dim , - 1 )
61+
5262 else :
5363 layer .weight .data = torch .repeat_interleave (
5464 layer .weight .data .view (orig_kv_heads , head_dim , hidden_size ), repeat , 0
@@ -65,7 +75,6 @@ def main(args):
6575 model_kwargs = {"attn_implementation" : "eager" }
6676 if args .num_hidden_layers :
6777 model_kwargs ["num_hidden_layers" ] = args .num_hidden_layers
68-
6978 model = AutoModelForCausalLM .from_pretrained (model_name , ** model_kwargs )
7079
7180 # Undo the effect of replace_transformers_quantizers
0 commit comments