Skip to content

Commit 9239fd4

Browse files
authored
add basic support and optimization for qwen2-vl (#12104)
1 parent 828fa01 commit 9239fd4

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

python/llm/src/ipex_llm/transformers/convert.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,9 @@ def _optimize_pre(model, qtype=None):
10001000
if model.config.model_type == "qwen2_audio":
10011001
from ipex_llm.transformers.models.qwen2 import merge_qkv
10021002
model.language_model.apply(merge_qkv)
1003+
if model.config.model_type == "qwen2_vl":
1004+
from ipex_llm.transformers.models.qwen2_vl import merge_qkv
1005+
model.apply(merge_qkv)
10031006
if model.config.model_type == "stablelm":
10041007
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
10051008
from ipex_llm.transformers.models.stablelm import merge_qkv
@@ -1651,6 +1654,17 @@ def _optimize_post(model, lightweight_bmm=False):
16511654
qwen2_attention_forward)
16521655
elif model.config.model_type == "qwen2_audio":
16531656
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
1657+
elif model.config.model_type == "qwen2_vl":
1658+
modeling_module_name = model.__class__.__module__
1659+
module = importlib.import_module(modeling_module_name)
1660+
from ipex_llm.transformers.models.common import rms_norm_forward
1661+
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
1662+
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward
1663+
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward
1664+
convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward)
1665+
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
1666+
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
1667+
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
16541668
elif model.config.model_type == "cohere":
16551669
# for CohereForAI/c4ai-command-r-v01
16561670
invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
# Some parts of this file is adapted from
17+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
18+
# which is licensed under Apache License 2.0:
19+
#
20+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
21+
#
22+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
23+
# and OPT implementations in this library. It has been modified from its
24+
# original forms to accommodate minor architectural differences compared
25+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
26+
#
27+
# Licensed under the Apache License, Version 2.0 (the "License");
28+
# you may not use this file except in compliance with the License.
29+
# You may obtain a copy of the License at
30+
#
31+
# http://www.apache.org/licenses/LICENSE-2.0
32+
#
33+
# Unless required by applicable law or agreed to in writing, software
34+
# distributed under the License is distributed on an "AS IS" BASIS,
35+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36+
# See the License for the specific language governing permissions and
37+
# limitations under the License.
38+
#
39+
40+
import math
41+
from typing import Optional, Tuple, Union, List
42+
43+
import torch
44+
45+
from ipex_llm.transformers.models.common import merge_qkv_base
46+
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
47+
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
48+
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
49+
50+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention, Qwen2VLModel
51+
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
52+
from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv
53+
from transformers.modeling_outputs import BaseModelOutputWithPast
54+
from transformers.cache_utils import Cache
55+
56+
57+
def merge_qkv(module: torch.nn.Module):
58+
merge_qkv_base(module, Qwen2VLAttention)
59+
60+
61+
def qwen2_vl_model_forward(
62+
self,
63+
input_ids: torch.LongTensor = None,
64+
attention_mask: Optional[torch.Tensor] = None,
65+
position_ids: Optional[torch.LongTensor] = None,
66+
past_key_values: Optional[List[torch.FloatTensor]] = None,
67+
inputs_embeds: Optional[torch.FloatTensor] = None,
68+
use_cache: Optional[bool] = None,
69+
output_attentions: Optional[bool] = None,
70+
output_hidden_states: Optional[bool] = None,
71+
return_dict: Optional[bool] = None,
72+
cache_position: Optional[torch.LongTensor] = None,
73+
) -> Union[Tuple, BaseModelOutputWithPast]:
74+
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
75+
inputs = input_ids if input_ids is not None else inputs_embeds
76+
use_cache = use_cache if use_cache is not None else self.config.use_cache
77+
use_cache = True if inputs.device.type == "xpu" else use_cache
78+
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
79+
if use_cache:
80+
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
81+
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
82+
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
83+
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
84+
85+
return Qwen2VLModel.forward(
86+
self=self,
87+
input_ids=input_ids,
88+
attention_mask=attention_mask,
89+
position_ids=position_ids,
90+
past_key_values=past_key_values,
91+
inputs_embeds=inputs_embeds,
92+
use_cache=use_cache,
93+
output_attentions=output_attentions,
94+
output_hidden_states=output_hidden_states,
95+
return_dict=return_dict,
96+
cache_position=cache_position,
97+
)
98+
99+
100+
def qwen2_vl_attention_forward(
101+
self,
102+
hidden_states: torch.Tensor,
103+
attention_mask: Optional[torch.Tensor] = None,
104+
position_ids: Optional[torch.LongTensor] = None,
105+
past_key_value: Optional[Cache] = None,
106+
output_attentions: bool = False,
107+
use_cache: bool = False,
108+
cache_position: Optional[torch.LongTensor] = None,
109+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
110+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
111+
bsz, q_len, _ = hidden_states.size()
112+
113+
qkv = self.qkv_proj(hidden_states)
114+
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
115+
qkv = qkv.transpose(1, 2)
116+
query_states, key_states, value_states = qkv.split([self.num_heads,
117+
self.num_key_value_heads,
118+
self.num_key_value_heads], dim=1)
119+
120+
if position_embeddings is None:
121+
cos, sin = self.rotary_emb(value_states, position_ids)
122+
else:
123+
cos, sin = position_embeddings
124+
query_states, key_states = apply_multimodal_rotary_pos_emb(
125+
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
126+
)
127+
128+
kv_seq_len = key_states.shape[-2]
129+
if past_key_value is not None:
130+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
131+
key_states, value_states = past_key_value.update(key_states, value_states,
132+
self.layer_idx, cache_kwargs)
133+
kv_seq_len = key_states.shape[-2]
134+
135+
attn_weights = None
136+
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
137+
import xe_addons
138+
if isinstance(past_key_value, DynamicFp8Cache):
139+
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
140+
attention_mask)
141+
else:
142+
attn_output = xe_addons.sdp(query_states, key_states, value_states,
143+
attention_mask)
144+
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
145+
import xe_addons
146+
if isinstance(past_key_value, DynamicFp8Cache):
147+
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
148+
value_states, attention_mask)
149+
else:
150+
attn_output = xe_addons.sdp_causal(query_states, key_states,
151+
value_states, attention_mask)
152+
else:
153+
if isinstance(past_key_value, DynamicFp8Cache):
154+
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
155+
query_states.dtype)
156+
# repeat k/v heads if n_kv_heads < n_heads
157+
key_states = repeat_kv(key_states, self.num_key_value_groups)
158+
value_states = repeat_kv(value_states, self.num_key_value_groups)
159+
160+
attn_weights = torch.matmul(query_states,
161+
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
162+
163+
if attention_mask is not None: # no matter the length, we just slice it
164+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
165+
attn_weights = attn_weights + causal_mask
166+
167+
# upcast attention to fp32
168+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
169+
dtype=torch.float32).to(query_states.dtype)
170+
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
171+
training=self.training)
172+
attn_output = torch.matmul(attn_weights, value_states)
173+
174+
attn_output = attn_output.transpose(1, 2).contiguous()
175+
attn_output = attn_output.reshape(bsz, q_len, -1)
176+
177+
attn_output = self.o_proj(attn_output)
178+
179+
if not output_attentions:
180+
attn_weights = None
181+
182+
return attn_output, attn_weights, past_key_value

0 commit comments

Comments
 (0)