Skip to content

v3 api embedder fixes#202

Merged
davidkoski merged 3 commits into
mainfrom
v3-api-p2
Apr 14, 2026
Merged

v3 api embedder fixes#202
davidkoski merged 3 commits into
mainfrom
v3-api-p2

Conversation

@davidkoski

@davidkoski davidkoski commented Apr 10, 2026

Copy link
Copy Markdown
Collaborator

Proposed changes

  • replaces embedder loading, context and container with LLM variant
  • this reuses all the code from LLM/VLM model type and model registries, along with the factory, download and load code
  • should complete [BUG] version 3 todos #189

ml-explore/mlx-swift-examples#468 is the branch adopting this -- between this branch and that I hope to get good notes for a porting guide, but that will be a different PR.

FYI @DePasqualeOrg @rudrankriyam @sxy-trans-n @CodebyCR @ronaldmannak

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@davidkoski davidkoski requested a review from angeloskath April 10, 2026 00:27
Comment on lines -47 to +50
private var llmTask: Task<LMModelContainer, Error>?
private var vlmTask: Task<LMModelContainer, Error>?
private var lfm2Task: Task<LMModelContainer, Error>?
private var glm4Task: Task<LMModelContainer, Error>?
private var llmTask: Task<LLModelContainer, Error>?
private var vlmTask: Task<LLModelContainer, Error>?
private var lfm2Task: Task<LLModelContainer, Error>?
private var glm4Task: Task<LLModelContainer, Error>?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a typo fix. Trivial

