From aea4dc046cb051a7c9b6507bf5b812c67ec0c726 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 26 Sep 2024 15:00:57 +0800 Subject: [PATCH] optimize llama 3.2 rope --- .../ipex_llm/transformers/models/llama32.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama32.py b/python/llm/src/ipex_llm/transformers/models/llama32.py index 9889797c832..9cb0f2c30a6 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama32.py +++ b/python/llm/src/ipex_llm/transformers/models/llama32.py @@ -48,6 +48,7 @@ from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache @@ -111,6 +112,12 @@ def llama_model_forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) + # IPEX-LLM OPT start: use fused rope + if (should_use_fuse_rope(hidden_states, position_ids, False) + and self.rotary_emb.rope_type == "llama3"): + position_embeddings = self.rotary_emb.inv_freq + # IEPX_LLM OPT end + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -179,11 +186,16 @@ def llama_attention_forward( self.num_key_value_heads, self.num_key_value_heads], dim=1) - if position_embeddings is None: - cos, sin = self.rotary_emb(value_states, position_ids) + if isinstance(position_embeddings, torch.Tensor): + import xe_addons + inv_freq = position_embeddings + xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states, key_states) else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: key_states, value_states = past_key_value.update(key_states, value_states,