Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions omlx/admin/i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
"settings.models.badge.rep_penalty": "rep_penalty:",
"settings.models.badge.tool_result_tokens": "tool_result_tokens:",
"settings.models.badge.force_sampling": "force_sampling",
"settings.models.badge.speculative": "speculative",

"modal.model_settings.section_label": "Model Settings",
"modal.model_settings.model_type": "Model Type",
Expand All @@ -304,6 +305,11 @@
"modal.model_settings.limit_tool_placeholder": "e.g. 2000",
"modal.model_settings.force_sampling": "Force Sampling",
"modal.model_settings.force_sampling_hint": "Override request sampling parameters with configured values",
"modal.model_settings.speculative_decoding": "Speculative Decoding",
"modal.model_settings.speculative_decoding_hint": "Speed up decoding with a smaller draft model",
"modal.model_settings.draft_model": "Draft Model",
"modal.model_settings.select_draft_model": "Select draft model",
"modal.model_settings.num_draft_tokens": "Draft Tokens",
"modal.model_settings.chat_template_kwargs": "Chat Template Kwargs",
"modal.model_settings.chat_template_kwargs_hint": "Parameters passed to chat template",
"modal.model_settings.add_kwarg": "+ Add",
Expand Down
6 changes: 6 additions & 0 deletions omlx/admin/i18n/ja.json
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
"settings.models.badge.rep_penalty": "rep_penalty:",
"settings.models.badge.tool_result_tokens": "tool_result_tokens:",
"settings.models.badge.force_sampling": "force_sampling",
"settings.models.badge.speculative": "speculative",

"modal.model_settings.section_label": "モデル設定",
"modal.model_settings.model_type": "モデルタイプ",
Expand All @@ -304,6 +305,11 @@
"modal.model_settings.limit_tool_placeholder": "例: 2000",
"modal.model_settings.force_sampling": "強制サンプリング",
"modal.model_settings.force_sampling_hint": "設定値でリクエストのサンプリングパラメータを上書きします",
"modal.model_settings.speculative_decoding": "投機的デコーディング",
"modal.model_settings.speculative_decoding_hint": "小さなドラフトモデルでデコード速度を向上",
"modal.model_settings.draft_model": "ドラフトモデル",
"modal.model_settings.select_draft_model": "ドラフトモデルを選択",
"modal.model_settings.num_draft_tokens": "ドラフトトークン数",
"modal.model_settings.chat_template_kwargs": "チャットテンプレート引数",
"modal.model_settings.chat_template_kwargs_hint": "チャットテンプレートに渡すパラメータ",
"modal.model_settings.add_kwarg": "+ 追加",
Expand Down
6 changes: 6 additions & 0 deletions omlx/admin/i18n/ko.json
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
"settings.models.badge.rep_penalty": "rep_penalty:",
"settings.models.badge.tool_result_tokens": "tool_result_tokens:",
"settings.models.badge.force_sampling": "force_sampling",
"settings.models.badge.speculative": "speculative",

