Skip to content

Commit 1f16b7f

Browse files
russellblochuynh1412mmoskalaarnphm
authored
[Core][V0] Add guidance backend for structured output (vllm-project#14589)
Signed-off-by: Russell Bryant <[email protected]> Co-authored-by: Loc Huynh <[email protected]> Co-authored-by: Michal Moskal <[email protected]> Co-authored-by: Aaron Pham <[email protected]>
1 parent b88be22 commit 1f16b7f

File tree

8 files changed

+167
-13
lines changed

8 files changed

+167
-13
lines changed

benchmarks/benchmark_serving_structured_output.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -999,11 +999,12 @@ def main(args: argparse.Namespace):
999999
type=float,
10001000
default=1.0,
10011001
help="Ratio of Structured Outputs requests")
1002-
parser.add_argument("--structured-output-backend",
1003-
type=str,
1004-
choices=["outlines", "lm-format-enforcer", "xgrammar"],
1005-
default="xgrammar",
1006-
help="Backend to use for structured outputs")
1002+
parser.add_argument(
1003+
"--structured-output-backend",
1004+
type=str,
1005+
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
1006+
default="xgrammar",
1007+
help="Backend to use for structured outputs")
10071008

10081009
args = parser.parse_args()
10091010
main(args)

requirements/common.txt

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pillow # Required for image processing
1818
prometheus-fastapi-instrumentator >= 7.0.0
1919
tiktoken >= 0.6.0 # Required for DBRX tokenizer
2020
lm-format-enforcer >= 0.10.11, < 0.11
21+
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
2122
outlines == 0.1.11
2223
lark == 1.2.2
2324
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"

tests/entrypoints/llm/test_guided_generate.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1515

1616
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
17-
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
17+
GUIDED_DECODING_BACKENDS = [
18+
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
19+
]
1820

1921

2022
@pytest.fixture(scope="module")

tests/model_executor/test_guided_processors.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from vllm.sampling_params import GuidedDecodingParams
1717

1818
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
19-
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
19+
GUIDED_DECODING_BACKENDS = [
20+
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
21+
]
2022
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
2123
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
2224

vllm/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2785,7 +2785,9 @@ def compute_hash(self) -> str:
27852785
return hash_str
27862786

27872787
def __post_init__(self):
2788-
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
2788+
valid_guided_backends = [
2789+
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
2790+
]
27892791

27902792
backend = GuidedDecodingParams(
27912793
backend=self.guided_decoding_backend).backend_name

vllm/model_executor/guided_decoding/__init__.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
7979
"xgrammar does not support Lark grammars and the "
8080
"grammar failed to convert to GBNF.", "outlines")
8181

82+
elif guided_params.json_object:
83+
# https://github.com/mlc-ai/xgrammar/issues/256
84+
fallback_or_error(guided_params,
85+
"xgrammar does not support json_object.",
86+
"guidance")
87+
8288
# If the xgrammar module cannot be imported successfully,
8389
# we should still allow users to use guided decoding with a fallback.
8490
elif not xgr_installed:
@@ -88,9 +94,9 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
8894

8995
if (guided_params.backend_name == "outlines"
9096
and guided_params.json_object is not None):
91-
# outlines doesn't support json_object, fallback to xgrammar
97+
# outlines doesn't support json_object, fallback to guidance
9298
fallback_or_error(guided_params,
93-
"outlines does not support json_object.", "xgrammar")
99+
"outlines does not support json_object.", "guidance")
94100

95101
return guided_params
96102

@@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor(
122128
get_local_xgrammar_guided_decoding_logits_processor)
123129
return get_local_xgrammar_guided_decoding_logits_processor(
124130
guided_params, tokenizer, model_config, reasoner)
125-
131+
if guided_params.backend_name == 'guidance':
132+
from vllm.model_executor.guided_decoding.guidance_decoding import (
133+
get_local_guidance_guided_decoding_logits_processor)
134+
return get_local_guidance_guided_decoding_logits_processor(
135+
guided_params, tokenizer)
126136
raise ValueError(
127137
f"Unknown guided decoding backend '{guided_params.backend}'. "
128-
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
138+
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
139+
)
129140

130141

