Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/llm/dev/benchmark/all-in-one/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ test_api:
# - "speculative_cpu" # on Intel CPU, inference with self-speculative decoding
# - "deepspeed_transformer_int4_cpu" # on Intel CPU, deepspeed autotp inference
# - "transformers_int4_npu_win" # on Intel NPU for Windows, transformer-like API, (qtype=int4)
# - "transformers_int4_loadlowbit_npu_win" # on Intel NPU for Windows, transformer-like API, (qtype=int4), use load_low_bit API. Please make sure you have used the save_npu.py to save the converted low bit model
cpu_embedding: False # whether put embedding to CPU
streaming: False # whether output in streaming way (only available now for gpu win related test_api)
optimize_model: False # whether apply further optimization on NPU (only available now for transformers_int4_npu_win test_api)
Expand Down
74 changes: 74 additions & 0 deletions python/llm/dev/benchmark/all-in-one/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
result = run_pipeline_parallel_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=use_fp16_torch_dtype)
elif test_api == 'transformers_int4_npu_win':
result = transformers_int4_npu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, optimize_model, transpose_value_cache)
elif test_api == 'transformers_int4_loadlowbit_npu_win':
result = run_transformer_int4_loadlowbit_npu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, optimize_model, transpose_value_cache)
else:
invalidInputError(False, "Unknown test_api " + test_api + ", please check your config.yaml.")

Expand Down Expand Up @@ -669,6 +671,78 @@ def transformers_int4_npu_win(repo_id,
gc.collect()
return result

def run_transformer_int4_loadlowbit_npu_win(repo_id,
local_model_hub,
in_out_pairs,
warm_up,
num_trials,
num_beams,
low_bit,
batch_size,
optimize_model,
transpose_value_cache):
from ipex_llm.transformers.npu_model import AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer, LlamaTokenizer

model_path = get_model_path(repo_id, local_model_hub)
in_out_len = in_out_pairs[0].split("-")
max_output_len = max(int(in_out_len[0]) + int(in_out_len[1]), 1024)
# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
st = time.perf_counter()
if repo_id in CHATGLM_IDS:
model = AutoModel.load_low_bit(model_path+'-npu-'+low_bit, trust_remote_code=True,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=transpose_value_cache,
torch_dtype=torch.float16, attn_implementation="eager").eval()
tokenizer = AutoTokenizer.from_pretrained(model_path+'-npu-'+low_bit, trust_remote_code=True)
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.load_low_bit(model_path+'-npu-'+low_bit, trust_remote_code=True, torch_dtype=torch.float16,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=transpose_value_cache,
use_cache=True, attn_implementation="eager").eval()
tokenizer = LlamaTokenizer.from_pretrained(model_path+'-npu-'+low_bit, trust_remote_code=True)
else:
model = AutoModelForCausalLM.load_low_bit(model_path+'-npu-'+low_bit, trust_remote_code=True, torch_dtype=torch.float16,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=transpose_value_cache,
use_cache=True, attn_implementation="eager").eval()
tokenizer = AutoTokenizer.from_pretrained(model_path+'-npu-'+low_bit, trust_remote_code=True)
end = time.perf_counter()
load_time = end - st
print(">> loading of model costs {}s".format(load_time))

model = BenchmarkWrapper(model)

result = {}
with torch.inference_mode():
for in_out in in_out_pairs:
in_out_len = in_out.split("-")
in_len = int(in_out_len[0])
out_len = int(in_out_len[1])
input_str = get_continuation_input_str(in_len, tokenizer)
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.
input_ids = tokenizer.encode(input_str, return_tensors="pt")
input_ids = input_ids[:, :in_len]
true_str = tokenizer.batch_decode(input_ids)[0]
input_list = [true_str] * batch_size
input_ids = tokenizer(input_list, return_tensors="pt").input_ids
input_ids = input_ids[:, :in_len]
actual_in_len = input_ids.shape[1]
result[in_out] = []
for i in range(num_trials + warm_up):
st = time.perf_counter()
output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
min_new_tokens=out_len, num_beams=num_beams)
end = time.perf_counter()
print("model generate cost: " + str(end - st))
output = tokenizer.batch_decode(output_ids)
print(output[0])
actual_out_len = output_ids.shape[1] - actual_in_len
if i >= warm_up:
result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
actual_in_len, actual_out_len, load_time])
del model
gc.collect()
return result

def run_optimize_model_gpu(repo_id,
local_model_hub,
Expand Down
82 changes: 82 additions & 0 deletions python/llm/dev/benchmark/all-in-one/save_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# this code is to support converting of model in load bit
# for performance tests using load_low_bit

import time
import torch
import os
import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from run import get_model_path

current_dir = os.path.dirname(os.path.realpath(__file__))

def save_npu_model_in_low_bit(repo_id,
local_model_hub,
low_bit,
max_output_len, max_prompt_len, intra_pp, inter_pp, disable_transpose_value_cache):
model_path = get_model_path(repo_id, local_model_hub)
# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
st = time.perf_counter()
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
trust_remote_code=True,
attn_implementation="eager",
load_in_low_bit="sym_int4",
optimize_model=True,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
intra_pp=intra_pp,
inter_pp=inter_pp,
transpose_value_cache=not disable_transpose_value_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
end = time.perf_counter()
print(">> loading of and converting of model costs {}s".format(end - st))

model.save_low_bit(model_path+'-npu-'+low_bit)
tokenizer.save_pretrained(model_path+'-npu-'+low_bit)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Predict Tokens using `generate()` API for npu model"
)
parser.add_argument("--max-output-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--intra-pp", type=int, default=2)
parser.add_argument("--inter-pp", type=int, default=2)

args = parser.parse_args()
from omegaconf import OmegaConf
conf = OmegaConf.load(f'{current_dir}/config.yaml')

for model in conf.repo_id:
save_npu_model_in_low_bit(repo_id=model,
local_model_hub=conf['local_model_hub'],
low_bit=conf['low_bit'],
max_output_len=args.max_output_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
disable_transpose_value_cache=args.disable_transpose_value_cache
)