Skip to content

Commit 5d63aef

Browse files
authored
optimize qwen2 vl again (#12109)
1 parent 03bd01c commit 5d63aef

File tree

1 file changed

+105
-24
lines changed

1 file changed

+105
-24
lines changed

python/llm/src/ipex_llm/transformers/models/qwen2_vl.py

Lines changed: 105 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@
4444

4545
from ipex_llm.transformers.models.common import merge_qkv_base
4646
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
47-
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
47+
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope
4848
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
49+
from ipex_llm.utils.common import invalidInputError
4950

50-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention, Qwen2VLModel
51+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
5152
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
5253
from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv
5354
from transformers.modeling_outputs import BaseModelOutputWithPast
@@ -71,29 +72,105 @@ def qwen2_vl_model_forward(
7172
return_dict: Optional[bool] = None,
7273
cache_position: Optional[torch.LongTensor] = None,
7374
) -> Union[Tuple, BaseModelOutputWithPast]:
74-
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
75-
inputs = input_ids if input_ids is not None else inputs_embeds
75+
output_attentions = (
76+
output_attentions if output_attentions is not None
77+
else self.config.output_attentions
78+
)
79+
output_hidden_states = (
80+
output_hidden_states if output_hidden_states is not None
81+
else self.config.output_hidden_states
82+
)
7683
use_cache = use_cache if use_cache is not None else self.config.use_cache
84+
85+
# IPEX-LLM OPT start: kv cache and quantize kv cache
86+
inputs = input_ids if input_ids is not None else inputs_embeds
7787
use_cache = True if inputs.device.type == "xpu" else use_cache
7888
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
7989
if use_cache:
8090
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
8191
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
8292
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
8393
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
94+
# IPEX-LLM OPT end
95+
96+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
97+
98+
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
99+
"You cannot specify both input_ids and inputs_embeds at the same time, "
100+
"and must specify either one")
101+
102+
if inputs_embeds is None:
103+
inputs_embeds = self.embed_tokens(input_ids)
104+
105+
if cache_position is None:
106+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
107+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1],
108+
device=inputs_embeds.device)
109+
110+
# the hard coded `3` is for temporal, height and width.
111+
if position_ids is None:
112+
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
113+
elif position_ids.dim() == 2:
114+
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
115+
116+
causal_mask = self._update_causal_mask(
117+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
118+
)
119+
120+
hidden_states = inputs_embeds
84121

85-
return Qwen2VLModel.forward(
86-
self=self,
87-
input_ids=input_ids,
88-
attention_mask=attention_mask,
89-
position_ids=position_ids,
90-
past_key_values=past_key_values,
91-
inputs_embeds=inputs_embeds,
92-
use_cache=use_cache,
93-
output_attentions=output_attentions,
94-
output_hidden_states=output_hidden_states,
95-
return_dict=return_dict,
96-
cache_position=cache_position,
122+
# create position embeddings to be shared across the decoder layers
123+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
124+
125+
# IPEX-LLM OPT start: use fused 2D rope
126+
if (torch.equal(position_ids[0], position_ids[1])
127+
and torch.equal(position_ids[0], position_ids[2])
128+
and should_use_fuse_rope(hidden_states, position_ids, False)):
129+
position_ids = position_ids[0].contiguous()
130+
position_embeddings = self.rotary_emb.inv_freq
131+
# IEPX_LLM OPT end
132+
133+
# decoder layers
134+
all_hidden_states = () if output_hidden_states else None
135+
all_self_attns = () if output_attentions else None
136+
next_decoder_cache = None
137+
138+
for decoder_layer in self.layers:
139+
layer_outputs = decoder_layer(
140+
hidden_states,
141+
attention_mask=causal_mask,
142+
position_ids=position_ids,
143+
past_key_value=past_key_values,
144+
output_attentions=output_attentions,
145+
use_cache=use_cache,
146+
cache_position=cache_position,
147+
position_embeddings=position_embeddings,
148+
)
149+
150+
hidden_states = layer_outputs[0]
151+
152+
if use_cache:
153+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
154+
155+
if output_attentions:
156+
all_self_attns += (layer_outputs[1],)
157+
158+
hidden_states = self.norm(hidden_states)
159+
160+
# add hidden states from the last decoder layer
161+
if output_hidden_states:
162+
all_hidden_states += (hidden_states,)
163+
164+
next_cache = next_decoder_cache if use_cache else None
165+
166+
if not return_dict:
167+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
168+
if v is not None)
169+
return BaseModelOutputWithPast(
170+
last_hidden_state=hidden_states,
171+
past_key_values=next_cache,
172+
hidden_states=all_hidden_states,
173+
attentions=all_self_attns,
97174
)
98175

99176

@@ -117,19 +194,23 @@ def qwen2_vl_attention_forward(
117194
self.num_key_value_heads,
118195
self.num_key_value_heads], dim=1)
119196

120-
if position_embeddings is None:
121-
cos, sin = self.rotary_emb(value_states, position_ids)
197+
if position_ids.dim() == 2:
198+
import xe_addons
199+
inv_freq = position_embeddings
200+
xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states, key_states)
122201
else:
123-
cos, sin = position_embeddings
124-
query_states, key_states = apply_multimodal_rotary_pos_emb(
125-
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
126-
)
202+
if position_embeddings is None:
203+
cos, sin = self.rotary_emb(value_states, position_ids)
204+
else:
205+
cos, sin = position_embeddings
206+
query_states, key_states = apply_multimodal_rotary_pos_emb(
207+
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
208+
)
127209

128210
kv_seq_len = key_states.shape[-2]
129211
if past_key_value is not None:
130-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
131212
key_states, value_states = past_key_value.update(key_states, value_states,
132-
self.layer_idx, cache_kwargs)
213+
self.layer_idx, None)
133214
kv_seq_len = key_states.shape[-2]
134215

135216
attn_weights = None

0 commit comments

Comments
 (0)