Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch
import xgrammar as xgr

from vllm.v1.request import Request
Expand All @@ -27,14 +26,18 @@
class StructuredOutputManager:

def __init__(self, vllm_config: VllmConfig):
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.vllm_config = vllm_config
self.init_complete = False

def _delayed_init(self):
"""Initialization delayed until we know it is needed."""
tokenizer_group = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
parallel_config=self.vllm_config.parallel_config,
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()

tokenizer = tokenizer_group.get_lora_tokenizer(None)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
Expand All @@ -47,12 +50,21 @@ def __init__(self, vllm_config: VllmConfig):
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask: Optional[torch.Tensor] = None
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)

self.init_complete = True

def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None:
return

# The first time this is called, we need to finish initialization
# of xgrammar. We defer it to avoid the import of xgrammar and
# initialization cost if it is not going to be used.
if not self.init_complete:
self._delayed_init()

grammar: Future[Grammar] = self.executor.submit(
self._async_create_grammar, request)
request.structured_output_request.grammar = grammar # type: ignore[assignment]
Expand Down Expand Up @@ -100,11 +112,6 @@ def grammar_bitmask(
if not structured_output_request_ids:
return None

if self._grammar_bitmask is None:
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size)

# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
Expand Down