Skip to content

Sync to ml-explore/mlx-swift-lm #22

Open
Gajesh2007 wants to merge 70 commits into
Layr-Labs:mainfrom
ml-explore:main
Open

Sync to ml-explore/mlx-swift-lm #22
Gajesh2007 wants to merge 70 commits into
Layr-Labs:mainfrom
ml-explore:main

Conversation

@Gajesh2007

Copy link
Copy Markdown
Member

No description provided.

dirvine and others added 30 commits May 7, 2026 13:58
#149)

Qwen35Language.LanguageModel.callAsFunction assumes inputs is always 2D
[batch, seq], but text-only callers like WiredMemoryUtils.tune and
TokenIterator can pass 1D [seq] token arrays. This causes
getRopeIndex() and subsequent dim(1) calls to crash with
"SmallVector out of range" when accessing a non-existent dimension.

Add an ndim check at the top of callAsFunction to expand 1D inputs
to 2D before any dimension-dependent logic runs.

Fixes #148
* Add coherence integration tests

* Consolidate task registration
#170)

`TokenRing.loadPrompt` used `prompt.dim(0)` to count tokens, which
returns 1 for VLM models that pass [1, n]-shaped prompts. This caused
the ring buffer to be incorrectly sized, leading to a broadcast shape
crash on the next `append` call during generation.

Flatten the prompt to 1D upfront so the token count and all downstream
slicing work correctly regardless of input shape.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The tools schema was passed to the chat template (so the model knew
about the tools) but never forwarded to the ToolCallProcessor that
parses the response. Without it, array and other non-string parameter
types in XML tool call formats (Qwen 3.5) were returned as raw strings
instead of being decoded to their proper types.

Fixes #159.
…#174)

GlmOcr never forwarded tools or additionalContext to applyChatTemplate
(added after #140, missed the pattern). SmolVLM2's video path was also
missing them while its other two branches had them.
* fix: use preconcurrency CoreImage import
feat: add ParoQuant model support
                                                                                                
Load PARO-quantized models (AutoAWQ format) with pairwise Givens rotation applied at runtime via a
Metal kernel.                                                                                    
                   
Key additions:                                                                                    
- RotateQuantizedLinear: Metal kernel for pairwise Givens rotation + quantized matmul. Rotation
state derived once at load time (thread-safe under concurrent inference).                         
- ParoQuantLoader: AutoAWQ→MLX weight conversion, rotation layer patching, fused in_proj_ba
splitting for Mamba projections.                                                                  
- maybeQuantizeKVCache fix: handle hybrid caches (attention + Mamba) where cache[0] isn't a KV    
attention cache.                                                                              
- Unit tests covering pair packing, AWQ conversion, quantization round-trip, and concurrent       
safety.
* Handle stringified JSON tool call arguments
* fix build
)

MLXArray(-Float.infinity) creates a float32 scalar. which() then promotes
the entire bf16 segsum output to fp32, doubling memory per SSM layer.

At L=2048 this wastes ~960MB on Qwen3.5-35B (30 GDN layers, 4 heads)
and ~24GB on Nemotron-30B (48 Mamba layers, 64 heads) per prefill.

