Skip to content

Commit 3a1e648

Browse files
authored
[V1] Refactor Structured Output for multiple backends (vllm-project#14694)
Signed-off-by: Russell Bryant <[email protected]>
1 parent 46c759c commit 3a1e648

File tree

6 files changed

+290
-185
lines changed

6 files changed

+290
-185
lines changed

vllm/v1/engine/processor.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,21 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
119119
def _validate_structured_output(self, params: SamplingParams) -> None:
120120
if not params.guided_decoding or not self.decoding_config:
121121
return
122-
if self.decoding_config.guided_decoding_backend != "xgrammar":
123-
raise ValueError(
124-
"Only xgrammar structured output is supported in V1.")
125-
if (params.guided_decoding.backend
126-
and params.guided_decoding.backend != 'xgrammar'):
127-
raise ValueError(
128-
"Only xgrammar structured output is supported in V1.")
129-
if self.vllm_config.speculative_config:
130-
raise ValueError("Structured output is not supported with "
131-
"speculative decoding.")
122+
123+
supported_backends = ["xgrammar"]
124+
engine_level_backend = self.decoding_config.guided_decoding_backend
125+
if engine_level_backend not in supported_backends:
126+
raise ValueError(f"Only {supported_backends} structured output is "
127+
"supported in V1.")
128+
if params.guided_decoding.backend:
129+
if params.guided_decoding.backend != engine_level_backend:
130+
raise ValueError("Request-level structured output backend "
131+
"must match engine-level backend. "
132+
f"{params.guided_decoding.backend}"
133+
f" != {engine_level_backend}")
134+
else:
135+
params.guided_decoding.backend = engine_level_backend
136+
132137
if vllm.platforms.current_platform.is_tpu():
133138
raise ValueError("Structured output is not supported on TPU.")
134139

vllm/v1/structured_output/__init__.py

Lines changed: 31 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -7,104 +7,58 @@
77

88
from vllm.config import VllmConfig
99
from vllm.logger import init_logger
10-
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
11-
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
12-
from vllm.utils import LazyLoader
13-
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
10+
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
11+
StructuredOutputGrammar)
12+
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
1413

1514
if TYPE_CHECKING:
1615
import numpy as np
1716
import numpy.typing as npt
18-
import xgrammar as xgr
17+
import torch
1918

2019
from vllm.v1.request import Request
21-
else:
22-
xgr = LazyLoader("xgr", globals(), "xgrammar")
2320

2421
logger = init_logger(__name__)
2522

2623

2724
class StructuredOutputManager:
25+
"""Engine-level manager for structured output requests."""
2826

2927
def __init__(self, vllm_config: VllmConfig):
28+
self.backend: Optional[StructuredOutputBackend] = None
3029
self.vllm_config = vllm_config
31-
self.init_complete = False
32-
33-
def _delayed_init(self):
34-
"""Initialization delayed until we know it is needed."""
35-
tokenizer_group = init_tokenizer_from_configs(
36-
model_config=self.vllm_config.model_config,
37-
scheduler_config=self.vllm_config.scheduler_config,
38-
parallel_config=self.vllm_config.parallel_config,
39-
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
40-
tokenizer_group.ping()
41-
42-
tokenizer = tokenizer_group.get_lora_tokenizer(None)
43-
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
44-
if isinstance(tokenizer, MistralTokenizer):
45-
# NOTE: ideally, xgrammar should handle this accordingly.
46-
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
47-
try:
48-
encoded_vocab = [
49-
token for token, _ in sorted(
50-
tokenizer.get_vocab().items(),
51-
key=lambda x: x[1],
52-
)
53-
]
54-
stop_token_ids = None
55-
if hasattr(
56-
tokenizer,
57-
"eos_token_id",
58-
) and tokenizer.eos_token_id is not None:
59-
stop_token_ids = [tokenizer.eos_token_id]
60-
except AttributeError as e:
61-
raise ValueError(
62-
f"Cannot get the vocabulary of the tokenizer "
63-
f"{type(tokenizer)}. The tokenizer should have a "
64-
"get_vocab method.") from e
65-
tokenizer_info = xgr.TokenizerInfo(
66-
encoded_vocab=encoded_vocab,
67-
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
68-
vocab_type=xgr.VocabType.BYTE_FALLBACK,
69-
vocab_size=self.vocab_size,
70-
stop_token_ids=stop_token_ids,
71-
add_prefix_space=True,
72-
)
73-
else:
74-
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
75-
tokenizer,
76-
vocab_size=self.vocab_size,
77-
)
78-
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
30+
self._grammar_bitmask: Optional[torch.Tensor] = None
7931

