diff --git a/omlx/admin/templates/dashboard/_settings.html b/omlx/admin/templates/dashboard/_settings.html
index 500d073f..6b1b228b 100644
--- a/omlx/admin/templates/dashboard/_settings.html
+++ b/omlx/admin/templates/dashboard/_settings.html
@@ -774,7 +774,7 @@
{{ t('settings.models.no
@@ -817,6 +817,11 @@ {{ t('settings.models.no
force_sampling
+
+
+ {{ t('settings.models.badge.speculative') }}
+
+
diff --git a/omlx/engine/batched.py b/omlx/engine/batched.py
index 8a636396..eaba0527 100644
--- a/omlx/engine/batched.py
+++ b/omlx/engine/batched.py
@@ -43,6 +43,8 @@ def __init__(
scheduler_config: Any | None = None,
stream_interval: int = 1,
enable_thinking: bool | None = None,
+ draft_model_path: str | None = None,
+ num_draft_tokens: int = 3,
):
"""
Initialize the batched engine.
@@ -53,12 +55,16 @@ def __init__(
scheduler_config: Optional scheduler configuration
stream_interval: Tokens to batch before streaming (1=every token)
enable_thinking: Enable thinking mode for reasoning models (passed to chat_template_kwargs)
+ draft_model_path: Optional draft model path for speculative decoding
+ num_draft_tokens: Number of tokens to draft per speculative step
"""
self._model_name = model_name
self._trust_remote_code = trust_remote_code
self._scheduler_config = scheduler_config
self._stream_interval = stream_interval
self._enable_thinking = enable_thinking
+ self._draft_model_path = draft_model_path
+ self._num_draft_tokens = num_draft_tokens
self._model = None
self._tokenizer = None
@@ -167,6 +173,24 @@ def _load_model_sync():
)
await self._engine.engine.start()
+
+ # Load draft model for speculative decoding
+ if self._draft_model_path:
+ def _load_draft_sync():
+ draft_model, _ = load(self._draft_model_path)
+ return draft_model
+
+ draft_model = await loop.run_in_executor(
+ get_mlx_executor(), _load_draft_sync
+ )
+ self._engine.engine.scheduler.set_draft_model(
+ draft_model, self._num_draft_tokens
+ )
+ logger.info(
+ f"Speculative decoding enabled: draft={self._draft_model_path}, "
+ f"num_draft_tokens={self._num_draft_tokens}"
+ )
+
self._loaded = True
logger.info(f"BatchedEngine loaded: {self._model_name}")
diff --git a/omlx/engine/vlm.py b/omlx/engine/vlm.py
index b009e717..69b528f6 100644
--- a/omlx/engine/vlm.py
+++ b/omlx/engine/vlm.py
@@ -124,12 +124,16 @@ def __init__(
scheduler_config: Any | None = None,
stream_interval: int = 1,
enable_thinking: bool | None = None,
+ draft_model_path: str | None = None,
+ num_draft_tokens: int = 3,
):
self._model_name = model_name
self._trust_remote_code = trust_remote_code
self._scheduler_config = scheduler_config
self._stream_interval = stream_interval
self._enable_thinking = enable_thinking
+ self._draft_model_path = draft_model_path
+ self._num_draft_tokens = num_draft_tokens
self._vlm_model = None
self._processor = None
@@ -242,6 +246,25 @@ def _load_vlm_sync():
await self._engine.engine.start()
+ # Load draft model for speculative decoding
+ if self._draft_model_path:
+ from mlx_lm import load as llm_load
+
+ def _load_draft_sync():
+ draft_model, _ = llm_load(self._draft_model_path)
+ return draft_model
+
+ draft_model = await loop.run_in_executor(
+ get_mlx_executor(), _load_draft_sync
+ )
+ self._engine.engine.scheduler.set_draft_model(
+ draft_model, self._num_draft_tokens
+ )
+ logger.info(
+ f"Speculative decoding enabled: draft={self._draft_model_path}, "
+ f"num_draft_tokens={self._num_draft_tokens}"
+ )
+
# Inject mlx-lm tool calling support into VLM tokenizer
self._inject_tool_calling(self._tokenizer)
diff --git a/omlx/engine_pool.py b/omlx/engine_pool.py
index c7a7775a..d386e8a3 100644
--- a/omlx/engine_pool.py
+++ b/omlx/engine_pool.py
@@ -90,6 +90,7 @@ def __init__(
self._current_model_memory = 0
self._scheduler_config = scheduler_config or SchedulerConfig()
self._process_memory_enforcer: object | None = None # Set by server
+ self._settings_manager: "ModelSettingsManager | None" = None # Set by apply_settings_overrides
@property
def max_model_memory(self) -> int | None:
@@ -174,6 +175,7 @@ def apply_settings_overrides(
self, settings_manager: "ModelSettingsManager"
) -> None:
"""Apply model_type_override from persisted settings to discovered entries."""
+ self._settings_manager = settings_manager
for model_id, entry in self._entries.items():
settings = settings_manager.get_settings(model_id)
if settings.model_type_override:
@@ -430,15 +432,59 @@ async def _load_engine(self, model_id: str) -> None:
engine = RerankerEngine(model_name=entry.model_path)
elif entry.engine_type == "vlm":
# VLMBatchedEngine for vision-language models
+ # Resolve speculative decoding settings
+ vlm_draft_model_path = None
+ vlm_num_draft_tokens = 3
+ if self._settings_manager:
+ vlm_settings = self._settings_manager.get_settings(model_id)
+ if vlm_settings.speculative_decoding and vlm_settings.draft_model:
+ vlm_draft_entry = self._entries.get(vlm_settings.draft_model)
+ if vlm_draft_entry and vlm_draft_entry.model_type in ("llm", "vlm"):
+ vlm_draft_model_path = vlm_draft_entry.model_path
+ vlm_num_draft_tokens = vlm_settings.num_draft_tokens or 3
+ logger.info(
+ f"Speculative decoding for VLM {model_id}: "
+ f"draft={vlm_settings.draft_model}"
+ )
+ else:
+ logger.warning(
+ f"Draft model '{vlm_settings.draft_model}' not found "
+ f"or not a valid type, skipping speculative decoding"
+ )
+
engine = VLMBatchedEngine(
model_name=entry.model_path,
scheduler_config=self._scheduler_config,
+ draft_model_path=vlm_draft_model_path,
+ num_draft_tokens=vlm_num_draft_tokens,
)
else:
# BatchedEngine with continuous batching (default)
+ # Resolve speculative decoding settings
+ draft_model_path = None
+ num_draft_tokens = 3
+ if self._settings_manager:
+ settings = self._settings_manager.get_settings(model_id)
+ if settings.speculative_decoding and settings.draft_model:
+ draft_entry = self._entries.get(settings.draft_model)
+ if draft_entry and draft_entry.model_type == "llm":
+ draft_model_path = draft_entry.model_path
+ num_draft_tokens = settings.num_draft_tokens or 3
+ logger.info(
+ f"Speculative decoding for {model_id}: "
+ f"draft={settings.draft_model}"
+ )
+ else:
+ logger.warning(
+ f"Draft model '{settings.draft_model}' not found "
+ f"or not an LLM, skipping speculative decoding"
+ )
+
engine = BatchedEngine(
model_name=entry.model_path,
scheduler_config=self._scheduler_config,
+ draft_model_path=draft_model_path,
+ num_draft_tokens=num_draft_tokens,
)
try:
diff --git a/omlx/model_discovery.py b/omlx/model_discovery.py
index e8d59563..ceeadcfa 100644
--- a/omlx/model_discovery.py
+++ b/omlx/model_discovery.py
@@ -74,7 +74,6 @@
"xlm-roberta",
"xlm_roberta",
"modernbert",
- "qwen3",
"gemma3-text",
"gemma3_text",
"siglip",
diff --git a/omlx/model_settings.py b/omlx/model_settings.py
index 6e0612f1..f41aabc6 100644
--- a/omlx/model_settings.py
+++ b/omlx/model_settings.py
@@ -50,6 +50,11 @@ class ModelSettings:
ttl_seconds: Optional[int] = None # Auto-unload after idle seconds (None = no TTL)
model_type_override: Optional[str] = None # "llm", "vlm", "embedding", "reranker", or None (auto-detect)
+ # Speculative decoding
+ speculative_decoding: bool = False
+ draft_model: Optional[str] = None # Draft model ID (from discovered models)
+ num_draft_tokens: Optional[int] = None # Number of tokens to draft (None = use default 3)
+
# Model management flags
is_pinned: bool = False
is_default: bool = False # Only one model can be default
diff --git a/omlx/scheduler.py b/omlx/scheduler.py
index dd69c760..cc98de5a 100644
--- a/omlx/scheduler.py
+++ b/omlx/scheduler.py
@@ -1078,14 +1078,387 @@ def __init__(
logger.warning(f"Error detecting Harmony model: {e}, assuming non-Harmony")
self._is_harmony_model = False
+ # Speculative decoding
+ self._draft_model: Any = None
+ self._draft_caches: Dict[int, List[Any]] = {} # uid -> draft prompt cache
+ self._num_draft_tokens: int = 3
+
# Statistics
self.num_requests_processed = 0
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
+ self._speculative_accepted_tokens = 0
+ self._speculative_total_draft_tokens = 0
# Step counter for periodic cleanup
self._step_counter = 0
+ def set_draft_model(self, draft_model: Any, num_draft_tokens: int = 3) -> None:
+ """
+ Set the draft model for speculative decoding.
+
+ When a draft model is set, single-request decode steps will use
+ speculative decoding for faster generation while preserving the
+ existing paged/SSD/prefix cache infrastructure.
+
+ Args:
+ draft_model: The draft MLX model (must share tokenizer with main model)
+ num_draft_tokens: Number of tokens to draft per speculative step
+ """
+ self._draft_model = draft_model
+ self._num_draft_tokens = num_draft_tokens
+ logger.info(
+ f"Speculative decoding enabled: num_draft_tokens={num_draft_tokens}"
+ )
+
+ def _prefill_draft_cache(self, uid: int, prompt_tokens: List[int]) -> None:
+ """
+ Prefill the draft model's cache for a request.
+
+ Called after the main model prefill so the draft cache is in sync.
+
+ Args:
+ uid: The BatchGenerator UID for this request
+ prompt_tokens: The prompt token IDs
+ """
+ if self._draft_model is None:
+ return
+
+ from mlx_lm.models.cache import make_prompt_cache
+
+ draft_cache = make_prompt_cache(self._draft_model)
+ y = mx.array(prompt_tokens, mx.uint32)
+
+ # Prefill in chunks (same as mlx-lm's _prefill pattern)
+ prefill_step_size = 512
+ while y.size > prefill_step_size:
+ self._draft_model(y[:prefill_step_size][None], cache=draft_cache)
+ mx.eval([c.state for c in draft_cache])
+ y = y[prefill_step_size:]
+ mx.clear_cache()
+
+ # Process remaining tokens (excluding last one, which will be the
+ # first decode token handled by the speculative step)
+ if y.size > 1:
+ self._draft_model(y[:-1][None], cache=draft_cache)
+ mx.eval([c.state for c in draft_cache])
+
+ # Process the last token to get draft cache fully caught up
+ if y.size >= 1:
+ self._draft_model(y[-1:][None], cache=draft_cache)
+ mx.eval([c.state for c in draft_cache])
+
+ self._draft_caches[uid] = draft_cache
+ logger.debug(f"Draft cache prefilled for UID {uid}: {len(prompt_tokens)} tokens")
+
+ def _cleanup_draft_cache(self, uid: int) -> None:
+ """Remove draft cache for a finished/aborted request."""
+ self._draft_caches.pop(uid, None)
+
+ def _can_speculative_step(self) -> bool:
+ """Check if we can do a speculative decode step."""
+ if self._draft_model is None:
+ return False
+ if not self.batch_generator or not self.batch_generator.active_batch:
+ return False
+ batch = self.batch_generator.active_batch
+ if len(batch) != 1:
+ return False
+ # Lazily prefill draft cache if not yet created
+ uid = batch.uids[0]
+ if uid not in self._draft_caches:
+ request_id = self.uid_to_request_id.get(uid)
+ if request_id is None:
+ return False
+ request = self.running.get(request_id)
+ if request is None or not request.prompt_token_ids:
+ return False
+ try:
+ self._prefill_draft_cache(uid, request.prompt_token_ids)
+ except Exception as e:
+ logger.warning(f"Failed to prefill draft cache: {e}")
+ return False
+ return True
+
+ def _has_non_trimmable_cache(self, cache_list: List[Any]) -> bool:
+ """Check if cache contains non-trimmable layers (e.g., SSM/ArraysCache)."""
+ return any(
+ not getattr(c, "is_trimmable", lambda: True)()
+ for c in cache_list
+ )
+
+ def _snapshot_cache(self, cache_list: List[Any]) -> List[Any]:
+ """Deep copy cache states for later restore.
+
+ Handles both KVCache (state is a tuple of arrays) and ArraysCache
+ (state is a list reference that must be shallow-copied, plus extra
+ attributes like left_padding and lengths).
+ """
+ snapshots = []
+ for c in cache_list:
+ state = c.state
+ if isinstance(state, list):
+ # ArraysCache: .state returns self.cache (list reference).
+ # Must copy the list so __setitem__ in the model doesn't
+ # mutate our snapshot. Also save left_padding/lengths.
+ snapshot = {
+ "state": list(state),
+ "left_padding": c.left_padding,
+ "lengths": getattr(c, "lengths", None),
+ }
+ snapshots.append(("arrays", snapshot))
+ else:
+ # KVCache / BatchKVCache: .state returns a tuple with
+ # sliced arrays that are safe from in-place mutation.
+ snapshots.append(("kv", state))
+ return snapshots
+
+ def _restore_cache(self, cache_list: List[Any], snapshots: List[Any]) -> None:
+ """Restore cache states from snapshots."""
+ for c, (kind, snap) in zip(cache_list, snapshots):
+ if kind == "arrays":
+ c.state = snap["state"]
+ c.left_padding = snap["left_padding"]
+ if hasattr(c, "lengths"):
+ c.lengths = snap["lengths"]
+ else:
+ c.state = snap
+
+ def _speculative_step(self) -> List[Any]:
+ """
+ Execute one speculative decode step with multi-token verify.
+
+ Matches mlx-lm's speculative_generate_step() pattern:
+ 1. Draft model generates N tokens one-by-one
+ 2. Main model verifies ALL N+1 tokens in a SINGLE forward pass
+ 3. Accept/reject by comparing main vs draft tokens
+ 4. Rewind caches for rejected tokens
+
+ For hybrid models with non-trimmable cache layers (e.g., SSM/DeltaNet),
+ uses snapshot/restore approach: save cache state before verify, then
+ restore and re-process only accepted tokens to keep all layers in sync.
+
+ Returns:
+ List of BatchGenerator.Response objects
+ """
+ from mlx_lm.models.cache import trim_prompt_cache
+
+ batch = self.batch_generator.active_batch
+ uid = batch.uids[0]
+ draft_cache = self._draft_caches[uid]
+
+ # Get current state
+ current_y = batch.y # shape: (1,) - pending token (not yet in cache)
+ sampler = batch.samplers[0] or self.batch_generator.sampler
+ logits_procs = batch.logits_processors[0]
+
+ num_draft = min(
+ self._num_draft_tokens,
+ batch.max_tokens[0] - batch.num_tokens[0],
+ )
+
+ if num_draft <= 0:
+ return self.batch_generator.next()
+
+ # Check if cache has non-trimmable layers (hybrid SSM models)
+ has_non_trimmable = self._has_non_trimmable_cache(batch.cache)
+
+ # === Phase 1: Generate N draft tokens ===
+ draft_tokens = []
+ draft_y = current_y[0:1] # shape: (1,)
+
+ with mx.stream(generation_stream):
+ for _ in range(num_draft):
+ draft_logits = self._draft_model(draft_y[None], cache=draft_cache)
+ draft_logits = draft_logits[:, -1, :]
+ draft_logprobs = draft_logits - mx.logsumexp(
+ draft_logits, axis=-1, keepdims=True
+ )
+ draft_y = sampler(draft_logprobs)
+ mx.async_eval(draft_y)
+ draft_tokens.append(draft_y)
+
+ draft_tokens_arr = mx.concatenate(draft_tokens) # shape: (N,)
+
+ mx.eval(draft_tokens_arr)
+ draft_list = draft_tokens_arr.tolist()
+
+ # === Phase 2: Verify ALL tokens in single main model forward pass ===
+ verify_input = mx.concatenate(
+ [current_y, draft_tokens_arr]
+ ) # shape: (N+1,)
+
+ # For hybrid models: snapshot cache before verify
+ if has_non_trimmable:
+ cache_snapshot = self._snapshot_cache(batch.cache)
+ # Materialize snapshot arrays before verify mutates the caches
+ eval_arrays = []
+ for kind, snap in cache_snapshot:
+ if kind == "arrays":
+ eval_arrays.extend(snap["state"])
+ lp = snap["left_padding"]
+ if isinstance(lp, mx.array):
+ eval_arrays.append(lp)
+ ln = snap.get("lengths")
+ if isinstance(ln, mx.array):
+ eval_arrays.append(ln)
+ else:
+ # KV state tuple: (keys, values)
+ eval_arrays.extend(snap)
+ mx.eval(eval_arrays)
+
+ with mx.stream(generation_stream):
+ main_logits = self.batch_generator.model(
+ verify_input[None], cache=batch.cache
+ )
+ main_logits = main_logits[:, -(num_draft + 1) :, :] # (1, N+1, vocab)
+
+ # Sample from all positions at once
+ if logits_procs:
+ all_logprobs = []
+ all_tokens = []
+ for i in range(num_draft + 1):
+ pos_logits = main_logits[:, i, :]
+ for proc in logits_procs:
+ pos_logits = proc(batch.tokens[0], pos_logits)
+ pos_logprobs = pos_logits - mx.logsumexp(
+ pos_logits, axis=-1, keepdims=True
+ )
+ all_tokens.append(sampler(pos_logprobs))
+ all_logprobs.append(pos_logprobs)
+ verified_tokens_arr = mx.concatenate(all_tokens)
+ verified_logprobs = all_logprobs
+ else:
+ logits_2d = main_logits.squeeze(0)
+ all_logprobs = logits_2d - mx.logsumexp(
+ logits_2d, axis=-1, keepdims=True
+ )
+ verified_tokens_arr = sampler(all_logprobs)
+ verified_logprobs = [
+ all_logprobs[i : i + 1] for i in range(num_draft + 1)
+ ]
+
+ mx.eval(verified_tokens_arr, draft_tokens_arr)
+ verified_list = verified_tokens_arr.tolist()
+ if not isinstance(verified_list, list):
+ verified_list = [verified_list]
+
+ # === Phase 3: Accept/reject ===
+ n = 0 # number of accepted draft tokens
+ while n < num_draft:
+ if verified_list[n] != draft_list[n]:
+ break
+ n += 1
+
+
+
+ # === Phase 3b: Rewind/restore caches ===
+ if has_non_trimmable:
+ # Hybrid model: restore full cache state, then re-process
+ # only accepted tokens to keep SSM + KV in sync
+ if n < num_draft:
+ # Partial/no acceptance: restore and re-process accepted tokens
+ self._restore_cache(batch.cache, cache_snapshot)
+ accepted_input = verify_input[: n + 1] # current_y + n accepted drafts
+ with mx.stream(generation_stream):
+ self.batch_generator.model(
+ accepted_input[None], cache=batch.cache
+ )
+ mx.eval([c.state for c in batch.cache])
+ # else: all accepted, cache is already correct (no need to restore)
+ else:
+ # Pure attention model: just trim KV cache
+ trim_prompt_cache(batch.cache, num_draft - n)
+
+ # Rewind draft cache.
+ # Simple KVCache.trim() only decrements offset without slicing
+ # the keys/values arrays. The next update_and_fetch() resets
+ # offset = keys.shape[-2], making the trim ineffective and
+ # leaving stale draft tokens in the cache. Fix by slicing
+ # the arrays to match the new offset after trimming.
+ draft_trim = max(num_draft - n - 1, 0)
+ if draft_trim > 0:
+ trim_prompt_cache(draft_cache, draft_trim)
+ for c in draft_cache:
+ if hasattr(c, "keys") and c.keys is not None:
+ if c.keys.shape[-2] > c.offset:
+ c.keys = c.keys[..., : c.offset, :]
+ c.values = c.values[..., : c.offset, :]
+
+ # Sync draft cache when all drafts accepted
+ if n == num_draft:
+ last_draft_y = mx.array(draft_list[-1:], mx.uint32)
+ with mx.stream(generation_stream):
+ self._draft_model(last_draft_y[None], cache=draft_cache)
+ mx.eval([c.state for c in draft_cache])
+
+ # === Phase 4: Build responses ===
+ responses = []
+ stop_tokens = self.batch_generator.stop_tokens
+
+ emit_tokens = [current_y[0].item()]
+ emit_logprobs = [batch.logprobs[0]]
+ for i in range(n):
+ emit_tokens.append(draft_list[i])
+ emit_logprobs.append(verified_logprobs[i].squeeze(0))
+
+ for token_val, lp in zip(emit_tokens, emit_logprobs):
+ batch.tokens[0] = mx.concatenate(
+ (batch.tokens[0], mx.array([token_val], mx.uint32))
+ )
+ batch.num_tokens[0] += 1
+
+ finish_reason = None
+ if token_val in stop_tokens:
+ finish_reason = "stop"
+ elif batch.num_tokens[0] >= batch.max_tokens[0]:
+ finish_reason = "length"
+
+ responses.append(
+ BatchGenerator.Response(
+ uid=uid,
+ token=token_val,
+ logprobs=lp,
+ finish_reason=finish_reason,
+ prompt_cache=None,
+ )
+ )
+
+ if finish_reason is not None:
+ break
+
+ # Update stats
+ self._speculative_accepted_tokens += n
+ self._speculative_total_draft_tokens += num_draft
+ cumulative_rate = (
+ self._speculative_accepted_tokens / self._speculative_total_draft_tokens
+ if self._speculative_total_draft_tokens > 0
+ else 0.0
+ )
+
+ logger.debug(
+ f"Speculative: {n}/{num_draft} accepted, "
+ f"cumulative={self._speculative_accepted_tokens}/"
+ f"{self._speculative_total_draft_tokens} "
+ f"({cumulative_rate:.1%})"
+ )
+
+ # Update batch state for next step
+ if responses:
+ last_response = responses[-1]
+ if last_response.finish_reason is not None:
+ batch.y = mx.array([last_response.token], mx.uint32)
+ batch.logprobs = [last_response.logprobs]
+ last_response.prompt_cache = batch.extract_cache(0)
+ self.batch_generator.active_batch = None
+ else:
+ batch.y = mx.array([verified_list[n]], mx.uint32)
+ batch.logprobs = [verified_logprobs[n].squeeze(0)]
+
+ mx.async_eval(batch.y, *batch.logprobs, *[t for t in batch.tokens])
+
+ return responses
+
def _calculate_max_blocks(self) -> int:
"""
Calculate maximum cache blocks for paged SSD-only mode.
@@ -2832,10 +3205,25 @@ def _process_batch_responses(
self.total_completion_tokens += request.num_output_tokens
self.num_requests_processed += 1
- logger.debug(
- f"Request {request_id} finished: {response.finish_reason}, "
- f"{request.num_output_tokens} tokens"
- )
+ if self._speculative_total_draft_tokens > 0:
+ rate = (
+ self._speculative_accepted_tokens
+ / self._speculative_total_draft_tokens
+ )
+ logger.debug(
+ f"Request {request_id} finished: {response.finish_reason}, "
+ f"{request.num_output_tokens} tokens, "
+ f"speculative accept rate: "
+ f"{self._speculative_accepted_tokens}/"
+ f"{self._speculative_total_draft_tokens} ({rate:.1%})"
+ )
+ self._speculative_accepted_tokens = 0
+ self._speculative_total_draft_tokens = 0
+ else:
+ logger.debug(
+ f"Request {request_id} finished: {response.finish_reason}, "
+ f"{request.num_output_tokens} tokens"
+ )
logger.log(5, "Request %s generated text:\n%s", request_id, output.output_text)
outputs.append(output)
@@ -2965,6 +3353,10 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
del self.uid_to_request_id[uid]
del self.request_id_to_uid[request_id]
+ # Clean up draft cache for speculative decoding
+ if uid is not None:
+ self._cleanup_draft_cache(uid)
+
# Clean up streaming detokenizer
self._cleanup_detokenizer(request_id)
@@ -3009,6 +3401,7 @@ def _recover_from_cache_error(self) -> None:
if self._boundary_snapshot_store is not None:
self._boundary_snapshot_store.cleanup_all()
self._boundary_snapshot_required = None
+ self._draft_caches.clear()
# Clear stale VLM position state to prevent re-corruption on retry
if hasattr(self.model, "clear_vlm_position_state"):
@@ -3082,7 +3475,11 @@ def step(self, max_retries: int = 1) -> SchedulerOutput:
# Run generation step if we have running requests
if self.batch_generator is not None and self.running:
- responses = self.batch_generator.next()
+ # Use speculative decoding when single request + draft model
+ if self._can_speculative_step():
+ responses = self._speculative_step()
+ else:
+ responses = self.batch_generator.next()
output.has_work = True
if responses:
@@ -3175,6 +3572,21 @@ def get_stats(self) -> Dict[str, Any]:
# Include cache stats
if self.block_aware_cache is not None:
stats["ssd_cache"] = self.block_aware_cache.get_stats()
+ # Speculative decoding stats
+ if self._draft_model is not None:
+ acceptance_rate = 0.0
+ if self._speculative_total_draft_tokens > 0:
+ acceptance_rate = (
+ self._speculative_accepted_tokens
+ / self._speculative_total_draft_tokens
+ )
+ stats["speculative"] = {
+ "enabled": True,
+ "num_draft_tokens": self._num_draft_tokens,
+ "total_draft_tokens": self._speculative_total_draft_tokens,
+ "accepted_tokens": self._speculative_accepted_tokens,
+ "acceptance_rate": round(acceptance_rate, 3),
+ }
return stats
def get_cache_stats(self) -> Optional[Dict[str, Any]]:
@@ -3204,6 +3616,7 @@ def reset(self) -> None:
if self._boundary_snapshot_store is not None:
self._boundary_snapshot_store.cleanup_all()
self._boundary_snapshot_required = None
+ self._draft_caches.clear()
# Clear caches
if self.block_aware_cache is not None:
diff --git a/tests/test_model_settings.py b/tests/test_model_settings.py
index 24108a3c..1f0739c2 100644
--- a/tests/test_model_settings.py
+++ b/tests/test_model_settings.py
@@ -144,6 +144,47 @@ def test_ttl_seconds_excluded_when_none(self):
d = settings.to_dict()
assert "ttl_seconds" not in d
+ def test_speculative_decoding_defaults(self):
+ """Test speculative decoding fields default values."""
+ settings = ModelSettings()
+ assert settings.speculative_decoding is False
+ assert settings.draft_model is None
+ assert settings.num_draft_tokens is None
+
+ def test_speculative_decoding_to_dict(self):
+ """Test speculative decoding fields in to_dict."""
+ settings = ModelSettings(
+ speculative_decoding=True,
+ draft_model="small-model",
+ num_draft_tokens=5,
+ )
+ d = settings.to_dict()
+ assert d["speculative_decoding"] is True
+ assert d["draft_model"] == "small-model"
+ assert d["num_draft_tokens"] == 5
+
+ def test_speculative_decoding_excluded_when_defaults(self):
+ """Test speculative decoding None fields excluded from to_dict when default."""
+ settings = ModelSettings()
+ d = settings.to_dict()
+ # speculative_decoding=False is not None, so it IS included (like force_sampling)
+ assert d.get("speculative_decoding") is False
+ assert "draft_model" not in d
+ assert "num_draft_tokens" not in d
+
+ def test_speculative_decoding_roundtrip(self):
+ """Test speculative decoding fields survive to_dict -> from_dict roundtrip."""
+ original = ModelSettings(
+ speculative_decoding=True,
+ draft_model="draft-0.6b",
+ num_draft_tokens=4,
+ )
+ d = original.to_dict()
+ restored = ModelSettings.from_dict(d)
+ assert restored.speculative_decoding is True
+ assert restored.draft_model == "draft-0.6b"
+ assert restored.num_draft_tokens == 4
+
def test_model_type_override_default(self):
"""Test model_type_override defaults to None."""
settings = ModelSettings()
@@ -341,6 +382,25 @@ def test_forced_ct_kwargs_persist(self):
assert loaded.forced_ct_kwargs == ["enable_thinking"]
assert loaded.chat_template_kwargs == {"enable_thinking": False}
+ def test_speculative_decoding_persist(self):
+ """Test speculative decoding settings survive save/load cycle."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ manager = ModelSettingsManager(Path(tmpdir))
+
+ settings = ModelSettings(
+ speculative_decoding=True,
+ draft_model="qwen3-0.6b",
+ num_draft_tokens=5,
+ )
+ manager.set_settings("test-model", settings)
+
+ # Reload from file
+ manager2 = ModelSettingsManager(Path(tmpdir))
+ loaded = manager2.get_settings("test-model")
+ assert loaded.speculative_decoding is True
+ assert loaded.draft_model == "qwen3-0.6b"
+ assert loaded.num_draft_tokens == 5
+
def test_model_type_override_persist(self):
"""Test model_type_override survives save/load cycle."""
with tempfile.TemporaryDirectory() as tmpdir: