Skip to content

Commit 310f18c

Browse files
authored
update NPU pipeline generate (#12182)
* update * fix style
1 parent 1daab45 commit 310f18c

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def generate(
5959
streamer: Optional["BaseStreamer"] = None,
6060
**kwargs,
6161
):
62+
# if do_print=True, output timing message
63+
do_print = kwargs.pop("do_print", False)
64+
time_start_all, time_t1, idx = time.perf_counter(), None, 0
6265
new_generate_kwargs = {}
6366
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
6467
value = kwargs.pop(var, None)
@@ -79,7 +82,7 @@ def generate(
7982
thread = threading.Thread(target=generate_serve,
8083
args=(self.kv_len, self.num_head,
8184
self.head_dim, self.num_layers,
82-
new_tokens))
85+
new_tokens - 1))
8386
thread.start()
8487

8588
in_pipe_path = "\\\\.\\pipe\\llminputpipe"
@@ -115,6 +118,8 @@ def generate(
115118

116119
bdata = bdata + eos.to_bytes(4, sys.byteorder)
117120

121+
time_start = time.perf_counter()
122+
118123
input_pipe.write(bytearray(bdata))
119124
input_pipe.flush()
120125

@@ -125,17 +130,32 @@ def generate(
125130
if len(data) == 0:
126131
break
127132
token = int.from_bytes(data, sys.byteorder)
133+
idx += 1
134+
if time_t1 is None:
135+
time_t1 = time.perf_counter()
128136
output_tokens.append(torch.tensor([token]))
129137
if streamer is not None:
130138
streamer.put(torch.tensor([token]))
131139
if token == eos:
132140
break
133141

134142
output = torch.stack(output_tokens, dim=1)
143+
output = torch.cat((inputs, output), dim=1)
135144
if streamer is not None:
136145
streamer.end()
137146

138147
thread.join()
148+
time_end = time.perf_counter()
149+
150+
if do_print:
151+
print(f" Start the thread and connect the pipe time: {(time_start - time_start_all):.2f} s")
152+
print(f" Number of input tokens: {input_length}")
153+
print(f" Generated tokens: {idx}")
154+
print(f" First token generation time: {(time_t1 - time_start):.2f} s")
155+
print(f" Generation average latency: {(time_end - time_t1)*1000 /(idx - 1):.2f} ms, "
156+
f"({(idx - 1)/(time_end - time_t1):.2f} token/s)")
157+
print(f" Generation time: {(time_end - time_start):.2f} s\n")
158+
139159
return output
140160

141161

0 commit comments

Comments
 (0)