Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 26 additions & 43 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \
get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
Expand All @@ -64,7 +63,6 @@
except ImportError:
Cache = Tuple[torch.Tensor]

from ipex_llm.transformers.low_bit_linear import FP6, FP16

import os

Expand Down Expand Up @@ -274,8 +272,6 @@ def mistral_attention_forward_quantized(
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
Expand Down Expand Up @@ -304,7 +300,8 @@ def mistral_attention_forward_quantized(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand All @@ -321,11 +318,9 @@ def mistral_attention_forward_quantized(
kv_seq_len += past_key_value[0].shape[-2]

if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
Expand Down Expand Up @@ -482,8 +477,6 @@ def mistral_attention_forward_original(
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
Expand All @@ -506,7 +499,8 @@ def mistral_attention_forward_original(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
kv_seq_len += 1
else:

Expand Down Expand Up @@ -542,11 +536,9 @@ def mistral_attention_forward_original(
kv_seq_len += past_key_value[0].shape[-2]

if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
Expand Down Expand Up @@ -708,8 +700,6 @@ def mistral_attention_forward_4_36_quantized(
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
seq_len=q_len)
Expand Down Expand Up @@ -739,7 +729,8 @@ def mistral_attention_forward_4_36_quantized(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand All @@ -765,11 +756,9 @@ def mistral_attention_forward_4_36_quantized(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
Expand Down Expand Up @@ -928,8 +917,6 @@ def mistral_attention_forward_4_36_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_idx,
Expand Down Expand Up @@ -958,7 +945,8 @@ def mistral_attention_forward_4_36_original(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
kv_seq_len += 1

# update past_key_value's seem_tokens and kv caches.
Expand Down Expand Up @@ -1011,11 +999,9 @@ def mistral_attention_forward_4_36_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
Expand Down Expand Up @@ -1189,8 +1175,6 @@ def mistral_attention_forward_4_39_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
Expand Down Expand Up @@ -1218,7 +1202,8 @@ def mistral_attention_forward_4_39_original(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
kv_seq_len += 1

# update past_key_value's seem_tokens and kv caches.
Expand Down Expand Up @@ -1270,11 +1255,9 @@ def mistral_attention_forward_4_39_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
Expand Down