@@ -65,10 +65,21 @@ def generate(
6565 if value is not None :
6666 new_generate_kwargs [var ] = value
6767
68+ if isinstance (inputs [0 ], torch .Tensor ):
69+ numpy_input = inputs [0 ].numpy ()
70+ else :
71+ numpy_input = inputs [0 ]
72+ input_length = numpy .size (numpy_input )
73+
74+ new_tokens = new_generate_kwargs ['max_new_tokens' ]
75+ invalidInputError (input_length + new_tokens <= self .kv_len + 1 ,
76+ "Input plus output tokens should not exceed max_output_len." )
77+
6878 # start generate_serve by Thread
69- thread = threading .Thread (target = generate_serve , args = (self .kv_len , self .num_head , self .head_dim ,
70- self .num_layers ,
71- new_generate_kwargs ['max_new_tokens' ]))
79+ thread = threading .Thread (target = generate_serve ,
80+ args = (self .kv_len , self .num_head ,
81+ self .head_dim , self .num_layers ,
82+ new_tokens ))
7283 thread .start ()
7384
7485 in_pipe_path = "\\ \\ .\\ pipe\\ llminputpipe"
@@ -92,12 +103,6 @@ def generate(
92103 else :
93104 break
94105
95- if isinstance (inputs [0 ], torch .Tensor ):
96- numpy_input = inputs [0 ].numpy ()
97- else :
98- numpy_input = inputs [0 ]
99- input_length = numpy .size (numpy_input )
100-
101106 bdata = b''
102107 for i in range (0 , input_length ):
103108 d = int (numpy_input [i ])
0 commit comments