8032
# The default max_workers if not specified is the number of CPUs * 5,
8133
# which is way too high since these tasks are CPU-bound, not I/O bound.
8234
# We also know we would never dominate CPU usage with just grammar
8335
# compilation, so we set it to half the number of CPUs.
8436
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
8537
self.executor = ThreadPoolExecutor(max_workers=max_workers)
86-
self._grammar_bitmask = xgr.allocate_token_bitmask(
87-
self.vllm_config.scheduler_config.max_num_seqs,
88-
self.vocab_size,
89-
)
90-
91-
self.init_complete = True
9238

9339
def grammar_init(self, request: Request) -> None:
9440
if request.structured_output_request is None:
9541
return
9642

97-
# The first time this is called, we need to finish initialization
98-
# of xgrammar. We defer it to avoid the import of xgrammar and
99-
# initialization cost if it is not going to be used.
100-
if not self.init_complete:
101-
self._delayed_init()
43+
# Initialize the backend the first time it is needed.
44+
#
45+
# NOTE: We only support a single backend. We do NOT support different
46+
# backends on a per-request basis in V1 (for now, anyway...).
47+
if self.backend is None:
48+
backend_name = request.sampling_params.guided_decoding.backend_name
49+
if backend_name == "xgrammar":
50+
self.backend = XgrammarBackend(self.vllm_config)
51+
else:
52+
raise ValueError(
53+
f"Unsupported structured output backend: {backend_name}")
10254

103-
grammar: Future[Grammar] = self.executor.submit(
104-
self._async_create_grammar, request)
55+
grammar: Future[StructuredOutputGrammar] = self.executor.submit(
56+
self._async_create_grammar, request, self.backend)
10557
request.structured_output_request.grammar = grammar # type: ignore[assignment]
10658

107-
def _async_create_grammar(self, request: Request) -> Grammar:
59+
def _async_create_grammar(
60+
self, request: Request,
61+
backend: StructuredOutputBackend) -> StructuredOutputGrammar:
10862
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
10963

11064
# Note that the request was validated in the engine core client,
@@ -114,28 +68,8 @@ def _async_create_grammar(self, request: Request) -> Grammar:
11468
# though it should be unlikely as we test that up front as well.
11569
request_type, grammar_spec = key
11670

117-
if request_type == StructuredOutputOptions.JSON:
118-
# TODO -- allow any_whitespace to be configurable
119-
# pending merge of https://github.com/vllm-project/vllm/pull/12744
120-
ctx = self.compiler.compile_json_schema(grammar_spec,
121-
any_whitespace=False)
122-
elif request_type == StructuredOutputOptions.JSON_OBJECT:
123-
ctx = self.compiler.compile_builtin_json_grammar()
124-
elif request_type == StructuredOutputOptions.GRAMMAR:
125-
ctx = self.compiler.compile_grammar(grammar_spec)
126-
elif request_type == StructuredOutputOptions.REGEX:
127-
ctx = self.compiler.compile_regex(grammar_spec)
128-
else:
129-
logger.error("Validation should have already occurred. "
130-
"Please file an issue.")
131-
raise ValueError(
132-
f"grammar is not of valid supported types. ({request_type!s})")
133-
134-
return Grammar(
135-
matcher=xgr.GrammarMatcher(ctx),
136-
vocab_size=self.vocab_size,
137-
ctx=ctx,
138-
)
71+
assert self.backend is not None
72+
return self.backend.compile_grammar(request_type, grammar_spec)
13973

