Skip to content

Commit e82c0c5

Browse files
committed
fix style and check tokens number
1 parent 1ce290b commit e82c0c5

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,21 @@ def generate(
6565
if value is not None:
6666
new_generate_kwargs[var] = value
6767

68+
if isinstance(inputs[0], torch.Tensor):
69+
numpy_input = inputs[0].numpy()
70+
else:
71+
numpy_input = inputs[0]
72+
input_length = numpy.size(numpy_input)
73+
74+
new_tokens = new_generate_kwargs['max_new_tokens']
75+
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
76+
"Input plus output tokens should not exceed max_output_len.")
77+
6878
# start generate_serve by Thread
69-
thread = threading.Thread(target=generate_serve, args=(self.kv_len, self.num_head, self.head_dim,
70-
self.num_layers,
71-
new_generate_kwargs['max_new_tokens']))
79+
thread = threading.Thread(target=generate_serve,
80+
args=(self.kv_len, self.num_head,
81+
self.head_dim, self.num_layers,
82+
new_tokens))
7283
thread.start()
7384

7485
in_pipe_path = "\\\\.\\pipe\\llminputpipe"
@@ -92,12 +103,6 @@ def generate(
92103
else:
93104
break
94105

95-
if isinstance(inputs[0], torch.Tensor):
96-
numpy_input = inputs[0].numpy()
97-
else:
98-
numpy_input = inputs[0]
99-
input_length = numpy.size(numpy_input)
100-
101106
bdata = b''
102107
for i in range(0, input_length):
103108
d = int(numpy_input[i])

0 commit comments

Comments
 (0)