Skip to content

Commit 416c191

Browse files
authored
Add Qwen pipeline and example (#12292)
* support qwen pipeline * update error msg * style * meet review * minor
1 parent 4cf1ccc commit 416c191

File tree

6 files changed

+422
-59
lines changed

6 files changed

+422
-59
lines changed

python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
88
|------------|----------------------------------------------------------------|
99
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
1010
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
11+
| Qwen2.5 | [Qwen/Qwen2.5-7b-Instruct](https://huggingface.co/Qwen/Qwen2.5-7b-Instruct) |
1112
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
1213
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16) |
1314

@@ -30,7 +31,7 @@ pip install --pre --upgrade ipex-llm[npu]
3031

3132
## 2. Runtime Configurations
3233

33-
**Following envrionment variables are required**:
34+
**Following environment variables are required**:
3435

3536
```cmd
3637
set BIGDL_USE_NPU=1
@@ -46,6 +47,9 @@ python llama2.py
4647
:: to run Meta-Llama-3-8B-Instruct
4748
python llama3.py
4849
50+
:: to run Qwen2.5-7b-Instruct
51+
python qwen.py
52+
4953
:: to run Baichuan2-7B-Chat
5054
python baichuan2.py
5155
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
18+
import os
19+
import torch
20+
import time
21+
import argparse
22+
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
23+
from transformers import AutoTokenizer
24+
from transformers.utils import logging
25+
26+
logger = logging.get_logger(__name__)
27+
28+
if __name__ == "__main__":
29+
parser = argparse.ArgumentParser(
30+
description="Predict Tokens using `generate()` API for npu model"
31+
)
32+
parser.add_argument(
33+
"--repo-id-or-model-path",
34+
type=str,
35+
default="Qwen/Qwen2.5-7B-Instruct", # Or Qwen2-7B-Instruct
36+
help="The huggingface repo id for the Baichuan2 model to be downloaded"
37+
", or the path to the huggingface checkpoint folder",
38+
)
39+
parser.add_argument("--lowbit-path", type=str,
40+
default="",
41+
help="The path to the lowbit model folder, leave blank if you do not want to save. \
42+
If path not exists, lowbit model will be saved there. \
43+
Else, lowbit model will be loaded.",
44+
)
45+
parser.add_argument('--prompt', type=str, default="AI是什么?",
46+
help='Prompt to infer')
47+
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
48+
parser.add_argument("--max-context-len", type=int, default=1024)
49+
parser.add_argument("--max-prompt-len", type=int, default=960)
50+
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
51+
52+
args = parser.parse_args()
53+
model_path = args.repo_id_or_model_path
54+
55+
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
56+
model = AutoModelForCausalLM.from_pretrained(model_path,
57+
optimize_model=True,
58+
pipeline=True,
59+
max_context_len=args.max_context_len,
60+
max_prompt_len=args.max_prompt_len,
61+
torch_dtype=torch.float16,
62+
attn_implementation="eager",
63+
transpose_value_cache=not args.disable_transpose_value_cache,
64+
mixed_precision=True,
65+
trust_remote_code=True)
66+
else:
67+
model = AutoModelForCausalLM.load_low_bit(
68+
args.lowbit_path,
69+
attn_implementation="eager",
70+
torch_dtype=torch.float16,
71+
max_context_len=args.max_context_len,
72+
max_prompt_len=args.max_prompt_len,
73+
pipeline=True,
74+
transpose_value_cache=not args.disable_transpose_value_cache)
75+
76+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
77+
78+
if args.lowbit_path and not os.path.exists(args.lowbit_path):
79+
model.save_low_bit(args.lowbit_path)
80+
81+
print("-" * 80)
82+
print("done")
83+
messages = [{"role": "system", "content": "You are a helpful assistant."},
84+
{"role": "user", "content": args.prompt}]
85+
text = tokenizer.apply_chat_template(messages,
86+
tokenize=False,
87+
add_generation_prompt=True)
88+
with torch.inference_mode():
89+
print("finish to load")
90+
for i in range(5):
91+
_input_ids = tokenizer([text], return_tensors="pt").input_ids
92+
print("input length:", len(_input_ids[0]))
93+
st = time.time()
94+
output = model.generate(
95+
_input_ids, max_new_tokens=args.n_predict, do_print=True
96+
)
97+
end = time.time()
98+
print(f"Inference time: {end-st} s")
99+
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
100+
print("-" * 20, "Input", "-" * 20)
101+
print(input_str)
102+
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
103+
print("-" * 20, "Output", "-" * 20)
104+
print(output_str)
105+
106+
print("-" * 80)
107+
print("done")
108+
print("success shut down")

python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,43 @@ def convert_minicpm(
267267
convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward)
268268

269269

270+
def convert_qwen(
271+
model: torch.nn.Module,
272+
max_output_len=1024,
273+
max_prompt_len=1024,
274+
decoder=False,
275+
inter_pp=None,
276+
intra_pp=None,
277+
transpose_value_cache=True,
278+
):
279+
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
280+
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
281+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
282+
if decoder:
283+
decode_runner = DecodeRunner(
284+
model,
285+
max_seq_len=max_output_len,
286+
inter_pp=inter_pp,
287+
intra_pp=intra_pp,
288+
transpose_value_cache=transpose_value_cache,
289+
)
290+
else:
291+
decode_runner = None
292+
prefill_runner = PrefillRunner(
293+
model,
294+
max_output_len=max_output_len,
295+
max_prompt_len=max_prompt_len,
296+
transpose_value_cache=transpose_value_cache,
297+
)
298+
qwen2_model_forward = gen_qwen2_fused_model_forward(
299+
prefill_runner=prefill_runner, decode_runner=decode_runner
300+
)
301+
convert_forward(model, Qwen2Model, qwen2_model_forward)
302+
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
303+
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
304+
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
305+
306+
270307
def optimize_llm(
271308
model: torch.nn.Module,
272309
max_context_len=1024,
@@ -300,31 +337,13 @@ def optimize_llm(
300337
inter_pp = 2
301338
else:
302339
inter_pp = 1
303-
304-
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
305-
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
306-
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
307-
308-
decode_runner = DecodeRunner(
309-
model,
310-
max_seq_len=max_context_len,
311-
inter_pp=inter_pp,
312-
intra_pp=intra_pp,
313-
transpose_value_cache=transpose_value_cache,
314-
)
315-
prefill_runner = PrefillRunner(
316-
model,
317-
max_output_len=max_context_len,
318-
max_prompt_len=max_prompt_len,
319-
transpose_value_cache=transpose_value_cache,
320-
)
321-
qwen2_model_forward = gen_qwen2_fused_model_forward(
322-
prefill_runner=prefill_runner, decode_runner=decode_runner
323-
)
324-
convert_forward(model, Qwen2Model, qwen2_model_forward)
325-
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
326-
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
327-
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
340+
convert_qwen(model,
341+
max_output_len=max_context_len,
342+
max_prompt_len=max_prompt_len,
343+
inter_pp=inter_pp,
344+
intra_pp=intra_pp,
345+
decoder=True,
346+
transpose_value_cache=transpose_value_cache)
328347
elif model.config.model_type == "minicpm":
329348
# for minicpm-1b
330349
if intra_pp is None:

python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -140,31 +140,13 @@ def __init__(
140140

141141
# Self Attention
142142
if mode == "decode":
143-
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1))
143+
attention_mask = self.create_input_op(
144+
(self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64)
144145
else:
145-
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len))
146+
attention_mask = self.create_input_op(
147+
(self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.int64)
146148

147-
position_ids = self.create_input_op((self.batch_size, self.seq_len))
148-
past_keys = []
149-
past_values = []
150-
if mode == "decode":
151-
for i in range(num_layers):
152-
past_key = self.create_cache_op(
153-
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
154-
)
155-
if transpose_value:
156-
past_value = self.create_cache_op(
157-
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
158-
)
159-
else:
160-
past_value = self.create_cache_op(
161-
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
162-
)
163-
past_keys.append(past_key)
164-
past_values.append(past_value)
165-
else:
166-
past_keys = [None] * num_layers
167-
past_values = [None] * num_layers
149+
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
168150

169151
if input_layernorm_weights is None:
170152
input_layernorm_weights = []
@@ -203,6 +185,27 @@ def __init__(
203185
k_biases = [self.constant(w) for w in k_biases]
204186
v_biases = [self.constant(w) for w in v_biases]
205187

188+
past_keys = []
189+
past_values = []
190+
if mode == "decode":
191+
for i in range(num_layers):
192+
past_key = self.create_cache_op(
193+
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
194+
)
195+
if transpose_value:
196+
past_value = self.create_cache_op(
197+
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
198+
)
199+
else:
200+
past_value = self.create_cache_op(
201+
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
202+
)
203+
past_keys.append(past_key)
204+
past_values.append(past_value)
205+
else:
206+
past_keys = [None] * num_layers
207+
past_values = [None] * num_layers
208+
206209
hidden_states = input
207210

208211
curr_key_values = []
@@ -396,8 +399,8 @@ def forward(
396399

397400
inputs = (
398401
hidden_states.to(torch.float16),
399-
attention_mask,
400-
position_ids.to(torch.float16),
402+
attention_mask.to(torch.int64),
403+
position_ids.to(torch.int64),
401404
)
402405

403406
for i in range(self.intra_stages):
@@ -514,7 +517,9 @@ def forward(
514517
seq_len = hidden_states.shape[1]
515518

516519
backend_cls = self.backend_cls_prefill
517-
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
520+
inputs = (hidden_states.to(torch.float16),
521+
attention_mask.to(torch.int64),
522+
position_ids.to(torch.int64))
518523
inputs += (self.layer_norm_0, self.layer_norm_1)
519524
inputs += (self.q_bias, self.k_bias, self.v_bias)
520525
hidden_states, past_key, past_value = run_model(
@@ -687,9 +692,9 @@ def run_decode(
687692
causal_mask[:, :, :, -1] = torch.finfo(torch.float16).min
688693
pad_mask = (0, pad_len)
689694
padded_causal_mask = F.pad(
690-
causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
695+
causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min
691696
)
692-
padded_causal_mask[:, :, :, -1] = 0.0
697+
padded_causal_mask[:, :, :, -1] = 0
693698
dist.recv(hidden_states, src=rank - 1)
694699
layer_outputs = multi_decoder(
695700
hidden_states,
@@ -973,9 +978,9 @@ def forward(
973978
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
974979
position_ids = F.pad(position_ids, (0, pad_len), value=0)
975980
attention_mask = F.pad(
976-
attention_mask.to(torch.float16),
981+
attention_mask.to(torch.int64),
977982
(0, pad_len, 0, pad_len),
978-
value=torch.finfo(torch.float16).min,
983+
value=torch.iinfo(torch.int64).min,
979984
)
980985

981986
args = (hidden_states, position_ids, attention_mask, past_key_value)

python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def convert_llm(model: torch.nn.Module,
196196
group_size: int):
197197
if group_size == 0:
198198
n_splits_linear = 1
199-
n_splits_down_proj = 1
199+
n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
200200
else:
201201
n_splits_linear = model.config.hidden_size // group_size
202202
n_splits_down_proj = model.config.intermediate_size // group_size
@@ -318,9 +318,49 @@ def convert_llm(model: torch.nn.Module,
318318
except:
319319
invalidInputError(False,
320320
"False to InitLLMPipeline.")
321+
elif model.config.model_type == "qwen2":
322+
with tempfile.TemporaryDirectory() as temp_dir:
323+
weight_dir = os.path.join(temp_dir, "model_weights")
324+
os.mkdir(weight_dir)
325+
layer_num = len(model.model.layers)
326+
from .qwen import convert_qwen_layer, convert_lm_head_and_embedding
327+
first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, n_splits_linear,
328+
temp_dir, weight_dir)
329+
330+
param_list = []
331+
for layer_idx in range(0, layer_num):
332+
param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
333+
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size))
334+
with Pool() as pool:
335+
result = pool.starmap(convert_qwen_layer, param_list)
336+
337+
# Prefill Runner
338+
from ipex_llm.transformers.npu_models.convert_mp import convert_qwen
339+
convert_qwen(model,
340+
max_output_len=kv_len,
341+
max_prompt_len=max_prompt_len,
342+
decoder=False,
343+
transpose_value_cache=transpose_value_cache)
344+
345+
# patch attrs for generate
346+
model.kv_len = kv_len
347+
model.num_head = model.model.layers[0].self_attn.num_key_value_heads
348+
model.head_dim = model.model.layers[0].self_attn.head_dim
349+
model.num_layers = layer_num
350+
model.transpose_value_cache = transpose_value_cache
351+
model.vocab_size = model.config.vocab_size
352+
353+
try:
354+
res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num,
355+
model.vocab_size, weight_dir, "model",
356+
first_blob_path, last_blob_path,
357+
os.path.join(temp_dir, "decoder_layer"))
358+
except:
359+
invalidInputError(False,
360+
"False to InitLLMPipeline.")
321361
else:
322-
invalidInputError(False,
323-
"Now we only support Llama2 / Llama3 / Baichuan2 for pipeline running.")
362+
invalidInputError(False, "Now we only support Llama2 / Llama3 / Baichuan2 / "
363+
"Qwen2 / Qwen2.5 / Minicpm for pipeline running.")
324364

325365
if isinstance(model.lm_head, SlicedLMHead):
326366
model.lm_head.get_fused_lm_head()

0 commit comments

Comments
 (0)