Skip to content

Commit 828fa01

Browse files
[NPU] Add mixed_precision for Qwen2 7B (#12098)
* Add mix_precision argument to control whether use INT8 lm_head for Qwen2-7B-Instruct * Small fix * Fixed on load low bit with mixed precision * Small fix * Update example accordingly * Update for default prompt * Update base on comments * Final fix
1 parent 2269768 commit 828fa01

File tree

4 files changed

+41
-24
lines changed

4 files changed

+41
-24
lines changed

python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ Right click and select **Update Driver** -> **Browse my computer for drivers**.
2626
## 1. Install
2727
### 1.1 Installation on Windows
2828
We suggest using conda to manage environment:
29-
```bash
29+
```cmd
3030
conda create -n llm python=3.10
3131
conda activate llm
3232
33-
# install ipex-llm with 'npu' option
33+
:: install ipex-llm with 'npu' option
3434
pip install --pre --upgrade ipex-llm[npu]
3535
```
3636

@@ -98,26 +98,26 @@ Supported models: Llama2-7B, MiniCPM-1B, Baichuan2-7B
9898
Supported models: Llama3-8B, MiniCPM-2B, Qwen2-7B, Qwen2-1.5B
9999

100100
### Run
101-
```bash
102-
# to run Llama-2-7b-chat-hf
101+
```cmd
102+
:: to run Llama-2-7b-chat-hf
103103
python llama.py
104104
105-
# to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
105+
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
106106
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct
107107
108-
# to run Qwen2-1.5B-Instruct LNL driver version: 32.0.101.2715)
108+
:: to run Qwen2-1.5B-Instruct LNL driver version: 32.0.101.2715)
109109
python qwen2.py
110110
111-
# to run Qwen2-7B-Instruct LNL driver version: 32.0.101.2715)
111+
:: to run Qwen2-7B-Instruct LNL driver version: 32.0.101.2715)
112112
python qwen2.py --repo-id-or-model-path Qwen/Qwen2-7B-Instruct
113113
114-
# to run MiniCPM-1B-sft-bf16
114+
:: to run MiniCPM-1B-sft-bf16
115115
python minicpm.py
116116
117-
# to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715)
117+
:: to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715)
118118
python minicpm.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16
119119
120-
# to run Baichuan2-7B-Chat
120+
:: to run Baichuan2-7B-Chat
121121
python baichuan2.py
122122
```
123123

@@ -137,29 +137,35 @@ If you encounter `TypeError: can't convert meta device type tensor to numpy. Use
137137

138138
#### Output Problem
139139
If you encounter output problem, please try to disable the optimization of transposing value cache with following command:
140-
```bash
141-
# to run Llama-2-7b-chat-hf
140+
```cmd
141+
:: to run Llama-2-7b-chat-hf
142142
python llama.py --disable-transpose-value-cache
143143
144-
# to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
144+
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
145145
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache
146146
147-
# to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
147+
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
148148
python qwen2.py --disable-transpose-value-cache
149149
150-
# to run Qwen2-7B-Instruct LNL driver version: 32.0.101.2715)
150+
:: to run Qwen2-7B-Instruct LNL driver version: 32.0.101.2715)
151151
python qwen2.py --repo-id-or-model-path Qwen/Qwen2-7B-Instruct --disable-transpose-value-cache
152152
153-
# to run MiniCPM-1B-sft-bf16
153+
:: to run MiniCPM-1B-sft-bf16
154154
python minicpm.py --disable-transpose-value-cache
155155
156-
# to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715)
156+
:: to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715)
157157
python minicpm.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16 --disable-transpose-value-cache
158158
159-
# to run Baichuan2-7B-Chat
159+
:: to run Baichuan2-7B-Chat
160160
python baichuan2.py --disable-transpose-value-cache
161161
```
162162

163+
For [Qwen2-7B](./qwen2.py), you could also try to enable mixed precision optimization when encountering output problems:
164+
165+
```cmd
166+
python qwen2.py --repo-id-or-model-path Qwen/Qwen2-7B-Instruct --mixed-precision
167+
```
168+
163169
#### Better Performance with High CPU Utilization
164170
You could enable optimization by setting the environment variable with `set IPEX_LLM_CPU_LM_HEAD=1` for better performance. But this will cause high CPU utilization.
165171

python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,15 @@
4343
If path not exists, lowbit model will be saved there. \
4444
Else, lowbit model will be loaded.",
4545
)
46-
parser.add_argument('--prompt', type=str, default="What is AI?",
46+
parser.add_argument('--prompt', type=str, default="AI是什么?",
4747
help='Prompt to infer')
4848
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
4949
parser.add_argument("--max-output-len", type=int, default=1024)
5050
parser.add_argument("--max-prompt-len", type=int, default=512)
5151
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
5252
parser.add_argument("--intra-pp", type=int, default=None)
5353
parser.add_argument("--inter-pp", type=int, default=None)
54+
parser.add_argument("--mixed-precision", action='store_true')
5455

5556
args = parser.parse_args()
5657
model_path = args.repo_id_or_model_path
@@ -68,6 +69,7 @@
6869
intra_pp=args.intra_pp,
6970
inter_pp=args.inter_pp,
7071
transpose_value_cache=not args.disable_transpose_value_cache,
72+
mixed_precision=args.mixed_precision
7173
)
7274
else:
7375
model = AutoModelForCausalLM.load_low_bit(

python/llm/src/ipex_llm/transformers/npu_model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def from_pretrained(cls, *args, **kwargs):
7878
Relevant low bit optimizations will be applied to the model.
7979
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
8080
Default to be ``False``.
81+
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
82+
Default to be False. If set to ``True``, we will use ``'sym_int8'`` for lm_head when
83+
``load_in_low_bit`` is '``sym_int4``' for certain models.
8184
:return: a model instance
8285
"""
8386
if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:
@@ -108,7 +111,6 @@ def from_pretrained(cls, *args, **kwargs):
108111
ignore_argument(kwargs, "load_in_4bit")
109112
ignore_argument(kwargs, "load_in_8bit")
110113
ignore_argument(kwargs, "imatrix")
111-
ignore_argument(kwargs, "mixed_precision")
112114
ignore_argument(kwargs, "cpu_embedding")
113115
ignore_argument(kwargs, "embedding_qtype")
114116
ignore_argument(kwargs, "enable_mp")
@@ -123,6 +125,7 @@ def from_pretrained(cls, *args, **kwargs):
123125
intra_pp = kwargs.pop("intra_pp", None)
124126
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
125127
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
128+
mixed_precision = kwargs.pop('mixed_precision', False)
126129

127130
_args = copy.deepcopy(args)
128131
_kwargs = copy.deepcopy(kwargs)
@@ -158,7 +161,8 @@ def from_pretrained(cls, *args, **kwargs):
158161
llm = model
159162

160163
with torch.no_grad():
161-
optimize_llm_pre(model, qtype)
164+
model.config.update({"mixed_precision": mixed_precision})
165+
optimize_llm_pre(model, qtype, mixed_precision)
162166
cls.load_convert(qtype, model, "cpu", modules_to_not_convert, *args, **kwargs)
163167
create_npu_kernels(llm)
164168
model = model.eval()
@@ -209,6 +213,7 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
209213
ignore_argument(kwargs, "embedding_qtype")
210214
ignore_argument(kwargs, "speculative")
211215
ignore_argument(kwargs, "pipeline_parallel_stages")
216+
ignore_argument(kwargs, "mixed_precision")
212217
optimize_model = kwargs.pop("optimize_model", False)
213218
max_output_len = kwargs.pop("max_output_len", 1024)
214219
max_prompt_len = kwargs.pop("max_prompt_len", 512)
@@ -258,6 +263,7 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
258263
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
259264
qtype = config_dict.pop("bigdl_transformers_low_bit", False)
260265
bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
266+
mixed_precision = config_dict.pop("mixed_precision", False)
261267

262268
invalidInputError(
263269
qtype,
@@ -370,7 +376,7 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
370376
llm = model
371377

372378
with torch.no_grad():
373-
optimize_llm_pre(model, qtype)
379+
optimize_llm_pre(model, qtype, mixed_precision)
374380
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
375381
*model_args, **kwargs)
376382
create_npu_kernels(llm)

python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def convert_forward(m, target_m, new_forward):
2929
convert_forward(sub_m, target_m, new_forward)
3030

3131

32-
def optimize_llm_pre(model: torch.nn.Module, qtype):
32+
def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision):
3333
if model.config.model_type == "baichuan":
3434
# process NormHead module in Baichuan2 7B
3535
if hasattr(model, 'lm_head') and model.lm_head is not None:
@@ -92,7 +92,10 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
9292
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
9393
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
9494
not cpu_lm_head:
95-
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=14,
95+
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
96+
is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
97+
split_num = 14 if is_split else 1
98+
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
9699
bias=model.lm_head.bias)
97100
del model.lm_head
98101
model.lm_head = new_lm_head

0 commit comments

Comments
 (0)