"modal.model_settings.section_label": "모델 설정",
"modal.model_settings.model_type": "모델 타입",
Expand All @@ -304,6 +305,11 @@
"modal.model_settings.limit_tool_placeholder": "예: 2000",
"modal.model_settings.force_sampling": "강제 샘플링",
"modal.model_settings.force_sampling_hint": "설정된 값으로 요청의 샘플링 파라미터를 덮어씁니다",
"modal.model_settings.speculative_decoding": "추측 디코딩",
"modal.model_settings.speculative_decoding_hint": "작은 드래프트 모델로 디코딩 속도 향상",
"modal.model_settings.draft_model": "드래프트 모델",
"modal.model_settings.select_draft_model": "드래프트 모델 선택",
"modal.model_settings.num_draft_tokens": "드래프트 토큰 수",
"modal.model_settings.chat_template_kwargs": "채팅 템플릿 피라미터 설정",
"modal.model_settings.chat_template_kwargs_hint": "채팅 템플릿에 전달되는 파라미터",
"modal.model_settings.add_kwarg": "+ 추가",
Expand Down
6 changes: 6 additions & 0 deletions omlx/admin/i18n/zh.json
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@
"settings.models.badge.rep_penalty": "rep_penalty:",
"settings.models.badge.tool_result_tokens": "tool_result_tokens:",
"settings.models.badge.force_sampling": "force_sampling",
"settings.models.badge.speculative": "speculative",
"modal.model_settings.section_label": "模型设置",
"modal.model_settings.model_type": "模型类型",
"modal.model_settings.model_type_auto": "自动检测",
Expand All @@ -280,6 +281,11 @@
"modal.model_settings.limit_tool_placeholder": "例如 2000",
"modal.model_settings.force_sampling": "强制采样",
"modal.model_settings.force_sampling_hint": "用配置值覆盖请求中的采样参数",
"modal.model_settings.speculative_decoding": "推测解码",
"modal.model_settings.speculative_decoding_hint": "使用小型草稿模型加速解码",
"modal.model_settings.draft_model": "草稿模型",
"modal.model_settings.select_draft_model": "选择草稿模型",
"modal.model_settings.num_draft_tokens": "草稿令牌数",
"modal.model_settings.chat_template_kwargs": "聊天模板参数",
"modal.model_settings.chat_template_kwargs_hint": "传递给聊天模板的参数",
"modal.model_settings.add_kwarg": "+ 添加",
Expand Down
56 changes: 51 additions & 5 deletions omlx/admin/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class ModelSettingsRequest(BaseModel):
chat_template_kwargs: Optional[Dict[str, Any]] = None
forced_ct_kwargs: Optional[list[str]] = None
ttl_seconds: Optional[int] = None
speculative_decoding: Optional[bool] = None
draft_model: Optional[str] = None
num_draft_tokens: Optional[int] = None
is_pinned: Optional[bool] = None
is_default: Optional[bool] = None

Expand Down Expand Up @@ -1015,6 +1018,9 @@ async def list_models(is_admin: bool = Depends(require_admin)):
"chat_template_kwargs": settings.chat_template_kwargs,
"forced_ct_kwargs": settings.forced_ct_kwargs,
"ttl_seconds": settings.ttl_seconds,
"speculative_decoding": settings.speculative_decoding,
"draft_model": settings.draft_model,
"num_draft_tokens": settings.num_draft_tokens,
"is_pinned": settings.is_pinned,
"is_default": settings.is_default,
"display_name": settings.display_name,
Expand Down Expand Up @@ -1170,6 +1176,12 @@ async def update_model_settings(
current_settings.forced_ct_kwargs = request.forced_ct_kwargs
if "ttl_seconds" in sent:
current_settings.ttl_seconds = request.ttl_seconds
if "speculative_decoding" in sent:
current_settings.speculative_decoding = request.speculative_decoding or False
if "draft_model" in sent:
current_settings.draft_model = request.draft_model or None
if "num_draft_tokens" in sent:
current_settings.num_draft_tokens = request.num_draft_tokens
if request.is_pinned is not None:
current_settings.is_pinned = request.is_pinned
# Also update the engine pool entry
Expand All @@ -1184,15 +1196,21 @@ async def update_model_settings(
settings_manager.set_settings(model_id, current_settings)

# Warn if engine type actually changed while model is loaded
speculative_changed = (
entry.engine is not None
and any(k in sent for k in ("speculative_decoding", "draft_model", "num_draft_tokens"))
)
requires_reload = (
"model_type_override" in sent
and entry.engine is not None
and entry.engine_type != prev_engine_type
entry.engine is not None
and (
("model_type_override" in sent and entry.engine_type != prev_engine_type)
or speculative_changed
)
)
if requires_reload:
logger.info(
f"Model type changed for loaded model {model_id} "
f"(now {entry.model_type}/{entry.engine_type}). "
f"Settings changed for loaded model {model_id} "
f"(engine_type={entry.engine_type}). "
f"Reload required to take effect."
)

Expand All @@ -1206,6 +1224,34 @@ async def update_model_settings(
}


@router.get("/api/models/{model_id}/draft_candidates")
async def get_draft_candidates(
model_id: str,
is_admin: bool = Depends(require_admin),
):
"""Get list of LLM models that can serve as draft models for speculative decoding."""
engine_pool = _get_engine_pool()
if engine_pool is None:
raise HTTPException(status_code=503, detail="Server not initialized")

entry = engine_pool.get_entry(model_id)
if entry is None:
raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")

candidates = []
for mid in engine_pool.get_model_ids():
if mid == model_id:
continue
e = engine_pool.get_entry(mid)
if e and e.model_type in ("llm", "vlm"):
candidates.append({
"model_id": mid,
"estimated_size": e.estimated_size,
})

return {"candidates": candidates}


@router.get("/api/models/{model_id}/generation_config")
async def get_generation_config(
model_id: str,
Expand Down
26 changes: 26 additions & 0 deletions omlx/admin/static/js/dashboard.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@
enableToolResultLimit: false,
max_tool_result_tokens: null,
ctKwargEntries: [],
speculative_decoding: false,
draft_model: '',
num_draft_tokens: 3,
},
draftModelCandidates: [],
savingModelSettings: false,
loadingGenDefaults: false,

Expand Down Expand Up @@ -518,9 +522,28 @@
max_tool_result_tokens: settings.max_tool_result_tokens || null,
ttl_seconds: settings.ttl_seconds ?? null,
ctKwargEntries,
speculative_decoding: settings.speculative_decoding || false,
draft_model: settings.draft_model || '',
num_draft_tokens: settings.num_draft_tokens ?? 3,
};
this.showModelSettingsModal = true;
this.$nextTick(() => lucide.createIcons());

// Fetch draft model candidates for speculative decoding.
// Must load candidates BEFORE the select renders, otherwise
// Alpine resets x-model to '' when no matching option exists.
const savedDraftModel = settings.draft_model || '';
if (model.model_type === 'llm' || model.model_type === 'vlm') {
fetch(`/admin/api/models/${encodeURIComponent(model.id)}/draft_candidates`)
.then(r => r.json())
.then(data => {
this.draftModelCandidates = data.candidates || [];
this.modelSettings.draft_model = savedDraftModel;
})
.catch(() => { this.draftModelCandidates = []; });
} else {
this.draftModelCandidates = [];
}
},

