Skip to content

Commit 04e4a88

Browse files
committed
support stream interval
Signed-off-by: elvischenv <[email protected]>
1 parent 33a0ea5 commit 04e4a88

File tree

5 files changed

+57
-4
lines changed

5 files changed

+57
-4
lines changed

vllm/config/scheduler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ class SchedulerConfig:
137137
structured outputs, speculative decoding, and pipeline parallelism.
138138
"""
139139

140+
stream_interval: int = field(default=1, ge=1)
141+
"""The interval (or buffer size) for streaming in terms of token length.
142+
A smaller value (1) makes streaming smoother by sending each token immediately,
143+
while a larger value (e.g., 10) reduces host overhead and increases throughput
144+
by batching multiple tokens before sending."""
145+
140146
def compute_hash(self) -> str:
141147
"""
142148
WARNING: Whenever a new field is added to this config,

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,8 @@ class EngineArgs:
551551

552552
async_scheduling: bool = SchedulerConfig.async_scheduling
553553

554+
stream_interval: int = SchedulerConfig.stream_interval
555+
554556
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
555557

556558
def __post_init__(self):
@@ -1044,6 +1046,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10441046
scheduler_group.add_argument(
10451047
"--async-scheduling", **scheduler_kwargs["async_scheduling"]
10461048
)
1049+
scheduler_group.add_argument(
1050+
"--stream-interval", **scheduler_kwargs["stream_interval"]
1051+
)
10471052

10481053
# Compilation arguments
10491054
compilation_kwargs = get_kwargs(CompilationConfig)
@@ -1588,6 +1593,7 @@ def create_engine_config(
15881593
long_prefill_token_threshold=self.long_prefill_token_threshold,
15891594
disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
15901595
async_scheduling=self.async_scheduling,
1596+
stream_interval=self.stream_interval,
15911597
)
15921598

15931599
if not model_config.is_multimodal_model and self.default_mm_loras:

vllm/v1/engine/async_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ def __init__(
128128
)
129129

130130
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
131+
stream_interval = self.vllm_config.scheduler_config.stream_interval
131132
self.output_processor = OutputProcessor(
132-
self.tokenizer, log_stats=self.log_stats
133+
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
133134
)
134135
endpoint = self.observability_config.otlp_traces_endpoint
135136
if endpoint is not None:

vllm/v1/engine/llm_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ def __init__(
109109
)
110110

111111
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
112+
stream_interval = self.vllm_config.scheduler_config.stream_interval
112113
self.output_processor = OutputProcessor(
113-
self.tokenizer, log_stats=self.log_stats
114+
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
114115
)
115116
endpoint = self.observability_config.otlp_traces_endpoint
116117
if endpoint is not None:

vllm/v1/engine/output_processor.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
arrival_time: float,
100100
queue: RequestOutputCollector | None,
101101
log_stats: bool,
102+
stream_interval: int,
102103
top_p: float | None = None,
103104
n: int | None = None,
104105
temperature: float | None = None,
@@ -126,6 +127,11 @@ def __init__(
126127

127128
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
128129

130+
# Stream Interval
131+
self.stream_interval = stream_interval
132+
self.total_num_output_tokens = 0 # Track total num of output tokens
133+
self.sent_tokens_offset = 0 # Offset of sent tokens
134+
129135
@classmethod
130136
def from_new_request(
131137
cls,
@@ -136,6 +142,7 @@ def from_new_request(
136142
request_index: int,
137143
queue: RequestOutputCollector | None,
138144
log_stats: bool,
145+
stream_interval: int,
139146
) -> "RequestState":
140147
if sampling_params := request.sampling_params:
141148
if not sampling_params.detokenize:
@@ -183,6 +190,7 @@ def from_new_request(
183190
arrival_time=request.arrival_time,
184191
queue=queue,
185192
log_stats=log_stats,
193+
stream_interval=stream_interval,
186194
)
187195

188196
def make_request_output(
@@ -200,13 +208,40 @@ def make_request_output(
200208
# Only the final output is required in FINAL_ONLY mode.
201209
return None
202210

211+
# Stream Interval buffering: only apply for DELTA mode and stream_interval > 1
212+
is_delta_streaming = self.output_kind == RequestOutputKind.DELTA
213+
if is_delta_streaming and self.stream_interval > 1:
214+
# Track total tokens generated
215+
self.total_num_output_tokens += len(new_token_ids)
216+
217+
# should send output when it is the first token or reach the stream interval
218+
should_send_output = (
219+
self.sent_tokens_offset == 0
220+
or self.total_num_output_tokens - self.sent_tokens_offset
221+
>= self.stream_interval
222+
)
223+
224+
# Do NOT send output if not finished and should not send output
225+
if not finished and not should_send_output:
226+
return None
227+
228+
# Send tokens from the offset
229+
assert self.detokenizer is not None
230+
tokens_to_send = self.detokenizer.output_token_ids[
231+
self.sent_tokens_offset :
232+
]
233+
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
234+
else:
235+
# Send tokens immediately
236+
tokens_to_send = new_token_ids
237+
203238
request_id = self.request_id
204239
if pooling_output is not None:
205240
return self._new_request_output(
206241
request_id, [self._new_pooling_output(pooling_output)], finished
207242
)
208243

209-
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
244+
output = self._new_completion_output(tokens_to_send, finish_reason, stop_reason)
210245

211246
if self.parent_req is None:
212247
outputs = [output]
@@ -305,9 +340,12 @@ def _new_pooling_output(
305340
class OutputProcessor:
306341
"""Process EngineCoreOutputs into RequestOutputs."""
307342

308-
def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
343+
def __init__(
344+
self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1
345+
):
309346
self.log_stats = log_stats
310347
self.tokenizer = tokenizer
348+
self.stream_interval = stream_interval
311349
self.request_states: dict[str, RequestState] = {}
312350
self.parent_requests: dict[str, ParentRequest] = {}
313351
self.lora_states = LoRARequestStates()
@@ -380,6 +418,7 @@ def add_request(
380418
request_index=request_index,
381419
queue=queue,
382420
log_stats=self.log_stats,
421+
stream_interval=self.stream_interval,
383422
)
384423
self.request_states[request_id] = req_state
385424
self.lora_states.add_request(req_state)

0 commit comments

Comments
 (0)