@@ -59,6 +59,9 @@ def generate(
5959 streamer : Optional ["BaseStreamer" ] = None ,
6060 ** kwargs ,
6161):
62+ # if do_print=True, output timing message
63+ do_print = kwargs .pop ("do_print" , False )
64+ time_start_all , time_t1 , idx = time .perf_counter (), None , 0
6265 new_generate_kwargs = {}
6366 for var in ['max_new_tokens' , 'attention_mask' , 'eos_token_id' ]:
6467 value = kwargs .pop (var , None )
@@ -79,7 +82,7 @@ def generate(
7982 thread = threading .Thread (target = generate_serve ,
8083 args = (self .kv_len , self .num_head ,
8184 self .head_dim , self .num_layers ,
82- new_tokens ))
85+ new_tokens - 1 ))
8386 thread .start ()
8487
8588 in_pipe_path = "\\ \\ .\\ pipe\\ llminputpipe"
@@ -115,6 +118,8 @@ def generate(
115118
116119 bdata = bdata + eos .to_bytes (4 , sys .byteorder )
117120
121+ time_start = time .perf_counter ()
122+
118123 input_pipe .write (bytearray (bdata ))
119124 input_pipe .flush ()
120125
@@ -125,17 +130,32 @@ def generate(
125130 if len (data ) == 0 :
126131 break
127132 token = int .from_bytes (data , sys .byteorder )
133+ idx += 1
134+ if time_t1 is None :
135+ time_t1 = time .perf_counter ()
128136 output_tokens .append (torch .tensor ([token ]))
129137 if streamer is not None :
130138 streamer .put (torch .tensor ([token ]))
131139 if token == eos :
132140 break
133141
134142 output = torch .stack (output_tokens , dim = 1 )
143+ output = torch .cat ((inputs , output ), dim = 1 )
135144 if streamer is not None :
136145 streamer .end ()
137146
138147 thread .join ()
148+ time_end = time .perf_counter ()
149+
150+ if do_print :
151+ print (f" Start the thread and connect the pipe time: { (time_start - time_start_all ):.2f} s" )
152+ print (f" Number of input tokens: { input_length } " )
153+ print (f" Generated tokens: { idx } " )
154+ print (f" First token generation time: { (time_t1 - time_start ):.2f} s" )
155+ print (f" Generation average latency: { (time_end - time_t1 )* 1000 / (idx - 1 ):.2f} ms, "
156+ f"({ (idx - 1 )/ (time_end - time_t1 ):.2f} token/s)" )
157+ print (f" Generation time: { (time_end - time_start ):.2f} s\n " )
158+
139159 return output
140160
141161
0 commit comments