11from typing import Dict , List , Mapping , Optional , Type , Union
2-
2+ from dataclasses import dataclass
33from typing_extensions import TypeVar
44
55from vllm .config import VllmConfig
2626
2727_G = TypeVar ("_G" , bound = BaseTokenizerGroup , default = BaseTokenizerGroup )
2828
29+ def _none_safe_min (x ,y ):
30+ if x is None :
31+ return y
32+ if y is None :
33+ return x
34+ return min (x ,y )
35+
36+ def _none_safe_max (x ,y ):
37+ if x is None :
38+ return y
39+ if y is None :
40+ return x
41+ return max (x ,y )
42+
43+ def _none_safe_sum (x ,y ):
44+ if x is None :
45+ return y
46+ if y is None :
47+ return x
48+ return x + y
49+
50+ @dataclass
51+ class ParallelSampleChildRequestInfo :
52+ """Info for aggregating parallel sampling child requests under parent"""
53+ parent_req_id : str
54+ index : int
55+
56+ @dataclass
57+ class ParallelSampleParentRequestInfo :
58+ """Parallel sampling parent request info"""
59+ n : int
60+ n_finished : int = 0
61+
62+ def num_child_requests_remaining (self ):
63+ assert self .n >= self .n_finished
64+ return self .n - self .n_finished
2965
3066class LLMEngine :
3167 """Legacy LLMEngine for backwards compatibility."""
@@ -46,6 +82,14 @@ def __init__(
4682 # TODO: Can we avoid this?
4783 self .model_config = vllm_config .model_config
4884
85+ # Parallel sampling metadata
86+ # - Metadata for aggregating the child requests associated with a parent request
87+ self .child_req_id_to_parent_req_info : Dict [
88+ str , ParallelSampleChildRequestInfo ] = {}
89+ # - Parent request metadata i.e. degree of parallelism and other characteristics
90+ self .parent_req_id_info : Dict [str ,
91+ ParallelSampleParentRequestInfo ] = {}
92+
4993 # Tokenizer (+ ensure liveness if running in another process).
5094 self .tokenizer = init_tokenizer_from_configs (
5195 model_config = vllm_config .model_config ,
@@ -117,8 +161,52 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
117161
118162 return executor_class
119163
120- def get_num_unfinished_requests (self ) -> int :
164+ def _get_num_core_unfinished_requests (self ) -> int :
165+ """Total number of unfinished requests in engine core
166+
167+ Does not account for parallel sampling, i.e. a request
168+ with `n=3` contributes `(3-n_complete)` to the total
169+ (the parent request
170+ does not count); an unfinished request with `n=1`
171+ contributes 1 to the total.
172+
173+ Returns:
174+ Total requests in engine core
175+ """
121176 return self .detokenizer .get_num_unfinished_requests ()
177+
178+ def _get_num_parallel_sampling_parent_unfinished_requests (self ) -> int :
179+ """Total number of requests with parallel sampling
180+
181+ i.e. an unfinished request with `n=<blah>` counts as 1,
182+ all other requests count a 0.
183+
184+ Returns:
185+ Number of parallel sampling parent requests
186+ """
187+ return len (self .parent_req_id_info )
188+
189+ def _get_num_parallel_sampling_child_unfinished_requests (self ) -> int :
190+ """Total number of parallel sampling child requests.
191+
192+ i.e. an unfinished request with `n>1` counts as `(n-n_complete)`,
193+ all other requests count as 0.
194+
195+ Returns:
196+ Number of parallel sampling child requests
197+ """
198+ return sum ([preq_info .num_child_requests_remaining ()
199+ for (_ ,preq_info ) in self .parent_req_id_info .items ()])
200+
201+ def get_num_unfinished_requests (self ) -> int :
202+ """Number of unfinished requests.
203+
204+ Each request submitted by the user counts as 1; the child requests
205+ spawned by parallel sampling requests are not reflected in this count.
206+ """
207+ return (self ._get_num_core_unfinished_requests () -
208+ self ._get_num_parallel_sampling_child_unfinished_requests () +
209+ self ._get_num_parallel_sampling_parent_unfinished_requests ())
122210
123211 def has_unfinished_requests (self ) -> bool :
124212 return self .detokenizer .has_unfinished_requests ()
@@ -127,11 +215,78 @@ def has_unfinished_requests(self) -> bool:
127215 def validate_outputs (cls , outputs , output_type ):
128216 return outputs
129217
218+ def _forget_parallel_sample_child_request_and_maybe_parent (
219+ self ,
220+ child_request_id :str ,
221+ ) -> None :
222+ """Forget child request parallel sampling metadata, & its' parent's metadata if necessary.
223+
224+ Parent request parallel sampling metadata is forgotten once all child requests have finished.
225+
226+ Args:
227+ child_request_id: id of finished child request
228+ """
229+ # Forget child request metadata
230+ parent_req_id = self .child_req_id_to_parent_req_info [child_request_id ].parent_req_id
231+ self .child_req_id_to_parent_req_info .pop (child_request_id , None )
232+ # Track parent request's remaining child requests & erase parent request metadata
233+ # if there are no remaining child requests
234+ self .parent_req_id_info [parent_req_id ].n_finished += 1
235+ if self .parent_req_id_info [parent_req_id ].num_child_requests_remaining () == 0 :
236+ self .parent_req_id_info .pop (parent_req_id , None )
237+
238+ def _maybe_forget_parallel_sample_child_requests (
239+ self , possible_child_request_ids : List [str ]) -> None :
240+ """When a request aborts, if it is a child of a parallel sampling request,
241+ forget its parallel sampling metadata. Apply this to a list of possible child
242+ request ids. If the request is not associated with parallel sampling, this
243+ method has no effect on it.
244+
245+ Args:
246+ request_ids: list of request ids to possibly forget parallel sampling metadata for
247+ """
248+ for possible_child_req_id in possible_child_request_ids :
249+ # Check if request is a parallel sampling child request
250+ if possible_child_req_id in self .child_req_id_to_parent_req_info :
251+ # If so, forget child request parallel sampling metadata
252+ self ._forget_parallel_sample_child_request_and_maybe_parent (possible_child_req_id )
253+
254+
130255 def abort_request (self , request_ids : List [str ]) -> None :
131256 """Remove request_ids from EngineCore and Detokenizer."""
132257
133258 self .engine_core .abort_requests (request_ids )
134259 self .detokenizer .abort_requests (request_ids )
260+ self ._maybe_forget_parallel_sample_child_requests (request_ids )
261+
262+ def _register_parallel_sampling_parent_request (
263+ self ,
264+ parent_req_id : str ,
265+ parallel_sample_parent_req_info : ParallelSampleParentRequestInfo ,
266+ ) -> None :
267+ """Register the attributes associated with a parallel sampling request (i.e. the parent request)"""
268+ self .parent_req_id_info [
269+ parent_req_id ] = parallel_sample_parent_req_info
270+
271+ def _register_parallel_sampling_child_request (
272+ self ,
273+ parallel_sample_child_req_info : ParallelSampleChildRequestInfo ,
274+ ) -> str :
275+ """Register the association of a parallel sampling child req with its parent req.
276+
277+ Generates a child request id
278+
279+ Side effect: internal mapping from child req id -> parent req info structure
280+
281+ Returns:
282+ Child request id
283+ """
284+ parent_req_id = parallel_sample_child_req_info .parent_req_id
285+ index = parallel_sample_child_req_info .index
286+ child_req_id = f"{ parent_req_id } _parallel_sample_{ index } "
287+ self .child_req_id_to_parent_req_info [
288+ child_req_id ] = parallel_sample_child_req_info
289+ return child_req_id
135290
136291 def add_request (
137292 self ,
@@ -144,6 +299,29 @@ def add_request(
144299 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
145300 priority : int = 0 ,
146301 ) -> None :
302+ if isinstance (params , SamplingParams ) and params .n > 1 :
303+ # Register parallel sampling request
304+ n = params .n
305+ self ._register_parallel_sampling_parent_request (
306+ request_id , ParallelSampleParentRequestInfo (n ))
307+ params .n = 1 # Engine core cannot see `n`
308+ for ndx in range (n ):
309+ # Register child request with parent
310+ child_req_id = self ._register_parallel_sampling_child_request (
311+ ParallelSampleChildRequestInfo (request_id , ndx ))
312+ # Recurse to add child request; `n=1` prevents further recursion
313+ self .add_request (
314+ request_id = child_req_id ,
315+ prompt = prompt ,
316+ params = params ,
317+ arrival_time = arrival_time ,
318+ lora_request = lora_request ,
319+ trace_headers = trace_headers ,
320+ prompt_adapter_request = prompt_adapter_request ,
321+ priority = priority ,
322+ )
323+ # The top-level add_request call is done
324+ return
147325
148326 # 1) Process raw inputs into the request.
149327 detokenizer_req , engine_core_req = self .processor .process_inputs (
@@ -156,6 +334,80 @@ def add_request(
156334 # 3) Add the request to EngineCore.
157335 self .engine_core .add_request (engine_core_req )
158336
337+ def _is_parallel_sampling_child_request (
338+ self ,
339+ possible_child_request_id :str ,
340+ ) -> bool :
341+ return possible_child_request_id in self .child_req_id_to_parent_req_info
342+
343+ def _maybe_get_parallel_sampling_child_request_info (
344+ self ,
345+ possible_child_request_id : str ,
346+ ) -> Optional [ParallelSampleChildRequestInfo ]:
347+ return self .child_req_id_to_parent_req_info .get (possible_child_request_id ,None )
348+
349+ def _merge_parallel_sampling_child_request_output_in_place (
350+ self ,
351+ parent_req_output : RequestOutput ,
352+ child_req_output : RequestOutput ,
353+ ) -> None :
354+ # Parent is finished when all children are finished
355+ parent_req_output .finished = parent_req_output .finished and child_req_output .finished
356+ p_met = parent_req_output .metrics
357+ c_met = child_req_output .metrics
358+ if p_met is None :
359+ # If current parent request metrics are `None`, update with this child's metrics
360+ # (which may also be None)
361+ parent_req_output .metrics = c_met
362+ elif c_met is not None :
363+ # Only merge in child request output metrics if the child request output metrics
364+ # are not `None`
365+ p_met .last_token_time = max (p_met .last_token_time ,c_met .last_token_time )
366+ p_met .first_scheduled_time = _none_safe_min (p_met .first_scheduled_time ,
367+ c_met .first_scheduled_time )
368+ p_met .first_token_time = _none_safe_min (p_met .first_token_time ,c_met .first_token_time )
369+ p_met .time_in_queue = _none_safe_sum (p_met .time_in_queue ,c_met .time_in_queue )
370+ p_met .finished_time = _none_safe_max (p_met .finished_time ,c_met .finished_time )
371+ p_met .last_token_time = max (p_met .last_token_time ,c_met .last_token_time )
372+ p_met .model_execute_time = _none_safe_sum (p_met .model_execute_time ,c_met .model_execute_time )
373+ p_met .model_forward_time = _none_safe_sum (p_met .model_forward_time ,c_met .model_forward_time )
374+ p_met .scheduler_time = _none_safe_sum (p_met .scheduler_time ,c_met .scheduler_time )
375+ p_met .time_in_queue = _none_safe_sum (p_met .time_in_queue ,c_met .time_in_queue )
376+ parent_req_output .outputs .extend (child_req_output .outputs )
377+ parent_req_output .num_cached_tokens = _none_safe_sum (parent_req_output .num_cached_tokens ,
378+ child_req_output .num_cached_tokens )
379+
380+ def _maybe_aggregate_parallel_sampling_child_requests (
381+ self ,
382+ request_outputs : List [RequestOutput ],
383+ ) -> List [RequestOutput ]:
384+ agg_request_outputs : List [RequestOutput ]= []
385+ parent_req_id_to_idx : Dict [str ,int ]= {}
386+ for req_output in request_outputs :
387+ possible_child_req_id = req_output .request_id
388+ maybe_child_req_info = self ._maybe_get_parallel_sampling_child_request_info (possible_child_req_id )
389+ if maybe_child_req_info :
390+ parent_req_id = maybe_child_req_info .parent_req_id
391+ if parent_req_id not in parent_req_id_to_idx :
392+ # For a particular parent id, this is the first child request output we have seen.
393+ # Repurpose the child request output structure to be the parent request output structure
394+ req_output .request_id = parent_req_id
395+ agg_request_outputs .append (req_output )
396+ # Remember where the parent request output structure resides in the output list
397+ parent_req_id_to_idx [parent_req_id ]= len (agg_request_outputs )- 1
398+ else :
399+ # Merge this child request output into the growing request output data structure associated
400+ # with its parent.
401+ parent_req_output = agg_request_outputs [parent_req_id_to_idx [parent_req_id ]]
402+ self ._merge_parallel_sampling_child_request_output_in_place (parent_req_output ,req_output )
403+ else :
404+ # Not a parallel sampling request; don't touch it
405+ agg_request_outputs .append (req_output )
406+ return agg_request_outputs
407+
408+
409+
410+
159411 def step (self ) -> List [RequestOutput ]:
160412
161413 # 1) Get EngineCoreOutput from the EngineCore.
@@ -169,7 +421,9 @@ def step(self) -> List[RequestOutput]:
169421 if requests_to_abort :
170422 self .abort_request (requests_to_abort )
171423
172- return request_outputs
424+ # 4) If necessary, aggregate outputs for parallel sampling child requests
425+ # to be associated with parent request
426+ return self ._maybe_aggregate_parallel_sampling_child_requests (request_outputs )
173427
174428 # TODO(rob): Can we get rid of these?
175429
0 commit comments