|
15 | 15 | # |
16 | 16 |
|
17 | 17 |
|
| 18 | +import math |
18 | 19 | import torch |
19 | 20 | from typing import Optional |
20 | 21 | from ipex_llm.transformers.models.common import merge_qkv_base |
21 | 22 | from transformers import AutoProcessor |
22 | 23 | from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor |
23 | 24 |
|
24 | 25 |
|
| 26 | +# MiniCPM-V-2_5 and MiniCPM-V-2_6 |
25 | 27 | def merge_qkv(module: torch.nn.Module): |
26 | 28 | merge_qkv_base(module, "SiglipAttention") |
27 | 29 | merge_qkv_base(module, "Idefics2VisionAttention") |
28 | 30 |
|
29 | 31 |
|
| 32 | +# MiniCPM-V-2_5 and MiniCPM-V-2_6 |
30 | 33 | def siglip_attention_forward( |
31 | 34 | self, |
32 | 35 | hidden_states: torch.Tensor, |
@@ -58,17 +61,7 @@ def siglip_attention_forward( |
58 | 61 | return attn_output, attn_weights |
59 | 62 |
|
60 | 63 |
|
61 | | -def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
62 | | - if scores.device.type == "xpu": |
63 | | - import xe_addons |
64 | | - xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty) |
65 | | - else: |
66 | | - score = torch.gather(scores, 1, input_ids) |
67 | | - score = torch.where(score < 0, score * self.penalty, score / self.penalty) |
68 | | - scores.scatter_(1, input_ids, score) |
69 | | - return scores |
70 | | - |
71 | | - |
| 64 | +# MiniCPM-V-2_5 |
72 | 65 | def minicpmv_chat_wrapper(origin_chat): |
73 | 66 | def minicpmv_chat( |
74 | 67 | self, |
@@ -106,6 +99,37 @@ def minicpmv_chat( |
106 | 99 | return minicpmv_chat |
107 | 100 |
|
108 | 101 |
|
| 102 | +# MiniCPM-V-2 |
| 103 | +def minicpmv_get_vision_embedding(self, pixel_values): |
| 104 | + res = [] |
| 105 | + dtype = self.dtype |
| 106 | + |
| 107 | + def process_each_pixel(pixel_value, dtype, config, vpm, resampler): |
| 108 | + H, W = pixel_value.shape[-2:] |
| 109 | + target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size)) |
| 110 | + vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype)) |
| 111 | + |
| 112 | + if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0: |
| 113 | + vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:] |
| 114 | + return resampler(vision_embedding, target_size) |
| 115 | + |
| 116 | + for pixel_value in pixel_values: |
| 117 | + result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler) |
| 118 | + res.append(result) |
| 119 | + return torch.vstack(res) |
| 120 | + |
| 121 | + |
| 122 | +def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
| 123 | + if scores.device.type == "xpu": |
| 124 | + import xe_addons |
| 125 | + xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty) |
| 126 | + else: |
| 127 | + score = torch.gather(scores, 1, input_ids) |
| 128 | + score = torch.where(score < 0, score * self.penalty, score / self.penalty) |
| 129 | + scores.scatter_(1, input_ids, score) |
| 130 | + return scores |
| 131 | + |
| 132 | + |
109 | 133 | def minicpmv_generate_wrapper(origin_generate): |
110 | 134 | def generate( |
111 | 135 | *inputs, |
|
0 commit comments