From 317d3be11306807a799667bf6c3f5a3bdc6f85bc Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 13 Aug 2024 16:48:26 +0800 Subject: [PATCH 1/2] refactor llama convert --- .../llm/src/ipex_llm/transformers/convert.py | 164 +++++++----------- 1 file changed, 64 insertions(+), 100 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index b8c6899540f..931a28a7f4b 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -754,6 +754,10 @@ def _optimize_pre(model, qtype=None): model.llm.config.model_type = "qwen2" _optimize_pre(model.llm, qtype=qtype) model.llm.config.model_type = "minicpmv" + elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256: + model.llm.config.model_type = "llama" + _optimize_pre(model.llm, qtype=qtype) + model.llm.config.model_type = "minicpmv" return model @@ -933,16 +937,6 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]): def _optimize_post(model, lightweight_bmm=False): - from packaging import version - from ipex_llm.transformers.models.llama import llama_attention_forward_4_31 - from ipex_llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 - from ipex_llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 - from ipex_llm.transformers.models.llama import llama_rms_norm_forward - from ipex_llm.transformers.models.llama import llama_mlp_forward - from ipex_llm.transformers.models.llama import llama_decoder_forward - from ipex_llm.transformers.models.llama import llama_model_forward - from transformers.modeling_utils import PreTrainedModel - try: from sentence_transformers.SentenceTransformer import SentenceTransformer if isinstance(model, SentenceTransformer): @@ -961,110 +955,78 @@ def _optimize_post(model, lightweight_bmm=False): except ModuleNotFoundError: pass + from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` if not isinstance(model, PreTrainedModel): logger.info("Only HuggingFace Transformers models are currently " "supported for further optimizations") return model - vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING") - enable_vllm_se_batching = vllm_selective_batching is not None - enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true" - + from packaging import version trans_version = transformers.__version__ - if version.parse(trans_version) >= version.parse("4.31.0"): - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaRMSNorm, - llama_rms_norm_forward,) - convert_forward(model, - transformers.models.llama.modeling_llama.LlamaMLP, - llama_mlp_forward) - convert_forward(model, - transformers.models.llama.modeling_llama.LlamaDecoderLayer, - llama_decoder_forward) + # convert all nn.LayerNorm + from ipex_llm.transformers.models.bloom import bloom_layer_norm_forward + convert_forward(model, + nn.LayerNorm, + bloom_layer_norm_forward) + + if model.config.model_type == "llama": + from transformers.models.llama.modeling_llama import LlamaRMSNorm + from transformers.models.llama.modeling_llama import LlamaMLP + from transformers.models.llama.modeling_llama import LlamaAttention + from transformers.models.llama.modeling_llama import LlamaDecoderLayer + from transformers.models.llama.modeling_llama import LlamaModel if version.parse(trans_version) >= version.parse("4.36.0"): - # transformers version >= 4.36.0 + from transformers.models.llama.modeling_llama import LlamaSdpaAttention + + from ipex_llm.transformers.models.llama import llama_rms_norm_forward + from ipex_llm.transformers.models.llama import llama_mlp_forward + from ipex_llm.transformers.models.llama import llama_decoder_forward + + convert_forward(model, LlamaRMSNorm, llama_rms_norm_forward) + convert_forward(model, LlamaMLP, llama_mlp_forward) + convert_forward(model, LlamaDecoderLayer, llama_decoder_forward) + + if version.parse(trans_version) >= version.parse("4.41.0"): + from ipex_llm.transformers.models.llama import llama_model_forward_4_41 + from ipex_llm.transformers.models.llama import llama_attention_forward_4_41 + convert_forward(model, LlamaModel, llama_model_forward_4_41) + convert_forward(model, LlamaAttention, llama_attention_forward_4_41) + convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_41) + elif version.parse(trans_version) >= version.parse("4.38.0"): + from ipex_llm.transformers.models.llama import llama_model_forward_4_38 from ipex_llm.transformers.models.llama import llama_attention_forward_4_38 - if version.parse(trans_version) >= version.parse("4.38.0"): - if version.parse(trans_version) >= version.parse("4.41.0"): - from ipex_llm.transformers.models.llama import llama_model_forward_4_41 - from ipex_llm.transformers.models.llama import llama_attention_forward_4_41 - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaModel, - llama_model_forward_4_41) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_forward_4_41) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaSdpaAttention, - llama_attention_forward_4_41) - else: - from ipex_llm.transformers.models.llama import llama_model_forward_4_38 - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaModel, - llama_model_forward_4_38) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_forward_4_38) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaSdpaAttention, - llama_attention_forward_4_38) - else: - from ipex_llm.transformers.models.llama import llama_model_forward_4_36 - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaModel, - llama_model_forward_4_36) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_forward_4_38) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaSdpaAttention, - llama_attention_forward_4_38) + convert_forward(model, LlamaModel, llama_model_forward_4_38) + convert_forward(model, LlamaAttention, llama_attention_forward_4_38) + convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_38) + elif version.parse(trans_version) >= version.parse("4.36.0"): + from ipex_llm.transformers.models.llama import llama_model_forward_4_36 + from ipex_llm.transformers.models.llama import llama_attention_forward_4_38 + convert_forward(model, LlamaModel, llama_model_forward_4_36) + convert_forward(model, LlamaAttention, llama_attention_forward_4_38) + convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_38) else: - # transformers version between 4.31.0 - 4.35.2 - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_forward_4_31, ) - if enable_vllm_se_batching: - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaModel, + vllm_se_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING", "").lower() == "true" + if vllm_se_batching: + from ipex_llm.transformers.models.llama import ( llama_model_selective_batching_forward_4_31, - ) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaAttention, llama_attention_selective_batching_forward_4_31, ) + convert_forward(model, LlamaModel, + llama_model_selective_batching_forward_4_31) + convert_forward(model, LlamaAttention, + llama_attention_selective_batching_forward_4_31) else: - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaModel, - llama_model_forward) - else: - # todo implement 4.28.0 ~ 4.30.2 - pass - - # convert all nn.LayerNorm - from ipex_llm.transformers.models.bloom import bloom_layer_norm_forward - convert_forward(model, - nn.LayerNorm, - bloom_layer_norm_forward) - - if model.config.architectures is not None \ - and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: + from ipex_llm.transformers.models.llama import llama_model_forward + from ipex_llm.transformers.models.llama import llama_attention_forward_4_31 + convert_forward(model, LlamaModel, llama_model_forward) + convert_forward(model, LlamaAttention, llama_attention_forward_4_31) + + elif ( + model.config.architectures is not None + and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"] + ): if hasattr(model.config, 'padded_vocab_size') and \ model.config.padded_vocab_size in [65024, 64896]: # chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k @@ -1768,7 +1730,9 @@ def safe_bmm_fwd(*args, **kwargs): model.llm.config.model_type = "minicpmv" elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256: # MiniCPM-V 2.5 - pass + model.llm.config.model_type = "llama" + _optimize_post(model.llm, lightweight_bmm=lightweight_bmm) + model.llm.config.model_type = "minicpmv" vpm_modeling_module_name = model.vpm.__class__.__module__ vpm_module = importlib.import_module(vpm_modeling_module_name) From 12f1dd0a6e2da049efbe35e664706450e1c5f913 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 13 Aug 2024 17:16:28 +0800 Subject: [PATCH 2/2] fix --- python/llm/src/ipex_llm/transformers/convert.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 931a28a7f4b..2d36b5c54c1 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -970,6 +970,8 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, nn.LayerNorm, bloom_layer_norm_forward) + from ipex_llm.transformers.models.llama import llama_rms_norm_forward + from ipex_llm.transformers.models.llama import llama_mlp_forward if model.config.model_type == "llama": from transformers.models.llama.modeling_llama import LlamaRMSNorm @@ -1332,6 +1334,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.qwen2_moe import qwen2moe_model_forward from ipex_llm.transformers.models.qwen2_moe import qwen2_moe_causal_lm_forward from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward + from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward convert_forward(model, module.Qwen2MoeModel, qwen2moe_model_forward) @@ -1346,7 +1349,7 @@ def _optimize_post(model, lightweight_bmm=False): qwen2moe_moeblock_forward) convert_forward(model, module.Qwen2MoeMLP, - llama_mlp_forward) + qwen2_mlp_forward) convert_forward(model, module.Qwen2MoeAttention, qwen2_attention_forward)