Skip to content

Commit cb79dcd

Browse files
authored
refactor llama convert to fix minicpm-v 2.5 optimization (#11783)
1 parent 7cd6ec9 commit cb79dcd

File tree

1 file changed

+68
-101
lines changed

1 file changed

+68
-101
lines changed

python/llm/src/ipex_llm/transformers/convert.py

Lines changed: 68 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,10 @@ def _optimize_pre(model, qtype=None):
754754
model.llm.config.model_type = "qwen2"
755755
_optimize_pre(model.llm, qtype=qtype)
756756
model.llm.config.model_type = "minicpmv"
757+
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
758+
model.llm.config.model_type = "llama"
759+
_optimize_pre(model.llm, qtype=qtype)
760+
model.llm.config.model_type = "minicpmv"
757761

758762
return model
759763

@@ -933,16 +937,6 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
933937

934938

935939
def _optimize_post(model, lightweight_bmm=False):
936-
from packaging import version
937-
from ipex_llm.transformers.models.llama import llama_attention_forward_4_31
938-
from ipex_llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
939-
from ipex_llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
940-
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
941-
from ipex_llm.transformers.models.llama import llama_mlp_forward
942-
from ipex_llm.transformers.models.llama import llama_decoder_forward
943-
from ipex_llm.transformers.models.llama import llama_model_forward
944-
from transformers.modeling_utils import PreTrainedModel
945-
946940
try:
947941
from sentence_transformers.SentenceTransformer import SentenceTransformer
948942
if isinstance(model, SentenceTransformer):
@@ -961,110 +955,80 @@ def _optimize_post(model, lightweight_bmm=False):
961955
except ModuleNotFoundError:
962956
pass
963957

958+
from transformers.modeling_utils import PreTrainedModel
964959
# All huggingface format models are inherited from `PreTrainedModel`
965960
if not isinstance(model, PreTrainedModel):
966961
logger.info("Only HuggingFace Transformers models are currently "
967962
"supported for further optimizations")
968963
return model
969964

970-
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
971-
enable_vllm_se_batching = vllm_selective_batching is not None
972-
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
973-
965+
from packaging import version
974966
trans_version = transformers.__version__
975-
if version.parse(trans_version) >= version.parse("4.31.0"):
976-
convert_forward(
977-
model,
978-
transformers.models.llama.modeling_llama.LlamaRMSNorm,
979-
llama_rms_norm_forward,)
980-
convert_forward(model,
981-
transformers.models.llama.modeling_llama.LlamaMLP,
982-
llama_mlp_forward)
983-
convert_forward(model,
984-
transformers.models.llama.modeling_llama.LlamaDecoderLayer,
985-
llama_decoder_forward)
986967

968+
# convert all nn.LayerNorm
969+
from ipex_llm.transformers.models.bloom import bloom_layer_norm_forward
970+
convert_forward(model,
971+
nn.LayerNorm,
972+
bloom_layer_norm_forward)
973+
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
974+
from ipex_llm.transformers.models.llama import llama_mlp_forward
975+
976+
if model.config.model_type == "llama":
977+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
978+
from transformers.models.llama.modeling_llama import LlamaMLP
979+
from transformers.models.llama.modeling_llama import LlamaAttention
980+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
981+
from transformers.models.llama.modeling_llama import LlamaModel
987982
if version.parse(trans_version) >= version.parse("4.36.0"):
988-
# transformers version >= 4.36.0
983+
from transformers.models.llama.modeling_llama import LlamaSdpaAttention
984+
985+
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
986+
from ipex_llm.transformers.models.llama import llama_mlp_forward
987+
from ipex_llm.transformers.models.llama import llama_decoder_forward
988+
989+
convert_forward(model, LlamaRMSNorm, llama_rms_norm_forward)
990+
convert_forward(model, LlamaMLP, llama_mlp_forward)
991+
convert_forward(model, LlamaDecoderLayer, llama_decoder_forward)
992+
993+
if version.parse(trans_version) >= version.parse("4.41.0"):
994+
from ipex_llm.transformers.models.llama import llama_model_forward_4_41
995+
from ipex_llm.transformers.models.llama import llama_attention_forward_4_41
996+
convert_forward(model, LlamaModel, llama_model_forward_4_41)
997+
convert_forward(model, LlamaAttention, llama_attention_forward_4_41)
998+
convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_41)
999+
elif version.parse(trans_version) >= version.parse("4.38.0"):
1000+
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
9891001
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
990-
if version.parse(trans_version) >= version.parse("4.38.0"):
991-
if version.parse(trans_version) >= version.parse("4.41.0"):
992-
from ipex_llm.transformers.models.llama import llama_model_forward_4_41
993-
from ipex_llm.transformers.models.llama import llama_attention_forward_4_41
994-
convert_forward(
995-
model,
996-
transformers.models.llama.modeling_llama.LlamaModel,
997-
llama_model_forward_4_41)
998-
convert_forward(
999-
model,
1000-
transformers.models.llama.modeling_llama.LlamaAttention,
1001-
llama_attention_forward_4_41)
1002-
convert_forward(
1003-
model,
1004-
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
1005-
llama_attention_forward_4_41)
1006-
else:
1007-
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
1008-
convert_forward(
1009-
model,
1010-
transformers.models.llama.modeling_llama.LlamaModel,
1011-
llama_model_forward_4_38)
1012-
convert_forward(
1013-
model,
1014-
transformers.models.llama.modeling_llama.LlamaAttention,
1015-
llama_attention_forward_4_38)
1016-
convert_forward(
1017-
model,
1018-
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
1019-
llama_attention_forward_4_38)
1020-
else:
1021-
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
1022-
convert_forward(
1023-
model,
1024-
transformers.models.llama.modeling_llama.LlamaModel,
1025-
llama_model_forward_4_36)
1026-
convert_forward(
1027-
model,
1028-
transformers.models.llama.modeling_llama.LlamaAttention,
1029-
llama_attention_forward_4_38)
1030-
convert_forward(
1031-
model,
1032-
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
1033-
llama_attention_forward_4_38)
1002+
convert_forward(model, LlamaModel, llama_model_forward_4_38)
1003+
convert_forward(model, LlamaAttention, llama_attention_forward_4_38)
1004+
convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_38)
1005+
elif version.parse(trans_version) >= version.parse("4.36.0"):
1006+
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
1007+
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
1008+
convert_forward(model, LlamaModel, llama_model_forward_4_36)
1009+
convert_forward(model, LlamaAttention, llama_attention_forward_4_38)
1010+
convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_38)
10341011
else:
1035-
# transformers version between 4.31.0 - 4.35.2
1036-
convert_forward(
1037-
model,
1038-
transformers.models.llama.modeling_llama.LlamaAttention,
1039-
llama_attention_forward_4_31, )
1040-
if enable_vllm_se_batching:
1041-
convert_forward(
1042-
model,
1043-
transformers.models.llama.modeling_llama.LlamaModel,
1012+
vllm_se_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING", "").lower() == "true"
1013+
if vllm_se_batching:
1014+
from ipex_llm.transformers.models.llama import (
10441015
llama_model_selective_batching_forward_4_31,
1045-
)
1046-
convert_forward(
1047-
model,
1048-
transformers.models.llama.modeling_llama.LlamaAttention,
10491016
llama_attention_selective_batching_forward_4_31,
10501017
)
1018+
convert_forward(model, LlamaModel,
1019+
llama_model_selective_batching_forward_4_31)
1020+
convert_forward(model, LlamaAttention,
1021+
llama_attention_selective_batching_forward_4_31)
10511022
else:
1052-
convert_forward(
1053-
model,
1054-
transformers.models.llama.modeling_llama.LlamaModel,
1055-
llama_model_forward)
1056-
else:
1057-
# todo implement 4.28.0 ~ 4.30.2
1058-
pass
1059-
1060-
# convert all nn.LayerNorm
1061-
from ipex_llm.transformers.models.bloom import bloom_layer_norm_forward
1062-
convert_forward(model,
1063-
nn.LayerNorm,
1064-
bloom_layer_norm_forward)
1065-
1066-
if model.config.architectures is not None \
1067-
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
1023+
from ipex_llm.transformers.models.llama import llama_model_forward
1024+
from ipex_llm.transformers.models.llama import llama_attention_forward_4_31
1025+
convert_forward(model, LlamaModel, llama_model_forward)
1026+
convert_forward(model, LlamaAttention, llama_attention_forward_4_31)
1027+
1028+
elif (
1029+
model.config.architectures is not None
1030+
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]
1031+
):
10681032
if hasattr(model.config, 'padded_vocab_size') and \
10691033
model.config.padded_vocab_size in [65024, 64896]:
10701034
# chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k
@@ -1370,6 +1334,7 @@ def _optimize_post(model, lightweight_bmm=False):
13701334
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_model_forward
13711335
from ipex_llm.transformers.models.qwen2_moe import qwen2_moe_causal_lm_forward
13721336
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
1337+
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
13731338
convert_forward(model,
13741339
module.Qwen2MoeModel,
13751340
qwen2moe_model_forward)
@@ -1384,7 +1349,7 @@ def _optimize_post(model, lightweight_bmm=False):
13841349
qwen2moe_moeblock_forward)
13851350
convert_forward(model,
13861351
module.Qwen2MoeMLP,
1387-
llama_mlp_forward)
1352+
qwen2_mlp_forward)
13881353
convert_forward(model,
13891354
module.Qwen2MoeAttention,
13901355
qwen2_attention_forward)
@@ -1768,7 +1733,9 @@ def safe_bmm_fwd(*args, **kwargs):
17681733
model.llm.config.model_type = "minicpmv"
17691734
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
17701735
# MiniCPM-V 2.5
1771-
pass
1736+
model.llm.config.model_type = "llama"
1737+
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
1738+
model.llm.config.model_type = "minicpmv"
17721739

17731740
vpm_modeling_module_name = model.vpm.__class__.__module__
17741741
vpm_module = importlib.import_module(vpm_modeling_module_name)

0 commit comments

Comments
 (0)