@@ -754,6 +754,10 @@ def _optimize_pre(model, qtype=None):
754754 model .llm .config .model_type = "qwen2"
755755 _optimize_pre (model .llm , qtype = qtype )
756756 model .llm .config .model_type = "minicpmv"
757+ elif model .config .hidden_size == 4096 and model .config .vocab_size == 128256 :
758+ model .llm .config .model_type = "llama"
759+ _optimize_pre (model .llm , qtype = qtype )
760+ model .llm .config .model_type = "minicpmv"
757761
758762 return model
759763
@@ -933,16 +937,6 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
933937
934938
935939def _optimize_post (model , lightweight_bmm = False ):
936- from packaging import version
937- from ipex_llm .transformers .models .llama import llama_attention_forward_4_31
938- from ipex_llm .transformers .models .llama import llama_attention_selective_batching_forward_4_31
939- from ipex_llm .transformers .models .llama import llama_model_selective_batching_forward_4_31
940- from ipex_llm .transformers .models .llama import llama_rms_norm_forward
941- from ipex_llm .transformers .models .llama import llama_mlp_forward
942- from ipex_llm .transformers .models .llama import llama_decoder_forward
943- from ipex_llm .transformers .models .llama import llama_model_forward
944- from transformers .modeling_utils import PreTrainedModel
945-
946940 try :
947941 from sentence_transformers .SentenceTransformer import SentenceTransformer
948942 if isinstance (model , SentenceTransformer ):
@@ -961,110 +955,80 @@ def _optimize_post(model, lightweight_bmm=False):
961955 except ModuleNotFoundError :
962956 pass
963957
958+ from transformers .modeling_utils import PreTrainedModel
964959 # All huggingface format models are inherited from `PreTrainedModel`
965960 if not isinstance (model , PreTrainedModel ):
966961 logger .info ("Only HuggingFace Transformers models are currently "
967962 "supported for further optimizations" )
968963 return model
969964
970- vllm_selective_batching = os .getenv ("VLLM_ENABLE_SELECTIVE_BATCHING" )
971- enable_vllm_se_batching = vllm_selective_batching is not None
972- enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching .lower () == "true"
973-
965+ from packaging import version
974966 trans_version = transformers .__version__
975- if version .parse (trans_version ) >= version .parse ("4.31.0" ):
976- convert_forward (
977- model ,
978- transformers .models .llama .modeling_llama .LlamaRMSNorm ,
979- llama_rms_norm_forward ,)
980- convert_forward (model ,
981- transformers .models .llama .modeling_llama .LlamaMLP ,
982- llama_mlp_forward )
983- convert_forward (model ,
984- transformers .models .llama .modeling_llama .LlamaDecoderLayer ,
985- llama_decoder_forward )
986967
968+ # convert all nn.LayerNorm
969+ from ipex_llm .transformers .models .bloom import bloom_layer_norm_forward
970+ convert_forward (model ,
971+ nn .LayerNorm ,
972+ bloom_layer_norm_forward )
973+ from ipex_llm .transformers .models .llama import llama_rms_norm_forward
974+ from ipex_llm .transformers .models .llama import llama_mlp_forward
975+
976+ if model .config .model_type == "llama" :
977+ from transformers .models .llama .modeling_llama import LlamaRMSNorm
978+ from transformers .models .llama .modeling_llama import LlamaMLP
979+ from transformers .models .llama .modeling_llama import LlamaAttention
980+ from transformers .models .llama .modeling_llama import LlamaDecoderLayer
981+ from transformers .models .llama .modeling_llama import LlamaModel
987982 if version .parse (trans_version ) >= version .parse ("4.36.0" ):
988- # transformers version >= 4.36.0
983+ from transformers .models .llama .modeling_llama import LlamaSdpaAttention
984+
985+ from ipex_llm .transformers .models .llama import llama_rms_norm_forward
986+ from ipex_llm .transformers .models .llama import llama_mlp_forward
987+ from ipex_llm .transformers .models .llama import llama_decoder_forward
988+
989+ convert_forward (model , LlamaRMSNorm , llama_rms_norm_forward )
990+ convert_forward (model , LlamaMLP , llama_mlp_forward )
991+ convert_forward (model , LlamaDecoderLayer , llama_decoder_forward )
992+
993+ if version .parse (trans_version ) >= version .parse ("4.41.0" ):
994+ from ipex_llm .transformers .models .llama import llama_model_forward_4_41
995+ from ipex_llm .transformers .models .llama import llama_attention_forward_4_41
996+ convert_forward (model , LlamaModel , llama_model_forward_4_41 )
997+ convert_forward (model , LlamaAttention , llama_attention_forward_4_41 )
998+ convert_forward (model , LlamaSdpaAttention , llama_attention_forward_4_41 )
999+ elif version .parse (trans_version ) >= version .parse ("4.38.0" ):
1000+ from ipex_llm .transformers .models .llama import llama_model_forward_4_38
9891001 from ipex_llm .transformers .models .llama import llama_attention_forward_4_38
990- if version .parse (trans_version ) >= version .parse ("4.38.0" ):
991- if version .parse (trans_version ) >= version .parse ("4.41.0" ):
992- from ipex_llm .transformers .models .llama import llama_model_forward_4_41
993- from ipex_llm .transformers .models .llama import llama_attention_forward_4_41
994- convert_forward (
995- model ,
996- transformers .models .llama .modeling_llama .LlamaModel ,
997- llama_model_forward_4_41 )
998- convert_forward (
999- model ,
1000- transformers .models .llama .modeling_llama .LlamaAttention ,
1001- llama_attention_forward_4_41 )
1002- convert_forward (
1003- model ,
1004- transformers .models .llama .modeling_llama .LlamaSdpaAttention ,
1005- llama_attention_forward_4_41 )
1006- else :
1007- from ipex_llm .transformers .models .llama import llama_model_forward_4_38
1008- convert_forward (
1009- model ,
1010- transformers .models .llama .modeling_llama .LlamaModel ,
1011- llama_model_forward_4_38 )
1012- convert_forward (
1013- model ,
1014- transformers .models .llama .modeling_llama .LlamaAttention ,
1015- llama_attention_forward_4_38 )
1016- convert_forward (
1017- model ,
1018- transformers .models .llama .modeling_llama .LlamaSdpaAttention ,
1019- llama_attention_forward_4_38 )
1020- else :
1021- from ipex_llm .transformers .models .llama import llama_model_forward_4_36
1022- convert_forward (
1023- model ,
1024- transformers .models .llama .modeling_llama .LlamaModel ,
1025- llama_model_forward_4_36 )
1026- convert_forward (
1027- model ,
1028- transformers .models .llama .modeling_llama .LlamaAttention ,
1029- llama_attention_forward_4_38 )
1030- convert_forward (
1031- model ,
1032- transformers .models .llama .modeling_llama .LlamaSdpaAttention ,
1033- llama_attention_forward_4_38 )
1002+ convert_forward (model , LlamaModel , llama_model_forward_4_38 )
1003+ convert_forward (model , LlamaAttention , llama_attention_forward_4_38 )
1004+ convert_forward (model , LlamaSdpaAttention , llama_attention_forward_4_38 )
1005+ elif version .parse (trans_version ) >= version .parse ("4.36.0" ):
1006+ from ipex_llm .transformers .models .llama import llama_model_forward_4_36
1007+ from ipex_llm .transformers .models .llama import llama_attention_forward_4_38
1008+ convert_forward (model , LlamaModel , llama_model_forward_4_36 )
1009+ convert_forward (model , LlamaAttention , llama_attention_forward_4_38 )
1010+ convert_forward (model , LlamaSdpaAttention , llama_attention_forward_4_38 )
10341011 else :
1035- # transformers version between 4.31.0 - 4.35.2
1036- convert_forward (
1037- model ,
1038- transformers .models .llama .modeling_llama .LlamaAttention ,
1039- llama_attention_forward_4_31 , )
1040- if enable_vllm_se_batching :
1041- convert_forward (
1042- model ,
1043- transformers .models .llama .modeling_llama .LlamaModel ,
1012+ vllm_se_batching = os .getenv ("VLLM_ENABLE_SELECTIVE_BATCHING" , "" ).lower () == "true"
1013+ if vllm_se_batching :
1014+ from ipex_llm .transformers .models .llama import (
10441015 llama_model_selective_batching_forward_4_31 ,
1045- )
1046- convert_forward (
1047- model ,
1048- transformers .models .llama .modeling_llama .LlamaAttention ,
10491016 llama_attention_selective_batching_forward_4_31 ,
10501017 )
1018+ convert_forward (model , LlamaModel ,
1019+ llama_model_selective_batching_forward_4_31 )
1020+ convert_forward (model , LlamaAttention ,
1021+ llama_attention_selective_batching_forward_4_31 )
10511022 else :
1052- convert_forward (
1053- model ,
1054- transformers .models .llama .modeling_llama .LlamaModel ,
1055- llama_model_forward )
1056- else :
1057- # todo implement 4.28.0 ~ 4.30.2
1058- pass
1059-
1060- # convert all nn.LayerNorm
1061- from ipex_llm .transformers .models .bloom import bloom_layer_norm_forward
1062- convert_forward (model ,
1063- nn .LayerNorm ,
1064- bloom_layer_norm_forward )
1065-
1066- if model .config .architectures is not None \
1067- and model .config .architectures [0 ] in ["ChatGLMModel" , "ChatGLMForConditionalGeneration" ]:
1023+ from ipex_llm .transformers .models .llama import llama_model_forward
1024+ from ipex_llm .transformers .models .llama import llama_attention_forward_4_31
1025+ convert_forward (model , LlamaModel , llama_model_forward )
1026+ convert_forward (model , LlamaAttention , llama_attention_forward_4_31 )
1027+
1028+ elif (
1029+ model .config .architectures is not None
1030+ and model .config .architectures [0 ] in ["ChatGLMModel" , "ChatGLMForConditionalGeneration" ]
1031+ ):
10681032 if hasattr (model .config , 'padded_vocab_size' ) and \
10691033 model .config .padded_vocab_size in [65024 , 64896 ]:
10701034 # chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k
@@ -1370,6 +1334,7 @@ def _optimize_post(model, lightweight_bmm=False):
13701334 from ipex_llm .transformers .models .qwen2_moe import qwen2moe_model_forward
13711335 from ipex_llm .transformers .models .qwen2_moe import qwen2_moe_causal_lm_forward
13721336 from ipex_llm .transformers .models .qwen2 import qwen2_attention_forward
1337+ from ipex_llm .transformers .models .qwen2 import qwen2_mlp_forward
13731338 convert_forward (model ,
13741339 module .Qwen2MoeModel ,
13751340 qwen2moe_model_forward )
@@ -1384,7 +1349,7 @@ def _optimize_post(model, lightweight_bmm=False):
13841349 qwen2moe_moeblock_forward )
13851350 convert_forward (model ,
13861351 module .Qwen2MoeMLP ,
1387- llama_mlp_forward )
1352+ qwen2_mlp_forward )
13881353 convert_forward (model ,
13891354 module .Qwen2MoeAttention ,
13901355 qwen2_attention_forward )
@@ -1768,7 +1733,9 @@ def safe_bmm_fwd(*args, **kwargs):
17681733 model .llm .config .model_type = "minicpmv"
17691734 elif model .config .hidden_size == 4096 and model .config .vocab_size == 128256 :
17701735 # MiniCPM-V 2.5
1771- pass
1736+ model .llm .config .model_type = "llama"
1737+ _optimize_post (model .llm , lightweight_bmm = lightweight_bmm )
1738+ model .llm .config .model_type = "minicpmv"
17721739
17731740 vpm_modeling_module_name = model .vpm .__class__ .__module__
17741741 vpm_module = importlib .import_module (vpm_modeling_module_name )
0 commit comments