Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,11 @@ def safe_bmm_fwd(*args, **kwargs):
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
model.generate = MethodType(minicpmv_generate, model)

if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
# MiniCPM-V 2
model.llm.config.model_type = "minicpm"
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
model.llm.config.model_type = "minicpmv"
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
# MiniCPM-V 2.6
model.llm.config.model_type = "qwen2"
Expand All @@ -1739,7 +1744,11 @@ def safe_bmm_fwd(*args, **kwargs):

vpm_modeling_module_name = model.vpm.__class__.__module__
vpm_module = importlib.import_module(vpm_modeling_module_name)
if model.vpm.config.model_type == "siglip":
if not hasattr(model.vpm, "config"):
# MiniCPM-V 2
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
elif model.vpm.config.model_type == "siglip":
# MiniCPM-V 2.6
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
Expand Down
46 changes: 35 additions & 11 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@
#


import math
import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from transformers import AutoProcessor
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor


# MiniCPM-V-2_5 and MiniCPM-V-2_6
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "Idefics2VisionAttention")


# MiniCPM-V-2_5 and MiniCPM-V-2_6
def siglip_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -58,17 +61,7 @@ def siglip_attention_forward(
return attn_output, attn_weights


def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores


# MiniCPM-V-2_5
def minicpmv_chat_wrapper(origin_chat):
def minicpmv_chat(
self,
Expand Down Expand Up @@ -106,6 +99,37 @@ def minicpmv_chat(
return minicpmv_chat


# MiniCPM-V-2
def minicpmv_get_vision_embedding(self, pixel_values):
res = []
dtype = self.dtype

def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
H, W = pixel_value.shape[-2:]
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))

if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
return resampler(vision_embedding, target_size)

for pixel_value in pixel_values:
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
res.append(result)
return torch.vstack(res)


def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores


def minicpmv_generate_wrapper(origin_generate):
def generate(
*inputs,
Expand Down