v3 api embedder fixes#202
Conversation
| private var llmTask: Task<LMModelContainer, Error>? | ||
| private var vlmTask: Task<LMModelContainer, Error>? | ||
| private var lfm2Task: Task<LMModelContainer, Error>? | ||
| private var glm4Task: Task<LMModelContainer, Error>? | ||
| private var llmTask: Task<LLModelContainer, Error>? | ||
| private var vlmTask: Task<LLModelContainer, Error>? | ||
| private var lfm2Task: Task<LLModelContainer, Error>? | ||
| private var glm4Task: Task<LLModelContainer, Error>? |
There was a problem hiding this comment.
Just a typo fix. Trivial
| let container = try await EmbedderModelFactory.shared.loadContainer( | ||
| from: downloader, using: tokenizerLoader, | ||
| configuration: EmbedderRegistry.nomic_text_v1_5, |
There was a problem hiding this comment.
This is now using an LLM style ModelFactory -- indeed it is the same innards except for the load level creation of the model itself.
The pre-configured ids move to EmbedderRegistry, just like the LLMs have a LLMRegistry
| let resultEmbeddings = await modelContainer.perform { | ||
| (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in | ||
| let resultEmbeddings = await modelContainer.perform { context in | ||
| let tokenizer = context.tokenizer | ||
| let encoded = inputs.map { |
There was a problem hiding this comment.
The container now works identically to ModelContainer from the LLM side. Well, more or less -- it has been thinned down a bit, but otherwise the same.
Now the closure gets a context rather than a tuple.
There was a problem hiding this comment.
I merged the comments back into the MLXLMCommon side. This was otherwise identical.
There was a problem hiding this comment.
This is now in ModelFactory.swift and reorganized to match the LLM side.
| /// return result.map { $0.asArray(Float.self) } | ||
| /// } | ||
| /// ``` | ||
| public final class EmbedderModelContainer: Sendable { |
There was a problem hiding this comment.
This is a clone of the ModelContext for LLMs, but thinned down (none of the deprecated functions, a couple fewer conveniences that I think are better served by the properties below). Implementation-wise, identical.
| /// } | ||
| /// } | ||
| /// ``` | ||
| public actor ModelContainer { |
There was a problem hiding this comment.
Moves to / is replaced by EmbedderModelContainer.swift.
| } | ||
|
|
||
| public protocol EmbeddingModel: Module { | ||
| public protocol EmbeddingModel: BaseLanguageModel { |
There was a problem hiding this comment.
This is so we can get weight loading for free.
There was a problem hiding this comment.
Replaced by ModelFactory.swift and Load.swift from MLXLMCommon.
| public func loadModelContainer( | ||
| from downloader: any Downloader, | ||
| using tokenizerLoader: any TokenizerLoader, | ||
| configuration: ModelConfiguration, | ||
| useLatest: Bool = false, | ||
| progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } | ||
| ) async throws -> ModelContainer { |
There was a problem hiding this comment.
On break is that these free functions are gone -- use the EmbedderModelFactory instead. We have these for the LLMs, but the signature is the same so it would introduce some qualified names.
I think the use case for these is less interesting than the LLM case where it did auto-discovery of the two factories. Here the EmbedderModelFactory is trivial to use.
| } | ||
|
|
||
| /// Registry of model type, e.g 'bert', to functions that can instantiate the model from configuration. | ||
| public enum EmbedderTypeRegistry { |
There was a problem hiding this comment.
Matching the LLM side
| } | ||
|
|
||
| /// Registry of known embedder model configurations. | ||
| public class EmbedderRegistry: AbstractModelRegistry, @unchecked Sendable { |
There was a problem hiding this comment.
Matching the LLM side
| /// | ||
| /// This is created using a ``EmbedderModelFactory`` and often used | ||
| /// inside a ``EmbedderModelContainer``. | ||
| public struct EmbedderModelContext { |
There was a problem hiding this comment.
Just like ModelContext for LLMs
There was a problem hiding this comment.
See also EmbedderModelContainer
| public final class EmbedderModelFactory: GenericModelFactory< | ||
| EmbedderModelContext, EmbedderModelContainer | ||
| > |
There was a problem hiding this comment.
Just like for LLMs. In fact it leverages all of the various load functions by implementing the low level config/weight loading.
There was a problem hiding this comment.
We just use ModelConfiguration from MLXLMCommon now. It has a few extra properties but they are not used in this library/ignored.
|
|
||
| /// Shared instance with default model types. | ||
| public static let shared: ModelTypeRegistry = .init(creators: [ | ||
| public static let shared: ModelTypeRegistry<LanguageModel> = .init(creators: [ |
There was a problem hiding this comment.
These are now parameterized with the type of model they produce so that we can use them with the embedders.
| /// call. This is important for things like `ModelContainer` that have to perform async | ||
| /// work but also need to prevent other callers for using _any_ of the internal state. | ||
| final class SerialAccessContainer<T>: @unchecked Sendable { | ||
| package final class SerialAccessContainer<T>: @unchecked Sendable { |
There was a problem hiding this comment.
I didn't know about this -- private for the Package.swift scope. Perfect!
| /// Models can override this to inspect metadata (e.g. check `metadata["format"] == "mlx"`) | ||
| /// and skip or customize sanitization accordingly. | ||
| func sanitize(weights: [String: MLXArray], metadata: [String: String]) -> [String: MLXArray] | ||
| } |
There was a problem hiding this comment.
Split out the generic parts of language models (preparing the weights) so that this can work:
public func loadWeights(
modelDirectory: URL, model: BaseLanguageModel,
quantization: BaseConfiguration.Quantization? = nil,
perLayerQuantization: BaseConfiguration.PerLayerQuantization? = nil
) throws {| /// - ``loadModelContainer(from:using:configuration:useLatest:progressHandler:)`` | ||
| /// | ||
| /// or variants. | ||
| public protocol GenericModelFactory<ContextType, ContainerType>: Sendable { |
There was a problem hiding this comment.
This is made generic so we can use it across LLM, VLM, Embedder.
| } | ||
|
|
||
| /// For backward compatibility: `ModelFactory` refers to an LLM/VLM model factory. | ||
| public typealias ModelFactory = GenericModelFactory<ModelContext, ModelContainer> |
There was a problem hiding this comment.
This gives a compatible type name if anyone was using it.
|
@DePasqualeOrg once this and #206 are merged, what do you think about giving it the 3.x tag? I have a PR on the mlx-swift-examples side integrated with this PR. Anything else that needs to be dealt with before the first tag? |
|
Sounds good! I'm not aware of anything else that needs to be done. I'll keep an eye out for the new release and update the integration packages to use the new major version when it's out. |
- replaces embedder loading, context and container with LLM variant - this reuses all the code from LLM/VLM model type and model registries, along with the factory, download and load code
…ward-port main fixes Reverses the Gemma4 VLM merge direction: restore alpha's pre-merge 1538-line implementation (with all its perf work intact) as the base, then forward-port the two small targeted fixes from main on top. Why: main's 1929-line VLM Gemma4 (which the merge initially adopted) was written independently from alpha's, with structurally different norm classes (`Gemma4RMSNormZeroShift` + `Gemma4RMSNormNoScale` vs alpha's fused-RMSNorm-RoPE path), no compiled QKV, and no GPU-resident masked-scatter. Taking main's text-decoder wholesale would have lost: - argWhere-based gemma4MaskedScatter (avoids `asArray(Bool.self)` full-mask CPU readback during VLM image-token splice — alpha's #53 perf optimization) - Compiled QKV for attention - Fused gate_up_proj + compile()-cached SharedMLP forward - FusedGateUpSwitchGLU MoE routing for 26B-A4B - Symbolic .slidingWindow SDPA mask routing Forward-ported from main on top of alpha's base: - #211 (system message + modality order): added `Gemma4MessageGenerator` struct (system messages stay text-only; user/assistant messages emit images first, then videos, then text — matching the Gemma 4 chat template). `Gemma4Processor.prepare` now uses it instead of `Qwen2VLMessageGenerator()`. - #202 (modern dim API): `flattenedSource.shape[0]` → `dim(0)` in `gemma4MaskedScatter`. Other #202 hunks (`maskArray.shape.last` cleanup, unused token-count diagnostics) didn't apply — alpha's attention-mask handling is structured differently and the processor never had those diagnostics. Verification: - `swift build` clean, `swift build --build-tests` clean. - `swift test --skip Benchmarks` passes 152 tests in 14 suites. - LLM-side Gemma 4 E2B smoke (3 KV modes × 4 contexts) coherent — exercises the same Gemma4TextConfiguration decode path. Note: full image-input smoke for the VLM path (loading via VLMModelFactory and running an image+prompt through Gemma4Processor) would benefit from a new `--method vision` mode in InferenceBenchmark that takes an image path. Scoped as a follow-up — see comment on the PR for the proposed shape. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The earlier backport (commit 921e16b) restored alpha's pre-merge VLM Gemma4 to preserve text-decoder perf work. End-to-end smoke against the actual `mlx-community/gemma-4-e2b-it-4bit` weights revealed alpha's VLM Gemma4 was broken vs the current HF config schema: 1. `Gemma4VisionConfiguration` required an `image_size` field that doesn't exist in the current Gemma 4 vision_config (HF uses `position_embedding_size` + `patch_size` instead). 2. Alpha's vision encoder was missing layer types that main implements (`Gemma4VisionRMSNorm`, `Gemma4ClippableLinear`, `Gemma4VisionPooler`, `Gemma4MultimodalEmbedder`) — alpha's perf work was on the LLM/text- decoder side; the VLM path was apparently never tested end-to-end against the official Gemma 4 weights. Conclusion: alpha's VLM Gemma4 is not meaningfully "ahead" of main — it's just unmaintained. Take main's full 1929-line VLM Gemma4 (with the post-#180 fixes from #211 system-message + modality order and #202 modern dim API). The text-decoder perf gap with the LLM-side Gemma4.swift (fused gate_up_proj, compile()-cached SharedMLP, FusedGateUpSwitchGLU, etc.) becomes a known follow-up — to be addressed as part of the LLM/VLM common-code consolidation pass that lifts shared text-decoder pieces into MLXLMCommon. Single tweak vs main: added `.slidingWindow` to the `gemma4AdjustAttentionMask` switch (alpha's mlx-swift exposes the case; main was authored before it landed). Vision bench harness — new method to validate VLM end-to-end on real weights: - `runGenerationBenchmark` gains `images: [URL] = []`, `chatMessages: [Chat.Message]? = nil`, and `useVLM: Bool = false`. When `chatMessages` is set, the function builds `UserInput(chat:)` so the model's `MessageGenerator` runs and expands images into the chat template's image-placeholder tokens (Gemma 4: 280 vision soft tokens per image). Pre-existing call sites (text-only) continue to use the `messages: [Message]` raw-dict path with no behavioural change. - `loadOrCacheModel` gains `useVLM: Bool = false`. When true, dispatches through `VLMModelFactory.shared.loadContainer`. The cache key is salted with `:vlm` so the same model id can be cached separately for each modality (Gemma 4 E2B is loadable both ways). - New `runVisionBenchmark` wires the bench-level method case to a default golden-retriever fixture at `Tests/Benchmarks/Resources/vlm-test-prompts/test-image1.jpeg`. Pass/fail is a substring check against the output (default keyword: "dog"; override via env vars). - New `case "vision":` in the dispatch switch. CLI surface (benchmark.sh): - `--method vision` accepted. - `--image <path>` overrides the default image (relative paths resolved against the repo root, where the bench runs). - `--vision-prompt "..."` overrides the default prompt. - `--vision-expect <keyword>` overrides the pass/fail keyword. Verification: - `swift build` clean, `make` clean. - Vision smoke on Gemma 4 E2B 4bit + the checked-in golden-retriever image: 1/1 PASSED. - Prompt: "What animal is in this image? Answer in one word." - Output: "Dog" - Prepared 304 tokens (22 text + 280 image + 2 wrappers). - TTFT: 1445ms · Prefill: 211 tok/s · GPU peak: 3.94GB · KV cache: 7MB. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Proposed changes
ml-explore/mlx-swift-examples#468 is the branch adopting this -- between this branch and that I hope to get good notes for a porting guide, but that will be a different PR.
FYI @DePasqualeOrg @rudrankriyam @sxy-trans-n @CodebyCR @ronaldmannak
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes