Skip to content

Commit 03bd01c

Browse files
authored
optimize npu qwen2 (#12107)
1 parent 0239902 commit 03bd01c

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,22 +399,22 @@ def set_weights_async(self, op_id, weights):
399399
self.setWeights(offset, op_id, *weights)
400400

401401
@staticmethod
402-
def run_decoders(inputs, decoders):
402+
def run_decoders(inputs, decoders, models_ptr=None):
403403
x_np = [elem.to(torch.float16).numpy() for elem in inputs]
404404

405405
num_decoders = len(decoders)
406406
num_inputs = len(x_np)
407407

408-
with record_function(f"npu_factory"):
409-
408+
if models_ptr is None:
410409
array_type = ctypes.POINTER(ctypes.c_char) * num_decoders
411410
models_ptr = array_type(
412411
*[decoders[i]._mm for i in range(num_decoders)]
413412
)
414-
inputs_ptr = (ctypes.c_void_p * num_inputs)(
415-
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
416-
)
417-
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)
413+
414+
inputs_ptr = (ctypes.c_void_p * num_inputs)(
415+
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
416+
)
417+
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)
418418

419419
hidden_states = decoders[-1].torch_out[0]
420420
new_key_states = []

python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import torch
1919
import time
20-
20+
import ctypes
2121
from typing import Optional, Sequence, List, Union, Any, Tuple
2222
import numpy as np
2323

@@ -379,6 +379,9 @@ def __init__(
379379
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
380380
offset = offset + curr_linear_ops
381381

382+
array_type = ctypes.POINTER(ctypes.c_char) * intra_stages
383+
self.models_ptr = array_type(*[self.backend_decoders[i]._mm for i in range(intra_stages)])
384+
382385
def forward(
383386
self,
384387
hidden_states: torch.Tensor,
@@ -402,7 +405,8 @@ def forward(
402405

403406
hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders(
404407
inputs,
405-
decoders=self.backend_decoders)
408+
self.backend_decoders,
409+
self.models_ptr)
406410

407411
if self.do_print:
408412
print("outputs:", hidden_states)

0 commit comments

Comments
 (0)