4444
4545from ipex_llm .transformers .models .common import merge_qkv_base
4646from ipex_llm .transformers .models .utils import use_quantize_kv_cache , restore_fp8_kv_cache
47- from ipex_llm .transformers .models .utils import use_sdp , use_sdp_causal
47+ from ipex_llm .transformers .models .utils import use_sdp , use_sdp_causal , should_use_fuse_rope
4848from ipex_llm .transformers .kv import DynamicFp8Cache , DynamicNormalCache
49+ from ipex_llm .utils .common import invalidInputError
4950
50- from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLAttention , Qwen2VLModel
51+ from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLAttention
5152from transformers .models .qwen2_vl .modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
5253from transformers .models .qwen2_vl .modeling_qwen2_vl import repeat_kv
5354from transformers .modeling_outputs import BaseModelOutputWithPast
@@ -71,29 +72,105 @@ def qwen2_vl_model_forward(
7172 return_dict : Optional [bool ] = None ,
7273 cache_position : Optional [torch .LongTensor ] = None ,
7374) -> Union [Tuple , BaseModelOutputWithPast ]:
74- # IPEX-LLM OPT: kv cache and quantize kv cache and sdp
75- inputs = input_ids if input_ids is not None else inputs_embeds
75+ output_attentions = (
76+ output_attentions if output_attentions is not None
77+ else self .config .output_attentions
78+ )
79+ output_hidden_states = (
80+ output_hidden_states if output_hidden_states is not None
81+ else self .config .output_hidden_states
82+ )
7683 use_cache = use_cache if use_cache is not None else self .config .use_cache
84+
85+ # IPEX-LLM OPT start: kv cache and quantize kv cache
86+ inputs = input_ids if input_ids is not None else inputs_embeds
7787 use_cache = True if inputs .device .type == "xpu" else use_cache
7888 use_quantize_kv = use_quantize_kv_cache (self .layers [0 ].mlp .down_proj , inputs )
7989 if use_cache :
8090 if use_quantize_kv and not isinstance (past_key_values , DynamicFp8Cache ):
8191 past_key_values = DynamicFp8Cache .from_legacy_cache (past_key_values )
8292 elif not use_quantize_kv and not isinstance (past_key_values , DynamicNormalCache ):
8393 past_key_values = DynamicNormalCache .from_legacy_cache (past_key_values )
94+ # IPEX-LLM OPT end
95+
96+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
97+
98+ invalidInputError ((input_ids is None ) ^ (inputs_embeds is None ),
99+ "You cannot specify both input_ids and inputs_embeds at the same time, "
100+ "and must specify either one" )
101+
102+ if inputs_embeds is None :
103+ inputs_embeds = self .embed_tokens (input_ids )
104+
105+ if cache_position is None :
106+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
107+ cache_position = torch .arange (past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ],
108+ device = inputs_embeds .device )
109+
110+ # the hard coded `3` is for temporal, height and width.
111+ if position_ids is None :
112+ position_ids = cache_position .view (1 , 1 , - 1 ).expand (3 , inputs_embeds .shape [0 ], - 1 )
113+ elif position_ids .dim () == 2 :
114+ position_ids = position_ids [None , ...].expand (3 , position_ids .shape [0 ], - 1 )
115+
116+ causal_mask = self ._update_causal_mask (
117+ attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
118+ )
119+
120+ hidden_states = inputs_embeds
84121
85- return Qwen2VLModel .forward (
86- self = self ,
87- input_ids = input_ids ,
88- attention_mask = attention_mask ,
89- position_ids = position_ids ,
90- past_key_values = past_key_values ,
91- inputs_embeds = inputs_embeds ,
92- use_cache = use_cache ,
93- output_attentions = output_attentions ,
94- output_hidden_states = output_hidden_states ,
95- return_dict = return_dict ,
96- cache_position = cache_position ,
122+ # create position embeddings to be shared across the decoder layers
123+ position_embeddings = self .rotary_emb (hidden_states , position_ids )
124+
125+ # IPEX-LLM OPT start: use fused 2D rope
126+ if (torch .equal (position_ids [0 ], position_ids [1 ])
127+ and torch .equal (position_ids [0 ], position_ids [2 ])
128+ and should_use_fuse_rope (hidden_states , position_ids , False )):
129+ position_ids = position_ids [0 ].contiguous ()
130+ position_embeddings = self .rotary_emb .inv_freq
131+ # IEPX_LLM OPT end
132+
133+ # decoder layers
134+ all_hidden_states = () if output_hidden_states else None
135+ all_self_attns = () if output_attentions else None
136+ next_decoder_cache = None
137+
138+ for decoder_layer in self .layers :
139+ layer_outputs = decoder_layer (
140+ hidden_states ,
141+ attention_mask = causal_mask ,
142+ position_ids = position_ids ,
143+ past_key_value = past_key_values ,
144+ output_attentions = output_attentions ,
145+ use_cache = use_cache ,
146+ cache_position = cache_position ,
147+ position_embeddings = position_embeddings ,
148+ )
149+
150+ hidden_states = layer_outputs [0 ]
151+
152+ if use_cache :
153+ next_decoder_cache = layer_outputs [2 if output_attentions else 1 ]
154+
155+ if output_attentions :
156+ all_self_attns += (layer_outputs [1 ],)
157+
158+ hidden_states = self .norm (hidden_states )
159+
160+ # add hidden states from the last decoder layer
161+ if output_hidden_states :
162+ all_hidden_states += (hidden_states ,)
163+
164+ next_cache = next_decoder_cache if use_cache else None
165+
166+ if not return_dict :
167+ return tuple (v for v in [hidden_states , next_cache , all_hidden_states , all_self_attns ]
168+ if v is not None )
169+ return BaseModelOutputWithPast (
170+ last_hidden_state = hidden_states ,
171+ past_key_values = next_cache ,
172+ hidden_states = all_hidden_states ,
173+ attentions = all_self_attns ,
97174 )
98175
99176
@@ -117,19 +194,23 @@ def qwen2_vl_attention_forward(
117194 self .num_key_value_heads ,
118195 self .num_key_value_heads ], dim = 1 )
119196
120- if position_embeddings is None :
121- cos , sin = self .rotary_emb (value_states , position_ids )
197+ if position_ids .dim () == 2 :
198+ import xe_addons
199+ inv_freq = position_embeddings
200+ xe_addons .rotary_half_inplaced (inv_freq , position_ids , query_states , key_states )
122201 else :
123- cos , sin = position_embeddings
124- query_states , key_states = apply_multimodal_rotary_pos_emb (
125- query_states , key_states , cos , sin , self .rope_scaling ["mrope_section" ]
126- )
202+ if position_embeddings is None :
203+ cos , sin = self .rotary_emb (value_states , position_ids )
204+ else :
205+ cos , sin = position_embeddings
206+ query_states , key_states = apply_multimodal_rotary_pos_emb (
207+ query_states , key_states , cos , sin , self .rope_scaling ["mrope_section" ]
208+ )
127209
128210 kv_seq_len = key_states .shape [- 2 ]
129211 if past_key_value is not None :
130- cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
131212 key_states , value_states = past_key_value .update (key_states , value_states ,
132- self .layer_idx , cache_kwargs )
213+ self .layer_idx , None )
133214 kv_seq_len = key_states .shape [- 2 ]
134215
135216 attn_weights = None
0 commit comments