Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
409da24
Extend on-device sampling support for dual QPC VLMs
quic-xiyushi Oct 23, 2025
e06e175
Fix random_numbers shape
quic-xiyushi Oct 30, 2025
3e242ce
Update example with new random sampling logic
quic-xiyushi Oct 30, 2025
1a01d57
Update to align with recent VLM CB changes
quic-xiyushi Nov 11, 2025
30d6061
Update tests with new random sampling logic
Nov 11, 2025
78ef180
Add code to perform guided decoding
Nov 11, 2025
1fafcdb
Add bitmask to example inputs and dynamic axes
Nov 12, 2025
18ab856
Rename bitmask to token_bitmasks
Nov 12, 2025
b1c049c
Fix typo
Nov 12, 2025
e16e846
Merge branch 'main' into guided_decoding_simple
Nov 19, 2025
1515497
Add flag to enable guided decoding
Nov 19, 2025
d02d04d
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Nov 19, 2025
97e4baf
Add flag to enable guided decoding
Nov 19, 2025
7b7677b
Update test_sampler_transform for guided decoding
Nov 19, 2025
7cf106e
Refactor
quic-xiyushi Nov 19, 2025
45aed11
Add unit tests
quic-xiyushi Nov 20, 2025
6273ab5
Clean up
quic-xiyushi Nov 20, 2025
ef9ae14
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Nov 20, 2025
60312b3
Add test for guided decoding
Nov 20, 2025
3789d5a
Update test_sampler.py
quic-xiyushi Nov 20, 2025
251099f
Merge branch 'on-device-sampling-vlm' into guided_decoding_simple
Nov 20, 2025
a24a55d
Enable guided decoding in vlm generation
Nov 20, 2025
55e76e9
Fix bug
Nov 20, 2025
f9355d4
Fix bug
Nov 20, 2025
5e2afb7
Fix hash for VLM's language decoder to include qaic_config
quic-xiyushi Nov 21, 2025
e672701
Merge branch 'on-device-sampling-vlm' into guided_decoding_simple
Nov 21, 2025
eee5314
Enable guided decoding test for vlms
Nov 21, 2025
60cf5ec
Use different config for each vlm
Nov 21, 2025
a71ee65
Update type
Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def cloud_ai_100_exec_kv(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
Expand Down Expand Up @@ -356,6 +357,8 @@ def cloud_ai_100_exec_kv(
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
:include_guided_decoding (bool, default=False): If True, enables guided token-level filtering
during decoding. Only works when `include_sampler`=True.
sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
Expand Down Expand Up @@ -394,6 +397,7 @@ def cloud_ai_100_exec_kv(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)

Expand Down Expand Up @@ -442,6 +446,7 @@ def __init__(
is_tlm: Optional[int] = None,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
activate: bool = True,
) -> None:
Expand All @@ -451,6 +456,7 @@ def __init__(
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
self.include_guided_decoding = include_guided_decoding
self.sampling_params = sampling_params
self._qpc_path = qpc_path # Store qpc_path for later use

Expand All @@ -461,7 +467,9 @@ def __init__(

# Validate sampler inputs for On-Device Sampling
self.include_sampler = validate_sampler_inputs(
session_inputs=set(self._session.input_names), include_sampler=include_sampler
session_inputs=set(self._session.input_names),
include_sampler=include_sampler,
include_guided_decoding=include_guided_decoding,
)

# Fetch the variables from the QPC
Expand Down Expand Up @@ -628,7 +636,7 @@ def prepare_decode_inputs(self):
decode_inputs["batch_index"] = self.batch_index
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if self.batch_index is not None:
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
else:
Expand Down Expand Up @@ -795,7 +803,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["num_logits_to_keep"] = np.zeros((1, 1))
if self.include_sampler:
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if decode_batch_id is not None:
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
Expand Down Expand Up @@ -1067,6 +1075,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._qaic_model = QEffTextGenerationBase(
Expand All @@ -1082,6 +1091,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
self._full_batch_size = self._qaic_model.full_batch_size
Expand Down
17 changes: 17 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
write_io_files,
)
from QEfficient.utils import LRUCache
from QEfficient.utils.constants import Constants
from QEfficient.utils.logging_utils import logger


Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -110,6 +112,7 @@ def __init__(
is_tlm: Target language model flag
include_sampler: Enable on-device sampling (new feature)
return_pdfs: Return probability distributions
include_guided_decoding: Enable guided decoding in on-device sampling
sampling_params: Sampling parameters for on-device sampling
"""
# Validate required parameters
Expand All @@ -133,6 +136,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
activate=False, # vision components need to be initialized first
)
Expand Down Expand Up @@ -303,6 +307,13 @@ def _execute_chunked_prefill(
prefill_ccl_id = 0
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

if self.include_sampler:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if decode_batch_id is not None:
lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
lang_inputs[op] = self.sampling_params[op]

for i in range(num_chunks):
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
position_ids_slice = lang_inputs["position_ids"][
Expand All @@ -328,6 +339,11 @@ def _execute_chunked_prefill(

chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]

if self.include_sampler:
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
chunk_inputs[op] = lang_inputs[op]

outputs = self._session.run(chunk_inputs)

if "image_idx_output" in outputs:
Expand Down Expand Up @@ -780,6 +796,7 @@ def generate_stream_tokens(
is_tlm=self.is_tlm,
include_sampler=self.include_sampler,
return_pdfs=self.return_pdfs,
include_guided_decoding=self.include_guided_decoding,
sampling_params=self.sampling_params,
)

Expand Down
Loading