|
| 1 | +# |
| 2 | +# Copyright 2016 The BigDL Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# |
| 16 | +# Some parts of this file is adapted from |
| 17 | +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py |
| 18 | +# which is licensed under Apache License 2.0: |
| 19 | +# |
| 20 | +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. |
| 21 | +# |
| 22 | +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX |
| 23 | +# and OPT implementations in this library. It has been modified from its |
| 24 | +# original forms to accommodate minor architectural differences compared |
| 25 | +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. |
| 26 | +# |
| 27 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 28 | +# you may not use this file except in compliance with the License. |
| 29 | +# You may obtain a copy of the License at |
| 30 | +# |
| 31 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 32 | +# |
| 33 | +# Unless required by applicable law or agreed to in writing, software |
| 34 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 35 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 36 | +# See the License for the specific language governing permissions and |
| 37 | +# limitations under the License. |
| 38 | +# |
| 39 | + |
| 40 | +import math |
| 41 | +from typing import Optional, Tuple, Union, List |
| 42 | + |
| 43 | +import torch |
| 44 | + |
| 45 | +from ipex_llm.transformers.models.common import merge_qkv_base |
| 46 | +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache |
| 47 | +from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal |
| 48 | +from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache |
| 49 | + |
| 50 | +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention, Qwen2VLModel |
| 51 | +from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb |
| 52 | +from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv |
| 53 | +from transformers.modeling_outputs import BaseModelOutputWithPast |
| 54 | +from transformers.cache_utils import Cache |
| 55 | + |
| 56 | + |
| 57 | +def merge_qkv(module: torch.nn.Module): |
| 58 | + merge_qkv_base(module, Qwen2VLAttention) |
| 59 | + |
| 60 | + |
| 61 | +def qwen2_vl_model_forward( |
| 62 | + self, |
| 63 | + input_ids: torch.LongTensor = None, |
| 64 | + attention_mask: Optional[torch.Tensor] = None, |
| 65 | + position_ids: Optional[torch.LongTensor] = None, |
| 66 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 67 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 68 | + use_cache: Optional[bool] = None, |
| 69 | + output_attentions: Optional[bool] = None, |
| 70 | + output_hidden_states: Optional[bool] = None, |
| 71 | + return_dict: Optional[bool] = None, |
| 72 | + cache_position: Optional[torch.LongTensor] = None, |
| 73 | +) -> Union[Tuple, BaseModelOutputWithPast]: |
| 74 | + # IPEX-LLM OPT: kv cache and quantize kv cache and sdp |
| 75 | + inputs = input_ids if input_ids is not None else inputs_embeds |
| 76 | + use_cache = use_cache if use_cache is not None else self.config.use_cache |
| 77 | + use_cache = True if inputs.device.type == "xpu" else use_cache |
| 78 | + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) |
| 79 | + if use_cache: |
| 80 | + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): |
| 81 | + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) |
| 82 | + elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): |
| 83 | + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) |
| 84 | + |
| 85 | + return Qwen2VLModel.forward( |
| 86 | + self=self, |
| 87 | + input_ids=input_ids, |
| 88 | + attention_mask=attention_mask, |
| 89 | + position_ids=position_ids, |
| 90 | + past_key_values=past_key_values, |
| 91 | + inputs_embeds=inputs_embeds, |
| 92 | + use_cache=use_cache, |
| 93 | + output_attentions=output_attentions, |
| 94 | + output_hidden_states=output_hidden_states, |
| 95 | + return_dict=return_dict, |
| 96 | + cache_position=cache_position, |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +def qwen2_vl_attention_forward( |
| 101 | + self, |
| 102 | + hidden_states: torch.Tensor, |
| 103 | + attention_mask: Optional[torch.Tensor] = None, |
| 104 | + position_ids: Optional[torch.LongTensor] = None, |
| 105 | + past_key_value: Optional[Cache] = None, |
| 106 | + output_attentions: bool = False, |
| 107 | + use_cache: bool = False, |
| 108 | + cache_position: Optional[torch.LongTensor] = None, |
| 109 | + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, |
| 110 | +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 111 | + bsz, q_len, _ = hidden_states.size() |
| 112 | + |
| 113 | + qkv = self.qkv_proj(hidden_states) |
| 114 | + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) |
| 115 | + qkv = qkv.transpose(1, 2) |
| 116 | + query_states, key_states, value_states = qkv.split([self.num_heads, |
| 117 | + self.num_key_value_heads, |
| 118 | + self.num_key_value_heads], dim=1) |
| 119 | + |
| 120 | + if position_embeddings is None: |
| 121 | + cos, sin = self.rotary_emb(value_states, position_ids) |
| 122 | + else: |
| 123 | + cos, sin = position_embeddings |
| 124 | + query_states, key_states = apply_multimodal_rotary_pos_emb( |
| 125 | + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] |
| 126 | + ) |
| 127 | + |
| 128 | + kv_seq_len = key_states.shape[-2] |
| 129 | + if past_key_value is not None: |
| 130 | + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| 131 | + key_states, value_states = past_key_value.update(key_states, value_states, |
| 132 | + self.layer_idx, cache_kwargs) |
| 133 | + kv_seq_len = key_states.shape[-2] |
| 134 | + |
| 135 | + attn_weights = None |
| 136 | + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): |
| 137 | + import xe_addons |
| 138 | + if isinstance(past_key_value, DynamicFp8Cache): |
| 139 | + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, |
| 140 | + attention_mask) |
| 141 | + else: |
| 142 | + attn_output = xe_addons.sdp(query_states, key_states, value_states, |
| 143 | + attention_mask) |
| 144 | + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): |
| 145 | + import xe_addons |
| 146 | + if isinstance(past_key_value, DynamicFp8Cache): |
| 147 | + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, |
| 148 | + value_states, attention_mask) |
| 149 | + else: |
| 150 | + attn_output = xe_addons.sdp_causal(query_states, key_states, |
| 151 | + value_states, attention_mask) |
| 152 | + else: |
| 153 | + if isinstance(past_key_value, DynamicFp8Cache): |
| 154 | + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, |
| 155 | + query_states.dtype) |
| 156 | + # repeat k/v heads if n_kv_heads < n_heads |
| 157 | + key_states = repeat_kv(key_states, self.num_key_value_groups) |
| 158 | + value_states = repeat_kv(value_states, self.num_key_value_groups) |
| 159 | + |
| 160 | + attn_weights = torch.matmul(query_states, |
| 161 | + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
| 162 | + |
| 163 | + if attention_mask is not None: # no matter the length, we just slice it |
| 164 | + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| 165 | + attn_weights = attn_weights + causal_mask |
| 166 | + |
| 167 | + # upcast attention to fp32 |
| 168 | + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, |
| 169 | + dtype=torch.float32).to(query_states.dtype) |
| 170 | + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, |
| 171 | + training=self.training) |
| 172 | + attn_output = torch.matmul(attn_weights, value_states) |
| 173 | + |
| 174 | + attn_output = attn_output.transpose(1, 2).contiguous() |
| 175 | + attn_output = attn_output.reshape(bsz, q_len, -1) |
| 176 | + |
| 177 | + attn_output = self.o_proj(attn_output) |
| 178 | + |
| 179 | + if not output_attentions: |
| 180 | + attn_weights = None |
| 181 | + |
| 182 | + return attn_output, attn_weights, past_key_value |
0 commit comments