Skip to content

Commit 9a93808

Browse files
authored
fix and optimize minicpm v 2 (#11799)
1 parent d8d887e commit 9a93808

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

python/llm/src/ipex_llm/transformers/convert.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1726,6 +1726,11 @@ def safe_bmm_fwd(*args, **kwargs):
17261726
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
17271727
model.generate = MethodType(minicpmv_generate, model)
17281728

1729+
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
1730+
# MiniCPM-V 2
1731+
model.llm.config.model_type = "minicpm"
1732+
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
1733+
model.llm.config.model_type = "minicpmv"
17291734
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
17301735
# MiniCPM-V 2.6
17311736
model.llm.config.model_type = "qwen2"
@@ -1739,7 +1744,11 @@ def safe_bmm_fwd(*args, **kwargs):
17391744

17401745
vpm_modeling_module_name = model.vpm.__class__.__module__
17411746
vpm_module = importlib.import_module(vpm_modeling_module_name)
1742-
if model.vpm.config.model_type == "siglip":
1747+
if not hasattr(model.vpm, "config"):
1748+
# MiniCPM-V 2
1749+
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
1750+
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
1751+
elif model.vpm.config.model_type == "siglip":
17431752
# MiniCPM-V 2.6
17441753
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
17451754
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)

python/llm/src/ipex_llm/transformers/models/minicpmv.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,21 @@
1515
#
1616

1717

18+
import math
1819
import torch
1920
from typing import Optional
2021
from ipex_llm.transformers.models.common import merge_qkv_base
2122
from transformers import AutoProcessor
2223
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
2324

2425

26+
# MiniCPM-V-2_5 and MiniCPM-V-2_6
2527
def merge_qkv(module: torch.nn.Module):
2628
merge_qkv_base(module, "SiglipAttention")
2729
merge_qkv_base(module, "Idefics2VisionAttention")
2830

2931

32+
# MiniCPM-V-2_5 and MiniCPM-V-2_6
3033
def siglip_attention_forward(
3134
self,
3235
hidden_states: torch.Tensor,
@@ -58,17 +61,7 @@ def siglip_attention_forward(
5861
return attn_output, attn_weights
5962

6063

61-
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
62-
if scores.device.type == "xpu":
63-
import xe_addons
64-
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
65-
else:
66-
score = torch.gather(scores, 1, input_ids)
67-
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
68-
scores.scatter_(1, input_ids, score)
69-
return scores
70-
71-
64+
# MiniCPM-V-2_5
7265
def minicpmv_chat_wrapper(origin_chat):
7366
def minicpmv_chat(
7467
self,
@@ -106,6 +99,37 @@ def minicpmv_chat(
10699
return minicpmv_chat
107100

108101

102+
# MiniCPM-V-2
103+
def minicpmv_get_vision_embedding(self, pixel_values):
104+
res = []
105+
dtype = self.dtype
106+
107+
def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
108+
H, W = pixel_value.shape[-2:]
109+
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
110+
vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
111+
112+
if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
113+
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
114+
return resampler(vision_embedding, target_size)
115+
116+
for pixel_value in pixel_values:
117+
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
118+
res.append(result)
119+
return torch.vstack(res)
120+
121+
122+
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
123+
if scores.device.type == "xpu":
124+
import xe_addons
125+
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
126+
else:
127+
score = torch.gather(scores, 1, input_ids)
128+
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
129+
scores.scatter_(1, input_ids, score)
130+
return scores
131+
132+
109133
def minicpmv_generate_wrapper(origin_generate):
110134
def generate(
111135
*inputs,

0 commit comments

Comments
 (0)