async saveModelSettings() {
Expand Down Expand Up @@ -569,6 +592,9 @@
? chatTemplateKwargs : null,
forced_ct_kwargs: forcedCtKwargs.length > 0
? forcedCtKwargs : null,
speculative_decoding: this.modelSettings.speculative_decoding,
draft_model: this.modelSettings.draft_model || null,
num_draft_tokens: this.modelSettings.num_draft_tokens || null,
};
})()),
});
Expand Down
37 changes: 37 additions & 0 deletions omlx/admin/templates/dashboard/_modal_model_settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,43 @@ <h4 class="text-xl font-bold tracking-tight text-neutral-900" x-text="selectedMo
</button>
</div>

<!-- Speculative Decoding (LLM/VLM only) -->
<template x-if="selectedModel && (selectedModel.model_type === 'llm' || selectedModel.model_type === 'vlm')">
<div class="p-4 bg-neutral-50 rounded-xl space-y-3">
<div class="flex items-center justify-between">
<div>
<span class="text-sm font-medium text-neutral-700">{{ t('modal.model_settings.speculative_decoding') }}</span>
<p class="text-xs text-neutral-500 mt-0.5">{{ t('modal.model_settings.speculative_decoding_hint') }}</p>
</div>
<button type="button" @click="modelSettings.speculative_decoding = !modelSettings.speculative_decoding"
:class="modelSettings.speculative_decoding ? 'bg-black' : 'bg-neutral-200'"
class="relative w-11 h-6 rounded-full transition-colors duration-300 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-black">
<span :class="modelSettings.speculative_decoding ? 'translate-x-5' : 'translate-x-0'"
class="block w-5 h-5 bg-white rounded-full shadow-sm transform transition-transform duration-300 absolute top-0.5 left-0.5"></span>
</button>
</div>
<template x-if="modelSettings.speculative_decoding">
<div class="space-y-3 pt-2 border-t border-neutral-200">
<div>
<label class="text-xs font-medium text-neutral-600 mb-1 block">{{ t('modal.model_settings.draft_model') }}</label>
<select x-model="modelSettings.draft_model"
class="w-full px-4 py-2.5 border border-neutral-200 rounded-xl text-sm focus:ring-2 focus:ring-neutral-900 focus:border-transparent transition-all bg-white">
<option value="">{{ t('modal.model_settings.select_draft_model') }}</option>
<template x-for="c in draftModelCandidates" :key="c.model_id">
<option :value="c.model_id" x-text="c.model_id"></option>
</template>
</select>
</div>
<div>
<label class="text-xs font-medium text-neutral-600 mb-1 block">{{ t('modal.model_settings.num_draft_tokens') }}</label>
<input type="number" x-model.number="modelSettings.num_draft_tokens" min="1" max="10" placeholder="3"
class="w-full px-4 py-2.5 border border-neutral-200 rounded-xl text-sm focus:ring-2 focus:ring-neutral-900 focus:border-transparent transition-all">
</div>
</div>
</template>
</div>
</template>

