@@ -99,6 +99,7 @@ def __init__(
9999 arrival_time : float ,
100100 queue : RequestOutputCollector | None ,
101101 log_stats : bool ,
102+ stream_interval : int ,
102103 top_p : float | None = None ,
103104 n : int | None = None ,
104105 temperature : float | None = None ,
@@ -126,6 +127,11 @@ def __init__(
126127
127128 self .stats = RequestStateStats (arrival_time = arrival_time ) if log_stats else None
128129
130+ # Stream Interval
131+ self .stream_interval = stream_interval
132+ self .total_num_output_tokens = 0 # Track total num of output tokens
133+ self .sent_tokens_offset = 0 # Offset of sent tokens
134+
129135 @classmethod
130136 def from_new_request (
131137 cls ,
@@ -136,6 +142,7 @@ def from_new_request(
136142 request_index : int ,
137143 queue : RequestOutputCollector | None ,
138144 log_stats : bool ,
145+ stream_interval : int ,
139146 ) -> "RequestState" :
140147 if sampling_params := request .sampling_params :
141148 if not sampling_params .detokenize :
@@ -183,6 +190,7 @@ def from_new_request(
183190 arrival_time = request .arrival_time ,
184191 queue = queue ,
185192 log_stats = log_stats ,
193+ stream_interval = stream_interval ,
186194 )
187195
188196 def make_request_output (
@@ -200,13 +208,40 @@ def make_request_output(
200208 # Only the final output is required in FINAL_ONLY mode.
201209 return None
202210
211+ # Stream Interval buffering: only apply for DELTA mode and stream_interval > 1
212+ is_delta_streaming = self .output_kind == RequestOutputKind .DELTA
213+ if is_delta_streaming and self .stream_interval > 1 :
214+ # Track total tokens generated
215+ self .total_num_output_tokens += len (new_token_ids )
216+
217+ # should send output when it is the first token or reach the stream interval
218+ should_send_output = (
219+ self .sent_tokens_offset == 0
220+ or self .total_num_output_tokens - self .sent_tokens_offset
221+ >= self .stream_interval
222+ )
223+
224+ # Do NOT send output if not finished and should not send output
225+ if not finished and not should_send_output :
226+ return None
227+
228+ # Send tokens from the offset
229+ assert self .detokenizer is not None
230+ tokens_to_send = self .detokenizer .output_token_ids [
231+ self .sent_tokens_offset :
232+ ]
233+ self .sent_tokens_offset = len (self .detokenizer .output_token_ids )
234+ else :
235+ # Send tokens immediately
236+ tokens_to_send = new_token_ids
237+
203238 request_id = self .request_id
204239 if pooling_output is not None :
205240 return self ._new_request_output (
206241 request_id , [self ._new_pooling_output (pooling_output )], finished
207242 )
208243
209- output = self ._new_completion_output (new_token_ids , finish_reason , stop_reason )
244+ output = self ._new_completion_output (tokens_to_send , finish_reason , stop_reason )
210245
211246 if self .parent_req is None :
212247 outputs = [output ]
@@ -305,9 +340,12 @@ def _new_pooling_output(
305340class OutputProcessor :
306341 """Process EngineCoreOutputs into RequestOutputs."""
307342
308- def __init__ (self , tokenizer : AnyTokenizer , log_stats : bool ):
343+ def __init__ (
344+ self , tokenizer : AnyTokenizer , log_stats : bool , stream_interval : int = 1
345+ ):
309346 self .log_stats = log_stats
310347 self .tokenizer = tokenizer
348+ self .stream_interval = stream_interval
311349 self .request_states : dict [str , RequestState ] = {}
312350 self .parent_requests : dict [str , ParentRequest ] = {}
313351 self .lora_states = LoRARequestStates ()
@@ -380,6 +418,7 @@ def add_request(
380418 request_index = request_index ,
381419 queue = queue ,
382420 log_stats = self .log_stats ,
421+ stream_interval = self .stream_interval ,
383422 )
384423 self .request_states [request_id ] = req_state
385424 self .lora_states .add_request (req_state )
0 commit comments