Fix: match the -inf scalar to the accumulator dtype.
* feat: expose speculative decoding in ChatSession (#181)
* feat: add generateTask overload for SpeculativeTokenIterator

Add SpeculativeDecodingConfig struct and wire SpeculativeTokenIterator
into ChatSession.streamMap, enabling ~2-3x generation speedup with
no API break for existing callers.

Exposes a public generateTask(iterator: consuming SpeculativeTokenIterator)
overload so that ChatSession can obtain the (stream, task) pair needed
for clean early-termination handling — matching the pattern used by
the existing TokenIterator overload.

---------

Co-authored-by: David Koski <46639364+davidkoski@users.noreply.github.com>
Fix EmbeddingGemma weight loading (`sanitize(weights:)`)

Two bugs prevented loading any `mlx-community/embeddinggemma-*` checkpoint:

1. **Init-order crash** — `sanitize(weights:)` assigned the dense head via `self._dense.wrappedValue = ...`, which fatals when `module != nil`. Route through `update(modules:)` instead.

2. **Wrong hidden size** — the dense head was initialized with `config.intermediateSize` (the backbone MLP dim), but the actual checkpoint expands 4× to `hiddenSize * 4`. Read the dim directly from `dense.0.weight` shape instead.

---
Co-authored-by: Anatoly Samoilenko anatoly.samoilenko@gmail.com
Co-authored-by: David Koski dkoski@apple.com
…ustion (#226)

CIContext() default options cache IOSurface-backed GPU textures for intermediate
filter results. Each frame in the VLM image pipeline (tone-curve → bicubic
resample → color-matrix normalize) allocates multiple IOSurfaces that are held
in the context cache across calls.

Processing a large library (hundreds of videos, each yielding several frames)
accumulates thousands of cached surfaces and hits the macOS per-process
IOSurface kernel limit of 16384, causing render failures:

  IOSurface creation failed: e00002be (likely per client IOSurface limit of
  16384 reached)
  -[CIContext _startTaskToRender:...] Render failed because of failure to
  allocate intermediate.

Fix: pass `.cacheIntermediates: false` when constructing the shared context.
Batch inference never re-renders the same CIImage twice, so the intermediate
cache provides no benefit. The change eliminates the surface accumulation while
leaving all functional behavior intact.

Co-authored-by: Vladimir <vladimir@sinitcin.com>
Softmax was applied to all 128 expert scores before top-k selection, so
selected weights came from the full distribution rather than just the
top-k. Moved softmax to after selection so it operates only on the chosen
experts. Drops the renormalization step since softmax on top-k already
produces a valid distribution.

Also fuses norm + scale from 3 dispatches into one MLXFast.rmsNorm call.
…227)

SwitchGLU variant for models that ship a single fused gate_up_proj weight
of shape [numExperts, 2*hiddenDims, inputDims] instead of separate
gate_proj / up_proj. Gemma 4 26B MoE uses this layout.

One gatherMM dispatch for the combined projection, split, then activate.
Same gatherSort optimization as SwitchGLU.
Fix UserInput init not populating self.images/self.videos (#182)

Property observers don't fire during initialization, so prompt.didSet — which keeps self.images/self.videos in sync — was never called by two of the inits:

init(prompt: String, images:videos:tools:additionalContext:) built a .chat prompt from its parameters but never copied images/videos into self.
init(prompt: Prompt, images:videos:...) had an explicit case .chat: break, silently dropping the parameters.
Both inits now mirror init(chat:processing:tools:additionalContext:): extract images and videos from the chat messages, or use the explicit parameters for non-chat prompts.

Add regression tests for UserInput init image/video sync (#182)

Three tests pinning the contract that self.images/self.videos are populated at construction, not deferred to the next prompt assignment:

testInitFromPromptStringPopulatesImages — string-based init reflects the images parameter immediately.
testInitFromPromptEnumPopulatesImagesForChat — the .chat branch of the Prompt-based init derives images from chat messages rather than dropping them.
testInitFromPromptStringPopulatesVideos — symmetric coverage for the videos parameter.
* pipeline prefill chunks with asyncEval -- 10x on GDN models

eval(cache) between prefill chunks was a blocking sync -- CPU idle while
GPU worked, then built next chunk serially. Python mlx-lm avoids this
because eval is deferred until a value is read.

asyncEval on cache state per chunk lets CPU build ahead while GPU
pipelines through. One terminal eval(cache) after the loop.

Prefill throughput (tok/s), Qwen3.6-35B-A3B-4bit, M5 Max 128GB:

  ctx  | before | after  | speedup
  128  | 260    | 696    | 2.7x
  512  | 235    | 2201   | 9.4x
  1k   | 244    | 3130   | 12.8x
  2k   | 270    | 3937   | 14.6x

No regression on dense models (Gemma 4 E2B, Llama 3.2 3B).
Decode unchanged -- different code path.
* fix gated delta state precision -- fp32 state to match Python

State was typed as q.dtype (bf16), losing precision across T-step
recurrence. Python mlx-lm keeps state in fp32. Aligns Swift with that.

- kernel source: write o_state as StT not InT
- dispatch: add StT template, use state.dtype for output
- gatedDeltaUpdate: create state in fp32, upcast if needed

Before: kernel showed ~0.25 max diff vs ops fallback at T>1.
After: matches ops path.

Tested on M5 Max 128GB, Qwen3.6-35B-A3B-4bit.

Co-authored-by: tturney <tturney@psyguard.ai>
Removed the command to show SDK build version during the build process. This was informative, but ultimately causing problems with installations and not actually guarding anything.
* models should not mutate state during eval

- fixes #157

* sync gptoss with mlx-lm -- no missing sinks
Fix Gemma4TextBackbone.callAsFunction crash on 1D token input

When callers construct LMInput directly (e.g. for manual KV-cache reuse), tokens arrive as 1D (L,) instead of 2D (1, L). This causes processedPerLayerInputs to be 3D, and the 4D subscript on finalPerLayerInputs crashes in mlx_array_dim. The first request succeeds because autoregressive step adds .newAxis; the crash hits on the initial prefill of continuation requests.

Fix: expand 1D inputs / 2D inputsEmbeds to add a leading batch dim at the top of callAsFunction. Zero-copy reshape; no behavior change for callers already passing the canonical 2D shape.
Fix Gemma4VisionPooler kernel derivation to use padded sequence length

Previously, the kernel was derived from the real patch count, which could yield kernel=2 at the 280-token budget instead of the expected 3, causing real patches to map to zero rows in the einsum output.

Extract gemma4VisionPoolingKernel(paddedPatchCount:outputLength:) and derive the kernel from pooledHiddenStates.dim(1) to consistently return pool=3 across all supported budgets. Add regression tests for all five budgets {70, 140, 280, 560, 1120}.
Adds `benchmarkLLMGeneration` and `LLMGenerationStats` to `BenchmarkHelpers`,
covering prefill (prompt processing) and decode (generation) throughput in
a single helper. Mirrors the surface of the existing loading and tokenization
helpers (warm-up + multi-run timing, returns BenchmarkStats).

Internally relies on the `GenerateCompletionInfo` emitted by the existing
generation stream, so prefill/decode times are measured the same way as
`llm-tool` and ChatSession surface them. Prompt defaults to text from
`BenchmarkDefaults.textSource`; pass an explicit `prompt:` to override.
`temperature` is 0 so trials are deterministic and tps comparisons are
reproducible.

Useful for evaluating model-level performance changes (fusion patches,
custom kernels, KV cache tweaks) without having to wire up a separate
CLI tool. Existing `BenchmarkHelpers` users (loading/tokenization/download)
are unchanged.
When attentionKeqV is true, v_proj does not exist and v shares k's
values. The previous code did `v = k` and then transposed v, but k
had already been transposed to (B, nKvHeads, L, head_dim) earlier
in the function. The second transpose reversed that layout, producing
a shape incompatible with the attention call and crashing in
broadcast_shapes.

Reproduces with mlx-community/gemma-4-31b-it-4bit. Fix: in the nil-vProj
branch, apply vNorm directly to the already-transposed k and skip the
extra transpose. The vProj branch is unchanged in behavior.
* Adopt GemmaFunctionParser to accomodate Gemma4 tool calls.
aleroot and others added 30 commits June 1, 2026 07:58
* Fix Gemma4Text quantized KV cache attention
Fix Qwen2.5-VL MROPE, rope_deltas, and invFreq loading to match Python mlx-vlm parity; plumb MROPE state through LMOutput.State for concurrent session safety; misc cleanup (bicubic video resampling, simpler apply_mrope, resize routing
- use a normal Linear layer and multiply by scale:  matches gemma3n and mlx-lm

Co-authored-by: kr1s0404 <tylerlolz0404@gmail.com>
* support for audio resources

- see #192, #194
- this adds representation of audio in UserInput and LMInput
Qwen2.5-VL ships max_pixels = 12,845,056 in its preprocessor config, ~12x the
1280*28*28 budget the model card recommends for image tasks. Default to the
recommended budget; a caller can still set any budget via
UserInput.Processing.minPixels / maxPixels.
Drop hasExplicitCache && from the sharedKV/offset ternaries in Gemma4TextBackbone; shared-KV layers should only gate on idx >= firstKVSharedLayerIdx.

Without the fix, no-cache forwards (embedding extraction, retrieval, batched eval) re-project K/V on shared layers, silently violating Gemma 4's invariant. Cached generation is unaffected (hasExplicitCache was always true on that path).
Add Multi-Token Prediction (MTP) speculative decoding for Gemma 4

Implements MTP speculative decoding in MLXLMCommon and MLXVLM:

- `MTPDrafterModel` protocol with `draftBlock(target:..., queryOffset:)`
  contract; `MTPDrafterContext` / `MTPDrafterContainer` mirror the
  existing `ModelContext` / `ModelContainer` split
- `createBidirectionalMask` / `createBidirectionalSlidingWindowMask`
  helpers for the drafter's attention path
- `Gemma4AssistantDraftModel` — 4-layer Q-only drafter cross-attending
  to the target's pooled full/sliding-attention K/V; loaded via
  `MTPDrafterModelFactory` with `gemma4_assistant` registry entry
- Three `LMOutput.Key` declarations (`mtpLastHiddenStatesKey`,
  `mtpSharedKVStatesKey`, `mtpEmitFlagKey`) for cross-model state
  exchange; Gemma4 target wired to emit on opt-in
- `MTPSpeculativeTokenIterator` driving the accept/reject round loop;
  `generate(...)` / `generateTokens(...)` public overloads
- `GenerateCompletionInfo` extended with MTP counters
  (`proposedDraftTokens`, `acceptedDraftTokens`, `passthroughReason`)
- Sticky passthrough fallback for mid-stream KV-cache quantization onset
- sharedKV snapshot trimmed in lockstep with cache rewind after partial
  acceptance
- Verified hidden sliced at the accepted-bonus position each round
  (aligns with mlx-lm's `verify.hidden[:, accepted:accepted+1, :]`)
- Fixture-anchored tests (masks, drafter forward, Rung 4 token parity)
  moved to IntegrationTesting; fixtures fetched from
  `angelsbrood/gemma4-mtp-fixtures` on HF

Empirical results at temp=0 (31b-it-8bit + 31B-assistant-bf16):
  bs=4 mt=64: 60.6% acceptance, 13.6 tok/s
  bs=6 mt=64: 29.8% acceptance,  8.6 tok/s
…hunked (#337)

* fix Gemma4 prepare() to honor windowSize -- prompt prefill is never chunked
Update Falcon H1/H1R inference behavior to match current mlx-lm more closely.

This fixes tied embedding output projection and scaling, routes attention through the common cache-aware causal path, advances Mamba cache metadata during generation, reports per-layer KV head counts correctly, and chunks SSM prefill to reduce long-prompt memory and latency.

Also mirror upstream ArraysCache length handling for hybrid and batched paths: preserve left padding and lengths through filter, extend, copy, and serialization; forward length preparation through CacheList; and use per-row lengths in Falcon H1 convolution plus chunked SSM state updates.

Add regression coverage for the shared causal attention mask path, SSM mask metadata, nested hybrid cache length propagation, and Mamba metadata copy behavior.
* fix: correct LFM2-MoE sigmoid routing and expert-bias selection

The MoE block applied softmax to the gate and folded expert_bias into the combination weights. LFM2-MoE is sigmoid-gated: the bias steers top-k selection only, and the weights come from the unbiased sigmoid, scaled by routed_scaling_factor. Mirrors ml-explore/mlx-lm#1354.
* fix speculative decode tests

- fix #315
- run models in float16 (more typical) and avoid the float32/tf32 mismatch
Store GatedDelta convolution state contiguously and advance array-cache metadata after each recurrent step, matching upstream mlx-lm behavior for Qwen3.5/Qwen3-Next style models.

Keep left-padding masks active after recurrent cache state initialization, add coverage for ArraysCache metadata advancement, and align Qwen3 RoPE setup with the shared rope initializer.
swift.org's signed prebuilt swift-syntax artifacts are keyed per
swift-syntax tag and exact toolchain, and 600.x/601.x artifacts are no
longer published for current toolchains. A package graph that resolves
below 602 silently falls back to compiling swift-syntax from source
(~200 build tasks). Raising the floor keeps consumers on the prebuilt
path. See #339 for measurements.
Lets a caller check whether the registry can instantiate a given model_type
without attempting a throwing, allocating createModel — e.g. to decide before
a multi-GB download whether a Hub repo's architecture is runnable. The creators
dictionary is private, so this check is only possible inside the type.
* add nemotron labs diffusion

Co-authored-by: Sachin Desai <sdesai@salesforce.com>
… Gemma3 / Qwen2VL / LFM2VL / Pixtral / Mistral3) (#344)

Honor windowSize for chunked prefill in 6 MLXVLM models

Single-pass prefill allocates transient buffers proportional to prompt
length; chunking to windowSize=512 reduces peak memory from ~17.7 GB to
~5.1 GB on an 8k-token Gemma4 prompt (#336).

Extends the Gemma4/#337 loop pattern to remaining single-pass models:
- FastVLM, Gemma3, Qwen2VL, LFM2VL: embed-only path
- Pixtral, Mistral3: inputIds + inputsEmbeds path

Deferred (need non-trivial tensor slicing): Qwen25VL, GlmOcr, Qwen35,
Qwen3VL, Paligemma.
Use a deterministic high-margin transition model for exact speculative-vs-greedy equality coverage. Keep Gemma3 in the suite as a smoke test because real MLX batched verification and token-by-token decoding can choose different argmaxes when logits are close.

Co-authored-by: Alessio Pollero <alessio.pollero@steerai.ai>
* Add Gemma 4 12B unified
…he helper (#356)

homeDirectoryForCurrentUser is API_UNAVAILABLE on iOS, which broke the IntegrationTestHelpers SPM library on iOS.
This change centralizes Hugging Face cache path resolution into two helper functions in IntegrationTestHelpers:
hfCacheDir()
hfSnapshotDir(modelId:revision:)

Co-authored-by: hehua2008 <hegan2010@gmail.com>
Speculative decoding is workload and hardware-sensitive. A draft model can look good in theory but fail to pay off if acceptance is low or if the extra model pressure is too high.

The change set prepares the project for edge-aware speculative decoding by adding the two missing foundations:

- observe speculative decoding quality/performance
- avoid applying auxiliary-model speculation when memory pressure makes it likely to hurt
Adds a basic Swift model conversion API for quantizing safetensors-backed LLMs through MLXLMCommon and LLMModelFactory. Fix #266 .
* Declare context as nonisolated(unsafe) on SDK < 26.

Otherwise, the code doesn't compile under Swift strict concurrency.
Adds kvScheme to GenerateParameters alongside existing kvBits. Provides
a string-based scheme selector for KV cache compression strategies.

Built-in: "affine4", "affine8" (equivalent to kvBits 4/8).
kvScheme overrides kvBits when set. Unrecognized schemes pass through
for custom KVCache implementations.

Plumbing only, no new compression algorithms. Prepares the API for
WHT-based and other non-affine KV compression schemes.
…342)

Fix Gemma 4 E-series QAT checkpoint loading for KV-shared layers

KV-shared layers own no k_proj/v_proj/k_norm, but Gemma4Attention declared
them non-optional and allocated them unconditionally. Full-precision/PTQ
checkpoints carry the redundant tensors so they loaded fine; QAT checkpoints
prune them, causing a keyNotFound failure at the first shared layer (15).

Make k_proj/k_norm optional, gate allocation on KV-owning layers only, and
drop redundant KV tensors in sanitize for PTQ compat. Fixes loading of
gemma-4-E2B-it-qat-4bit.
* Add LFM2.5 bidirectional encoders (Embedding + ColBERT) to MLXEmbedders

Adds a self-contained LFM2.5 bidirectional encoder supporting both LiquidAI retrieval heads: LFM2.5-Embedding-350M (CLS-pooled dense vector, cosine) and LFM2.5-ColBERT-350M (per-token 1024->128 projection, MaxSim late interaction). A single LFM2BidirectionalModel branches on the config "mlx" head; registered under model_type "lfm2" with EmbedderRegistry entries.

Architecture, relative to the causal generative LFM2 in MLXLLM: non-causal GQA attention, a centered depthwise short-conv, an additive pad mask plus a conv keep-mask, and a pooling/projection head (no LM head). sanitize applies only the conv-transpose guard (not a blanket model. prefix, which would corrupt the bare dense.weight). Quantization is data-driven (.scales), so bf16 and pre-quantized int4/int8 checkpoints load unchanged.

Tests cover config decoding, forward shapes/pooling for both heads, and float32 parity vs the Hugging Face reference vectors (>0.999 cosine), gated on local fixtures so they skip cleanly when absent. Note: MLX-compute tests must be run via xcodebuild (the Metal library is unavailable under 'swift test').
…rs (#330)

Mirrors the LLM-side sanitize drop merged in #342 to the MLXVLM Gemma4
text backbone: KV-shared layers (the last `num_kv_shared_layers`) reuse an
earlier layer's K/V and own no `k_proj`/`v_proj`/`k_norm`. After #327,
those module slots are optional and not built for shared layers. QAT
checkpoints already omit the redundant tensors; some PTQ checkpoints still
ship them and would fail to load against the smaller tree.

Scope the drop to the text backbone only — `vision_tower` / `audio_tower`
share the `layers.N.self_attn.{k,v}_proj` naming, so an unguarded
predicate would amputate tower layers >= firstKVSharedLayer.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.