|
5 | 5 | import time |
6 | 6 | from collections import defaultdict |
7 | 7 | from collections.abc import Iterable |
8 | | -from typing import Any |
| 8 | +from typing import TYPE_CHECKING, Any |
9 | 9 |
|
10 | 10 | from vllm.config import VllmConfig |
11 | 11 | from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch |
|
34 | 34 | from vllm.v1.spec_decode.metrics import SpecDecodingStats |
35 | 35 | from vllm.v1.structured_output import StructuredOutputManager |
36 | 36 |
|
| 37 | +if TYPE_CHECKING: |
| 38 | + import numpy as np |
| 39 | + import numpy.typing as npt |
| 40 | + |
37 | 41 | logger = init_logger(__name__) |
38 | 42 |
|
39 | 43 |
|
@@ -608,11 +612,8 @@ def schedule(self) -> SchedulerOutput: |
608 | 612 | scheduled_spec_decode_tokens, |
609 | 613 | req_to_new_blocks, |
610 | 614 | ) |
611 | | - scheduled_requests = ( |
612 | | - scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs |
613 | | - ) |
614 | 615 | structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( |
615 | | - scheduled_requests, scheduled_spec_decode_tokens |
| 616 | + num_scheduled_tokens.keys(), scheduled_spec_decode_tokens |
616 | 617 | ) |
617 | 618 | scheduler_output = SchedulerOutput( |
618 | 619 | scheduled_new_reqs=new_reqs_data, |
@@ -876,32 +877,28 @@ def _try_schedule_encoder_inputs( |
876 | 877 |
|
877 | 878 | def get_grammar_bitmask( |
878 | 879 | self, |
879 | | - requests: list[Request], |
| 880 | + scheduled_request_ids: Iterable[str], |
880 | 881 | scheduled_spec_decode_tokens: dict[str, list[int]], |
881 | | - ): |
882 | | - # NOTE: structured_output_request_ids maps |
883 | | - # a request's (request that uses structured output) |
884 | | - # request_id to its index in the batch. |
885 | | - # This will help us determine to slice the grammar bitmask |
886 | | - # and only applies valid mask for requests that |
887 | | - # uses structured decoding. |
888 | | - structured_output_request_ids: dict[str, int] = {} |
889 | | - for i, req in enumerate(requests): |
890 | | - if req.use_structured_output: |
891 | | - # PERF: in case of chunked prefill, |
892 | | - # request might not include any new tokens. |
893 | | - # Therefore, we might introduce some additional |
894 | | - # cycle to fill in the bitmask, which could be a big no-op. |
895 | | - structured_output_request_ids[req.request_id] = i |
896 | | - |
| 882 | + ) -> tuple[list[str], "npt.NDArray[np.int32] | None"]: |
| 883 | + # Collect list of scheduled request ids that use structured output. |
| 884 | + # The corresponding rows of the bitmask will be in this order. |
| 885 | + # PERF: in case of chunked prefill, |
| 886 | + # request might not include any new tokens. |
| 887 | + # Therefore, we might introduce some additional |
| 888 | + # cycle to fill in the bitmask, which could be a big no-op. |
| 889 | + structured_output_request_ids = [ |
| 890 | + req_id |
| 891 | + for req_id in scheduled_request_ids |
| 892 | + if (req := self.requests.get(req_id)) and req.use_structured_output |
| 893 | + ] |
897 | 894 | if not structured_output_request_ids: |
898 | | - bitmask = None |
899 | | - else: |
900 | | - bitmask = self.structured_output_manager.grammar_bitmask( |
901 | | - self.requests, |
902 | | - structured_output_request_ids, |
903 | | - scheduled_spec_decode_tokens, |
904 | | - ) |
| 895 | + return structured_output_request_ids, None |
| 896 | + |
| 897 | + bitmask = self.structured_output_manager.grammar_bitmask( |
| 898 | + self.requests, |
| 899 | + structured_output_request_ids, |
| 900 | + scheduled_spec_decode_tokens, |
| 901 | + ) |
905 | 902 | return structured_output_request_ids, bitmask |
906 | 903 |
|
907 | 904 | def update_from_output( |
@@ -1013,12 +1010,10 @@ def update_from_output( |
1013 | 1010 | new_logprobs = logprobs.slice(req_index, req_index + 1) |
1014 | 1011 |
|
1015 | 1012 | if new_token_ids and self.structured_output_manager.should_advance(request): |
1016 | | - # NOTE: structured_output_request |
1017 | | - # should not be None if use_structured_output, we have |
1018 | | - # checked above, so safe to ignore type warning |
1019 | | - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] |
1020 | | - req_id, new_token_ids |
1021 | | - ) |
| 1013 | + struct_output_request = request.structured_output_request |
| 1014 | + assert struct_output_request is not None |
| 1015 | + assert struct_output_request.grammar is not None |
| 1016 | + struct_output_request.grammar.accept_tokens(req_id, new_token_ids) |
1022 | 1017 |
|
1023 | 1018 | if num_nans_in_logits is not None and req_id in num_nans_in_logits: |
1024 | 1019 | request.num_nans_in_logits = num_nans_in_logits[req_id] |
|
0 commit comments