Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/grammar_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import torch
import xgrammar

from ...bindings.executor import GuidedDecodingConfig, GuidedDecodingParams
from tensorrt_llm.llmapi.llm_args import GuidedDecodingConfig

from ...bindings.executor import GuidedDecodingParams


class GrammarMatcher(ABC):
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import torch

from tensorrt_llm.llmapi.llm_args import GuidedDecodingConfig

from ..._utils import nvtx_range
from ...bindings.executor import GuidedDecodingConfig, GuidedDecodingParams
from ...bindings.executor import GuidedDecodingParams
from ...bindings.internal.batch_manager import LlmRequestType
from ...logger import logger
from ..hostfunc import hostfunc
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import tensorrt_llm
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import get_sm_version, mpi_disabled
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy,
ContextChunkingPolicy, LoadFormat,
ContextChunkingPolicy,
GuidedDecodingConfig, LoadFormat,
TorchLlmArgs)
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase,
_llguidance_tokenizer_info,
Expand Down
23 changes: 21 additions & 2 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@
KvCacheConfig as _KvCacheConfig,
LookaheadDecodingConfig as _LookaheadDecodingConfig,
PeftCacheConfig as _PeftCacheConfig,
SchedulerConfig as _SchedulerConfig,
GuidedDecodingConfig as _GuidedDecodingConfig) # isort: skip
SchedulerConfig as _SchedulerConfig) # isort: skip
# isort: on

# yapf: enable
Expand Down Expand Up @@ -165,6 +164,26 @@ def _generate_cuda_graph_batch_sizes(max_batch_size: int,
return batch_sizes


class GuidedDecodingConfig(StrictBaseModel):

class GuidedDecodingBackend(Enum):
XGRAMMAR = 0
LLGUIDANCE = 1

backend: GuidedDecodingBackend = Field(
default=GuidedDecodingBackend.XGRAMMAR,
description="The backend for guided decoding config.")
encoded_vocab: Optional[List[str]] = Field(
default=None,
description="The encoded vocab for guided decoding config.")
tokenizer_str: Optional[str] = Field(
default=None,
description="The tokenizer string for guided decoding config.")
stop_token_ids: Optional[List[int]] = Field(
default=None,
description="The stop token ids for guided decoding config.")


class BaseSparseAttentionConfig(StrictBaseModel):
"""
Configuration for sparse attention.
Expand Down
Loading