File tree Expand file tree Collapse file tree 2 files changed +9
-3
lines changed
python/llm/src/ipex_llm/transformers/models Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -79,7 +79,10 @@ def llama_model_forward(
7979 # IPEX-LLM OPT start: kv cache and quantize kv cache
8080 inputs = input_ids if input_ids is not None else inputs_embeds
8181 use_cache = True if inputs .device .type == "xpu" else use_cache
82- use_quantize_kv = use_quantize_kv_cache (self .layers [0 ].mlp .down_proj , inputs )
82+ use_quantize_kv = use_quantize_kv_cache (
83+ self .layers [0 ].mlp .down_proj , inputs ,
84+ self .config .num_attention_heads // self .config .num_key_value_heads
85+ )
8386 if use_cache :
8487 if use_quantize_kv and not isinstance (past_key_values , DynamicFp8Cache ):
8588 past_key_values = DynamicFp8Cache .from_legacy_cache (past_key_values )
@@ -114,7 +117,7 @@ def llama_model_forward(
114117
115118 # IPEX-LLM OPT start: use fused rope
116119 if (should_use_fuse_rope (hidden_states , position_ids , False )
117- and self .rotary_emb .rope_type == " llama3" ):
120+ and self .rotary_emb .rope_type in [ "default" , " llama3"] ):
118121 position_embeddings = self .rotary_emb .inv_freq
119122 # IEPX_LLM OPT end
120123
Original file line number Diff line number Diff line change @@ -129,7 +129,10 @@ def mllama_text_model_forward(
129129 # IPEX-LLM OPT start: kv cache and quantize kv cache
130130 inputs = input_ids if input_ids is not None else inputs_embeds
131131 use_cache = True if inputs .device .type == "xpu" else use_cache
132- use_quantize_kv = use_quantize_kv_cache (self .layers [0 ].mlp .down_proj , inputs )
132+ use_quantize_kv = use_quantize_kv_cache (
133+ self .layers [0 ].mlp .down_proj , inputs ,
134+ self .config .num_attention_heads // self .config .num_key_value_heads
135+ )
133136 if use_cache :
134137 if use_quantize_kv and not isinstance (past_key_values , DynamicFp8Cache ):
135138 past_key_values = DynamicFp8Cache .from_legacy_cache (past_key_values )
You can’t perform that action at this time.
0 commit comments