14074
def grammar_bitmask(
14175
self,
@@ -147,14 +81,19 @@ def grammar_bitmask(
14781
if not structured_output_request_ids:
14882
return None
14983

84+
if self._grammar_bitmask is None:
85+
assert self.backend is not None
86+
self._grammar_bitmask = self.backend.allocate_token_bitmask(
87+
self.vllm_config.scheduler_config.max_num_seqs)
88+
15089
# Fill the bitmask using the index of each request equal to its
15190
# position in the batch. Resize the bitmask down to the size of
15291
# the batch.
15392
bitmask_tensor = self._grammar_bitmask
15493
for req_id, batch_index in structured_output_request_ids.items():
15594
request = requests[req_id].structured_output_request
15695
assert request is not None and request.grammar is not None
157-
if not request.grammar.matcher.is_terminated():
96+
if not request.grammar.is_terminated():
15897
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
15998
if batch_len < self._grammar_bitmask.shape[0]:
16099
bitmask_tensor = self._grammar_bitmask[:batch_len]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import enum
4+
from abc import ABC, abstractmethod
5+
6+
import torch
7+
8+
9+
class StructuredOutputOptions(enum.Enum):
10+
JSON = enum.auto()
11+
JSON_OBJECT = enum.auto()
12+
REGEX = enum.auto()
13+
GRAMMAR = enum.auto()
14+
CHOICE = enum.auto()
15+
16+
17+
StructuredOutputKey = tuple[StructuredOutputOptions, str]
18+
19+
20+
class StructuredOutputGrammar(ABC):
21+
"""Request-level backend for structured output requests."""
22+
23+
@abstractmethod
24+
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
25+
"""
26+
Determines whether the provided tokens are accepted for the
27+
given request.
28+
29+
Args:
30+
request_id (str): The unique identifier for the request.
31+
tokens (list[int]): A list of token IDs to evaluate.
32+
33+
Returns:
34+
bool: True if the tokens are accepted, False otherwise.
35+
"""
36+
37+
@abstractmethod
38+
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
39+
"""
40+
Fills the bitmask for a specific batch index.
41+
42+
Args:
43+
bitmask (torch.Tensor): The bitmask to fill
44+
batch_index (int): The index in the bitmask to fill
45+
"""
46+
47+
@abstractmethod
48+
def is_terminated(self) -> bool:
49+
"""
50+
Checks whether the structured output process has terminated.
51+
52+
Returns:
53+
bool: True if the process is terminated, False otherwise.
54+
"""
55+
56+
@abstractmethod
57+
def reset(self):
58+
"""
59+
Resets the state of the structured output grammar.
60+
"""
61+
62+
63+
class StructuredOutputBackend(ABC):
64+
"""Engine-level backend for structured output requests."""
65+
66+
@abstractmethod
67+
def compile_grammar(self, request_type: StructuredOutputOptions,
68+
grammar_spec: str) -> StructuredOutputGrammar:
69+
"""
70+
Compiles a grammar specification into a structured output grammar.
71+
72+
Args:
73+
request_type (StructuredOutputOptions): The type of structured
74+
output request.
75+
grammar_spec (str): The grammar specification to compile.
76+
77+
Returns:
78+
StructuredOutputGrammar: The compiled structured output grammar.
79+
"""
80+
81+
@abstractmethod
82+
def allocate_token_bitmask(self, max_num_seqs: int):
83+
"""
84+
Allocates a token bitmask for the specified maximum number of sequences.
85+
86+
Args:
87+
max_num_seqs (int): The maximum number of sequences for which
88+
to allocate the bitmask.
89+
"""

0 commit comments

Comments
 (0)