Skip to content

Commit 72605c7

Browse files
authored
fix llama3.1/3.2 quantize kv check (#12302)
1 parent 416c191 commit 72605c7

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

python/llm/src/ipex_llm/transformers/models/llama32.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def llama_model_forward(
7979
# IPEX-LLM OPT start: kv cache and quantize kv cache
8080
inputs = input_ids if input_ids is not None else inputs_embeds
8181
use_cache = True if inputs.device.type == "xpu" else use_cache
82-
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
82+
use_quantize_kv = use_quantize_kv_cache(
83+
self.layers[0].mlp.down_proj, inputs,
84+
self.config.num_attention_heads // self.config.num_key_value_heads
85+
)
8386
if use_cache:
8487
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
8588
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
@@ -114,7 +117,7 @@ def llama_model_forward(
114117

115118
# IPEX-LLM OPT start: use fused rope
116119
if (should_use_fuse_rope(hidden_states, position_ids, False)
117-
and self.rotary_emb.rope_type == "llama3"):
120+
and self.rotary_emb.rope_type in ["default", "llama3"]):
118121
position_embeddings = self.rotary_emb.inv_freq
119122
# IEPX_LLM OPT end
120123

python/llm/src/ipex_llm/transformers/models/mllama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ def mllama_text_model_forward(
129129
# IPEX-LLM OPT start: kv cache and quantize kv cache
130130
inputs = input_ids if input_ids is not None else inputs_embeds
131131
use_cache = True if inputs.device.type == "xpu" else use_cache
132-
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
132+
use_quantize_kv = use_quantize_kv_cache(
133+
self.layers[0].mlp.down_proj, inputs,
134+
self.config.num_attention_heads // self.config.num_key_value_heads
135+
)
133136
if use_cache:
134137
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
135138
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)

0 commit comments

Comments
 (0)