<!-- Chat Template Kwargs -->
<div class="p-4 bg-neutral-50 rounded-xl space-y-3">
<div class="flex items-center justify-between">
Expand Down
7 changes: 6 additions & 1 deletion omlx/admin/templates/dashboard/_settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ <h4 class="text-lg font-semibold text-neutral-900 mb-2">{{ t('settings.models.no
</div>

<!-- Sampling Settings Row (if any settings configured) - hidden for embedding/reranker models -->
<div x-show="(!model.model_type || model.model_type === 'llm' || model.model_type === 'vlm') && model.settings && (model.settings.max_context_window || model.settings.max_tokens || model.settings.temperature !== null || model.settings.top_p !== null || model.settings.top_k !== null || model.settings.repetition_penalty !== null || model.settings.force_sampling || model.settings.max_tool_result_tokens)"
<div x-show="(!model.model_type || model.model_type === 'llm' || model.model_type === 'vlm') && model.settings && (model.settings.max_context_window || model.settings.max_tokens || model.settings.temperature !== null || model.settings.top_p !== null || model.settings.top_k !== null || model.settings.repetition_penalty !== null || model.settings.force_sampling || model.settings.max_tool_result_tokens || model.settings.speculative_decoding)"
class="px-6 pb-4 -mt-2">
<div class="ml-11 flex flex-wrap gap-2">
<template x-if="model.settings?.max_context_window">
Expand Down Expand Up @@ -817,6 +817,11 @@ <h4 class="text-lg font-semibold text-neutral-900 mb-2">{{ t('settings.models.no
force_sampling
</span>
</template>
<template x-if="model.settings?.speculative_decoding">
<span class="inline-flex items-center px-2 py-0.5 text-[10px] font-medium rounded bg-purple-50 text-purple-700 border border-purple-200">
{{ t('settings.models.badge.speculative') }}
</span>
</template>
<template x-if="model.settings?.chat_template_kwargs">
<template x-for="[k, v] in Object.entries(model.settings.chat_template_kwargs)" :key="k">
<span class="inline-flex items-center px-2 py-0.5 text-[10px] font-medium rounded bg-indigo-50 text-indigo-700 border border-indigo-200">
Expand Down
24 changes: 24 additions & 0 deletions omlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
scheduler_config: Any | None = None,
stream_interval: int = 1,
enable_thinking: bool | None = None,
draft_model_path: str | None = None,
num_draft_tokens: int = 3,
):
"""
Initialize the batched engine.
Expand All @@ -53,12 +55,16 @@ def __init__(
scheduler_config: Optional scheduler configuration
stream_interval: Tokens to batch before streaming (1=every token)
enable_thinking: Enable thinking mode for reasoning models (passed to chat_template_kwargs)
draft_model_path: Optional draft model path for speculative decoding
num_draft_tokens: Number of tokens to draft per speculative step
"""
self._model_name = model_name
self._trust_remote_code = trust_remote_code
self._scheduler_config = scheduler_config
self._stream_interval = stream_interval
self._enable_thinking = enable_thinking
self._draft_model_path = draft_model_path
self._num_draft_tokens = num_draft_tokens

self._model = None
self._tokenizer = None
Expand Down Expand Up @@ -167,6 +173,24 @@ def _load_model_sync():
)

await self._engine.engine.start()

# Load draft model for speculative decoding
if self._draft_model_path:
def _load_draft_sync():
draft_model, _ = load(self._draft_model_path)
return draft_model

draft_model = await loop.run_in_executor(
get_mlx_executor(), _load_draft_sync
)
self._engine.engine.scheduler.set_draft_model(
draft_model, self._num_draft_tokens
)
logger.info(
f"Speculative decoding enabled: draft={self._draft_model_path}, "
f"num_draft_tokens={self._num_draft_tokens}"
)

self._loaded = True
logger.info(f"BatchedEngine loaded: {self._model_name}")

Expand Down
Loading