131142
def get_local_guided_decoding_logits_processor(
@@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor(
155166
get_local_xgrammar_guided_decoding_logits_processor)
156167
return get_local_xgrammar_guided_decoding_logits_processor(
157168
guided_params, tokenizer, model_config, reasoner)
169+
if guided_params.backend_name == 'guidance':
170+
from vllm.model_executor.guided_decoding.guidance_decoding import (
171+
get_local_guidance_guided_decoding_logits_processor)
172+
return get_local_guidance_guided_decoding_logits_processor(
173+
guided_params, tokenizer)
158174

159175
raise ValueError(
160176
f"Unknown guided decoding backend '{guided_params.backend}'. "
161-
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
177+
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
178+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from re import escape as regex_escape
3+
4+
import llguidance
5+
from transformers import PreTrainedTokenizerBase
6+
7+
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
8+
GuidanceLogitsProcessor)
9+
from vllm.sampling_params import GuidedDecodingParams
10+
11+
12+
def get_local_guidance_guided_decoding_logits_processor(
13+
guided_params: GuidedDecodingParams,
14+
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
15+
"""
16+
Given an OpenAI-compatible request, check for guided decoding parameters
17+
and get the necessary logits processor for the given guide.
18+
"""
19+
20+
grm = ""
21+
if guided_params.json:
22+
grm = llguidance.LLMatcher.grammar_from_json_schema(
23+
guided_params.json,
24+
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
25+
elif guided_params.json_object:
26+
grm = llguidance.LLMatcher.grammar_from_json_schema(
27+
'{"type": "object"}',
28+
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
29+
elif guided_params.regex:
30+
grm = llguidance.grammar_from("regex", guided_params.regex)
31+
elif guided_params.choice:
32+
# choice just uses regex
33+
choices = (regex_escape(str(choice))
34+
for choice in guided_params.choice)
35+
choices_regex = "(" + "|".join(choices) + ")"
36+
grm = llguidance.grammar_from("regex", choices_regex)
37+
elif guided_params.grammar:
38+
# this supports Lark and GBNF
39+
grm = llguidance.grammar_from("grammar", guided_params.grammar)
40+
41+
if grm:
42+
return GuidanceLogitsProcessor(grm, tokenizer)
43+
44+
raise ValueError("Unknown guided decoding mode")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
from typing import Any, List
4+
5+
import llguidance
6+
import llguidance.hf
7+
import llguidance.torch
8+
import torch
9+
from transformers import PreTrainedTokenizerBase
10+
11+
from vllm.logger import init_logger
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class GuidanceLogitsProcessor:
17+
"""Base Guidance Logits Processor"""
18+
19+
cached_tokenizers: dict[str, Any] = {}
20+
21+
def __init__(
22+
self,
23+
grammar: str,
24+
tokenizer: PreTrainedTokenizerBase,
25+
) -> None:
26+
"""Base Guidance Logits Processor
27+
28+
Args:
29+
grammar (str)
30+
grammar to guide the generation
31+
tokenizer (PreTrainedTokenizerBase)
32+
model's tokenizer
33+
"""
34+
self.grammar = grammar
35+
self.tokenizer = tokenizer
36+
self.tokenizer_name = tokenizer.name_or_path
37+
self.new_sampling = False
38+
self.initialized = False
39+
40+
def _initialize(self):
41+
if self.initialized:
42+
return
43+
44+
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
45+
None)
46+
if ll_tokenizer is None:
47+
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
48+
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
49+
50+
self.ll_tokenizer = ll_tokenizer
51+
self.ll_matcher = llguidance.LLMatcher(
52+
self.ll_tokenizer,
53+
self.grammar,
54+
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
55+
)
56+
57+
# create reusable bitmask
58+
self.bitmask = llguidance.torch.allocate_token_bitmask(
59+
1, self.ll_tokenizer.vocab_size)
60+
61+
self.initialized = True
62+
63+
def __call__(
64+
self,
65+
input_ids: List[int],
66+
scores: torch.Tensor,
67+
) -> torch.Tensor:
68+
# we initialize the guidance model here
69+
# to avoid pickling ll_tokenizer and ll_interpreter
70+
self._initialize()
71+
72+
if self.new_sampling and len(input_ids) > 0:
73+
self.ll_matcher.consume_token(input_ids[-1])
74+
err = self.ll_matcher.get_error()
75+
if err:
76+
logger.warning("Error in LLMatcher: %s", err)
77+
78+
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
79+
0)
80+
llguidance.torch.apply_token_bitmask_inplace(
81+
scores, self.bitmask.to(scores.device))
82+
83+
self.new_sampling = True
84+
85+
return scores

0 commit comments

Comments
 (0)