Comment on lines +146 to +148
let container = try await EmbedderModelFactory.shared.loadContainer(
from: downloader, using: tokenizerLoader,
configuration: EmbedderRegistry.nomic_text_v1_5,

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now using an LLM style ModelFactory -- indeed it is the same innards except for the load level creation of the model itself.

The pre-configured ids move to EmbedderRegistry, just like the LLMs have a LLMRegistry

Comment on lines -359 to 362
let resultEmbeddings = await modelContainer.perform {
(model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in
let resultEmbeddings = await modelContainer.perform { context in
let tokenizer = context.tokenizer
let encoded = inputs.map {

@davidkoski davidkoski Apr 10, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The container now works identically to ModelContainer from the LLM side. Well, more or less -- it has been thinned down a bit, but otherwise the same.

Now the closure gets a context rather than a tuple.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged the comments back into the MLXLMCommon side. This was otherwise identical.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now in ModelFactory.swift and reorganized to match the LLM side.

/// return result.map { $0.asArray(Float.self) }
/// }
/// ```
public final class EmbedderModelContainer: Sendable {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a clone of the ModelContext for LLMs, but thinned down (none of the deprecated functions, a couple fewer conveniences that I think are better served by the properties below). Implementation-wise, identical.

/// }
/// }
/// ```
public actor ModelContainer {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moves to / is replaced by EmbedderModelContainer.swift.

}

public protocol EmbeddingModel: Module {
public protocol EmbeddingModel: BaseLanguageModel {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so we can get weight loading for free.

Comment thread Libraries/MLXEmbedders/Load.swift Outdated

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced by ModelFactory.swift and Load.swift from MLXLMCommon.

Comment on lines -223 to -229
public func loadModelContainer(
from downloader: any Downloader,
using tokenizerLoader: any TokenizerLoader,
configuration: ModelConfiguration,
useLatest: Bool = false,
progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
) async throws -> ModelContainer {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On break is that these free functions are gone -- use the EmbedderModelFactory instead. We have these for the LLMs, but the signature is the same so it would introduce some qualified names.

I think the use case for these is less interesting than the LLM case where it did auto-discovery of the two factories. Here the EmbedderModelFactory is trivial to use.

}

/// Registry of model type, e.g 'bert', to functions that can instantiate the model from configuration.
public enum EmbedderTypeRegistry {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching the LLM side

}

/// Registry of known embedder model configurations.
public class EmbedderRegistry: AbstractModelRegistry, @unchecked Sendable {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching the LLM side

///
/// This is created using a ``EmbedderModelFactory`` and often used
/// inside a ``EmbedderModelContainer``.
public struct EmbedderModelContext {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just like ModelContext for LLMs

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also EmbedderModelContainer

Comment on lines +131 to +133
public final class EmbedderModelFactory: GenericModelFactory<
EmbedderModelContext, EmbedderModelContainer
>

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just like for LLMs. In fact it leverages all of the various load functions by implementing the low level config/weight loading.

Comment thread Libraries/MLXEmbedders/Models.swift Outdated

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just use ModelConfiguration from MLXLMCommon now. It has a few extra properties but they are not used in this library/ignored.


/// Shared instance with default model types.
public static let shared: ModelTypeRegistry = .init(creators: [
public static let shared: ModelTypeRegistry<LanguageModel> = .init(creators: [

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are now parameterized with the type of model they produce so that we can use them with the embedders.

/// call. This is important for things like `ModelContainer` that have to perform async
/// work but also need to prevent other callers for using _any_ of the internal state.
final class SerialAccessContainer<T>: @unchecked Sendable {
package final class SerialAccessContainer<T>: @unchecked Sendable {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know about this -- private for the Package.swift scope. Perfect!

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, so cool!

/// Models can override this to inspect metadata (e.g. check `metadata["format"] == "mlx"`)
/// and skip or customize sanitization accordingly.
func sanitize(weights: [String: MLXArray], metadata: [String: String]) -> [String: MLXArray]
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split out the generic parts of language models (preparing the weights) so that this can work:

public func loadWeights(
    modelDirectory: URL, model: BaseLanguageModel,
    quantization: BaseConfiguration.Quantization? = nil,
    perLayerQuantization: BaseConfiguration.PerLayerQuantization? = nil
) throws {

/// - ``loadModelContainer(from:using:configuration:useLatest:progressHandler:)``
///
/// or variants.
public protocol GenericModelFactory<ContextType, ContainerType>: Sendable {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is made generic so we can use it across LLM, VLM, Embedder.

}

/// For backward compatibility: `ModelFactory` refers to an LLM/VLM model factory.
public typealias ModelFactory = GenericModelFactory<ModelContext, ModelContainer>

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives a compatible type name if anyone was using it.

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Base automatically changed from v3-api-p1 to main April 13, 2026 23:47
@davidkoski

Copy link
Copy Markdown
Collaborator Author

@DePasqualeOrg once this and #206 are merged, what do you think about giving it the 3.x tag? I have a PR on the mlx-swift-examples side integrated with this PR.

Anything else that needs to be dealt with before the first tag?

@DePasqualeOrg

Copy link
Copy Markdown
Contributor

Sounds good! I'm not aware of anything else that needs to be done. I'll keep an eye out for the new release and update the integration packages to use the new major version when it's out.

- replaces embedder loading, context and container with LLM variant
- this reuses all the code from LLM/VLM model type and model registries, along with the factory, download and load code
@davidkoski davidkoski merged commit 826fbc9 into main Apr 14, 2026
3 checks passed
@davidkoski davidkoski deleted the v3-api-p2 branch April 14, 2026 16:03
ekryski referenced this pull request in ekryski/mlx-swift-lm May 5, 2026
…ward-port main fixes

Reverses the Gemma4 VLM merge direction: restore alpha's pre-merge
1538-line implementation (with all its perf work intact) as the base,
then forward-port the two small targeted fixes from main on top.

Why: main's 1929-line VLM Gemma4 (which the merge initially adopted) was
written independently from alpha's, with structurally different norm
classes (`Gemma4RMSNormZeroShift` + `Gemma4RMSNormNoScale` vs alpha's
fused-RMSNorm-RoPE path), no compiled QKV, and no GPU-resident
masked-scatter. Taking main's text-decoder wholesale would have lost:

- argWhere-based gemma4MaskedScatter (avoids `asArray(Bool.self)`
  full-mask CPU readback during VLM image-token splice — alpha's #53
  perf optimization)
- Compiled QKV for attention
- Fused gate_up_proj + compile()-cached SharedMLP forward
- FusedGateUpSwitchGLU MoE routing for 26B-A4B
- Symbolic .slidingWindow SDPA mask routing

Forward-ported from main on top of alpha's base:

- #211 (system message + modality order): added `Gemma4MessageGenerator`
  struct (system messages stay text-only; user/assistant messages emit
  images first, then videos, then text — matching the Gemma 4 chat
  template). `Gemma4Processor.prepare` now uses it instead of
  `Qwen2VLMessageGenerator()`.
- #202 (modern dim API): `flattenedSource.shape[0]` → `dim(0)` in
  `gemma4MaskedScatter`. Other #202 hunks (`maskArray.shape.last`
  cleanup, unused token-count diagnostics) didn't apply — alpha's
  attention-mask handling is structured differently and the processor
  never had those diagnostics.

Verification:
- `swift build` clean, `swift build --build-tests` clean.
- `swift test --skip Benchmarks` passes 152 tests in 14 suites.
- LLM-side Gemma 4 E2B smoke (3 KV modes × 4 contexts) coherent —
  exercises the same Gemma4TextConfiguration decode path.

Note: full image-input smoke for the VLM path (loading via
VLMModelFactory and running an image+prompt through Gemma4Processor)
would benefit from a new `--method vision` mode in InferenceBenchmark
that takes an image path. Scoped as a follow-up — see comment on the
PR for the proposed shape.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ekryski referenced this pull request in ekryski/mlx-swift-lm May 5, 2026
The earlier backport (commit 921e16b) restored alpha's pre-merge VLM
Gemma4 to preserve text-decoder perf work. End-to-end smoke against
the actual `mlx-community/gemma-4-e2b-it-4bit` weights revealed alpha's
VLM Gemma4 was broken vs the current HF config schema:

1. `Gemma4VisionConfiguration` required an `image_size` field that
   doesn't exist in the current Gemma 4 vision_config (HF uses
   `position_embedding_size` + `patch_size` instead).
2. Alpha's vision encoder was missing layer types that main implements
   (`Gemma4VisionRMSNorm`, `Gemma4ClippableLinear`, `Gemma4VisionPooler`,
   `Gemma4MultimodalEmbedder`) — alpha's perf work was on the LLM/text-
   decoder side; the VLM path was apparently never tested end-to-end
   against the official Gemma 4 weights.

Conclusion: alpha's VLM Gemma4 is not meaningfully "ahead" of main —
it's just unmaintained. Take main's full 1929-line VLM Gemma4 (with
the post-#180 fixes from #211 system-message + modality order and
#202 modern dim API). The text-decoder perf gap with the LLM-side
Gemma4.swift (fused gate_up_proj, compile()-cached SharedMLP,
FusedGateUpSwitchGLU, etc.) becomes a known follow-up — to be
addressed as part of the LLM/VLM common-code consolidation pass that
lifts shared text-decoder pieces into MLXLMCommon.

Single tweak vs main: added `.slidingWindow` to the
`gemma4AdjustAttentionMask` switch (alpha's mlx-swift exposes the
case; main was authored before it landed).

Vision bench harness — new method to validate VLM end-to-end on real
weights:

- `runGenerationBenchmark` gains `images: [URL] = []`, `chatMessages:
  [Chat.Message]? = nil`, and `useVLM: Bool = false`. When
  `chatMessages` is set, the function builds `UserInput(chat:)` so the
  model's `MessageGenerator` runs and expands images into the chat
  template's image-placeholder tokens (Gemma 4: 280 vision soft tokens
  per image). Pre-existing call sites (text-only) continue to use the
  `messages: [Message]` raw-dict path with no behavioural change.
- `loadOrCacheModel` gains `useVLM: Bool = false`. When true, dispatches
  through `VLMModelFactory.shared.loadContainer`. The cache key is
  salted with `:vlm` so the same model id can be cached separately for
  each modality (Gemma 4 E2B is loadable both ways).
- New `runVisionBenchmark` wires the bench-level method case to a
  default golden-retriever fixture at
  `Tests/Benchmarks/Resources/vlm-test-prompts/test-image1.jpeg`.
  Pass/fail is a substring check against the output (default keyword:
  "dog"; override via env vars).
- New `case "vision":` in the dispatch switch.

CLI surface (benchmark.sh):
- `--method vision` accepted.
- `--image <path>` overrides the default image (relative paths
  resolved against the repo root, where the bench runs).
- `--vision-prompt "..."` overrides the default prompt.
- `--vision-expect <keyword>` overrides the pass/fail keyword.

Verification:
- `swift build` clean, `make` clean.
- Vision smoke on Gemma 4 E2B 4bit + the checked-in golden-retriever
  image: 1/1 PASSED.
  - Prompt: "What animal is in this image? Answer in one word."
  - Output: "Dog"
  - Prepared 304 tokens (22 text + 280 image + 2 wrappers).
  - TTFT: 1445ms · Prefill: 211 tok/s · GPU peak: 3.94GB · KV cache: 7MB.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants