-
Notifications
You must be signed in to change notification settings - Fork 288
v3 api embedder fixes #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,8 +11,8 @@ import MLXLMCommon | |
| import MLXVLM | ||
|
|
||
| // Both MLXLMCommon and MLXEmbedders define ModelContainer. | ||
| public typealias LMModelContainer = MLXLMCommon.ModelContainer | ||
| public typealias EmbeddingModelContainer = MLXEmbedders.ModelContainer | ||
| public typealias LLModelContainer = MLXLMCommon.ModelContainer | ||
| public typealias EmbeddingModelContainer = MLXEmbedders.EmbedderModelContainer | ||
|
|
||
| // MARK: - Error | ||
|
|
||
|
|
@@ -47,20 +47,20 @@ public actor IntegrationTestModels { | |
| private let downloader: any Downloader | ||
| private let tokenizerLoader: any TokenizerLoader | ||
|
|
||
| private var llmTask: Task<LMModelContainer, Error>? | ||
| private var vlmTask: Task<LMModelContainer, Error>? | ||
| private var lfm2Task: Task<LMModelContainer, Error>? | ||
| private var glm4Task: Task<LMModelContainer, Error>? | ||
| private var mistral3Task: Task<LMModelContainer, Error>? | ||
| private var nemotronTask: Task<LMModelContainer, Error>? | ||
| private var qwen35Task: Task<LMModelContainer, Error>? | ||
| private var llmTask: Task<LLModelContainer, Error>? | ||
| private var vlmTask: Task<LLModelContainer, Error>? | ||
| private var lfm2Task: Task<LLModelContainer, Error>? | ||
| private var glm4Task: Task<LLModelContainer, Error>? | ||
| private var mistral3Task: Task<LLModelContainer, Error>? | ||
| private var nemotronTask: Task<LLModelContainer, Error>? | ||
| private var qwen35Task: Task<LLModelContainer, Error>? | ||
|
|
||
| public init(downloader: any Downloader, tokenizerLoader: any TokenizerLoader) { | ||
| self.downloader = downloader | ||
| self.tokenizerLoader = tokenizerLoader | ||
| } | ||
|
|
||
| public func llmContainer() async throws -> LMModelContainer { | ||
| public func llmContainer() async throws -> LLModelContainer { | ||
| if let task = llmTask { | ||
| return try await task.value | ||
| } | ||
|
|
@@ -81,7 +81,7 @@ public actor IntegrationTestModels { | |
| return try await task.value | ||
| } | ||
|
|
||
| public func vlmContainer() async throws -> LMModelContainer { | ||
| public func vlmContainer() async throws -> LLModelContainer { | ||
| if let task = vlmTask { | ||
| return try await task.value | ||
| } | ||
|
|
@@ -102,7 +102,7 @@ public actor IntegrationTestModels { | |
| return try await task.value | ||
| } | ||
|
|
||
| public func lfm2Container() async throws -> LMModelContainer { | ||
| public func lfm2Container() async throws -> LLModelContainer { | ||
| if let task = lfm2Task { | ||
| return try await task.value | ||
| } | ||
|
|
@@ -123,7 +123,7 @@ public actor IntegrationTestModels { | |
| return try await task.value | ||
| } | ||
|
|
||
| public func glm4Container() async throws -> LMModelContainer { | ||
| public func glm4Container() async throws -> LLModelContainer { | ||
| if let task = glm4Task { | ||
| return try await task.value | ||
| } | ||
|
|
@@ -144,7 +144,7 @@ public actor IntegrationTestModels { | |
| return try await task.value | ||
| } | ||
|
|
||
| public func mistral3Container() async throws -> LMModelContainer { | ||
| public func mistral3Container() async throws -> LLModelContainer { | ||
| if let task = mistral3Task { | ||
| return try await task.value | ||
| } | ||
|
|
@@ -165,7 +165,7 @@ public actor IntegrationTestModels { | |
| return try await task.value | ||
| } | ||
|
|
||
| public func nemotronContainer() async throws -> LMModelContainer { | ||
| public func nemotronContainer() async throws -> LLModelContainer { | ||
| if let task = nemotronTask { | ||
| return try await task.value | ||
| } | ||
|
|
@@ -186,7 +186,7 @@ public actor IntegrationTestModels { | |
| return try await task.value | ||
| } | ||
|
|
||
| public func qwen35Container() async throws -> LMModelContainer { | ||
| public func qwen35Container() async throws -> LLModelContainer { | ||
| if let task = qwen35Task { | ||
| return try await task.value | ||
| } | ||
|
|
@@ -212,8 +212,9 @@ public actor IntegrationTestModels { | |
| let tokenizerLoader = self.tokenizerLoader | ||
| let id = "nomic_text_v1_5" | ||
| print("Loading embedding model: \(id)") | ||
| let container = try await MLXEmbedders.loadModelContainer( | ||
| from: downloader, using: tokenizerLoader, configuration: .nomic_text_v1_5, | ||
| let container = try await EmbedderModelFactory.shared.loadContainer( | ||
| from: downloader, using: tokenizerLoader, | ||
| configuration: EmbedderRegistry.nomic_text_v1_5, | ||
| progressHandler: logProgress(id) | ||
| ) | ||
| print("Loaded embedding model: \(id)") | ||
|
|
@@ -227,7 +228,7 @@ private let generateParameters = GenerateParameters(maxTokens: 200, temperature: | |
|
|
||
| public enum ChatSessionTests { | ||
|
|
||
| public static func oneShot(container: LMModelContainer) async throws { | ||
| public static func oneShot(container: LLModelContainer) async throws { | ||
| let session = ChatSession(container, generateParameters: generateParameters) | ||
| let result = try await streamAndCollect( | ||
| session.streamResponse( | ||
|
|
@@ -238,7 +239,7 @@ public enum ChatSessionTests { | |
| ) | ||
| } | ||
|
|
||
| public static func oneShotStream(container: LMModelContainer) async throws { | ||
| public static func oneShotStream(container: LLModelContainer) async throws { | ||
| let session = ChatSession(container, generateParameters: generateParameters) | ||
| let result = try await streamAndCollect( | ||
| session.streamResponse( | ||
|
|
@@ -249,7 +250,7 @@ public enum ChatSessionTests { | |
| ) | ||
| } | ||
|
|
||
| public static func multiTurnConversation(container: LMModelContainer) async throws { | ||
| public static func multiTurnConversation(container: LLModelContainer) async throws { | ||
| let session = ChatSession( | ||
| container, instructions: "You are a helpful assistant. Keep responses brief.", | ||
| generateParameters: generateParameters) | ||
|
|
@@ -268,7 +269,7 @@ public enum ChatSessionTests { | |
| ) | ||
| } | ||
|
|
||
| public static func visionModel(container: LMModelContainer) async throws { | ||
| public static func visionModel(container: LLModelContainer) async throws { | ||
| let session = ChatSession(container, generateParameters: generateParameters) | ||
| let redImage = CIImage(color: .red).cropped( | ||
| to: CGRect(x: 0, y: 0, width: 100, height: 100)) | ||
|
|
@@ -283,7 +284,7 @@ public enum ChatSessionTests { | |
| ) | ||
| } | ||
|
|
||
| public static func streamDetailsWithTools(container: LMModelContainer) async throws { | ||
| public static func streamDetailsWithTools(container: LLModelContainer) async throws { | ||
| let tools: [ToolSpec] = [weatherToolSchema] | ||
| let session = ChatSession(container, generateParameters: generateParameters, tools: tools) | ||
|
|
||
|
|
@@ -334,7 +335,7 @@ public enum ChatSessionTests { | |
| } | ||
| } | ||
|
|
||
| public static func toolInvocation(container: LMModelContainer) async throws { | ||
| public static func toolInvocation(container: LLModelContainer) async throws { | ||
| struct EmptyInput: Codable {} | ||
|
|
||
| struct TimeOutput: Codable { | ||
|
|
@@ -369,7 +370,7 @@ public enum ChatSessionTests { | |
| ) | ||
| } | ||
|
|
||
| public static func promptRehydration(container: LMModelContainer) async throws { | ||
| public static func promptRehydration(container: LLModelContainer) async throws { | ||
| let history: [Chat.Message] = [ | ||
| .system("You are a helpful assistant."), | ||
| .user("My name is Bob."), | ||
|
|
@@ -414,8 +415,9 @@ public enum EmbedderTests { | |
| ) async throws { | ||
| let modelId = "mlx-community/gemma-3-1b-it-qat-4bit" | ||
| print("Loading Gemma 3 embedding model: \(modelId)") | ||
| let modelContainer = try await MLXEmbedders.loadModelContainer( | ||
| from: downloader, using: tokenizerLoader, configuration: .init(id: modelId), | ||
| let modelContainer = try await EmbedderModelFactory.shared.loadContainer( | ||
| from: downloader, using: tokenizerLoader, | ||
| configuration: ModelConfiguration(id: modelId), | ||
| progressHandler: logProgress(modelId) | ||
| ) | ||
| print("Loaded Gemma 3 embedding model: \(modelId)") | ||
|
|
@@ -425,8 +427,8 @@ public enum EmbedderTests { | |
| "In the United States, PepsiCo Inc. is a leading soft drink company.", | ||
| ] | ||
|
|
||
| let resultEmbeddings = await modelContainer.perform { | ||
| (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in | ||
| let resultEmbeddings = await modelContainer.perform { context in | ||
| let tokenizer = context.tokenizer | ||
| let encoded = inputs.map { | ||
|
Comment on lines
-428
to
432
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The container now works identically to Now the closure gets a context rather than a tuple. |
||
| tokenizer.encode(text: $0, addSpecialTokens: true) | ||
| } | ||
|
|
@@ -446,10 +448,10 @@ public enum EmbedderTests { | |
| let mask = (padded .!= (tokenizer.eosTokenId ?? 0)) | ||
| let tokenTypes = MLXArray.zeros(like: padded) | ||
|
|
||
| let modelOutput = model( | ||
| let modelOutput = context.model( | ||
| padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask) | ||
|
|
||
| let result = pooling( | ||
| let result = context.pooling( | ||
| modelOutput, | ||
| normalize: true, applyLayerNorm: true | ||
| ) | ||
|
|
@@ -488,8 +490,9 @@ public enum EmbedderTests { | |
| "search_document: Polar Bears", | ||
| ] | ||
|
|
||
| let resultEmbeddings = await container.perform { | ||
| (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in | ||
| let resultEmbeddings = await container.perform { context in | ||
| let tokenizer = context.tokenizer | ||
|
|
||
| let inputs = searchInputs.map { | ||
| tokenizer.encode(text: $0, addSpecialTokens: true) | ||
| } | ||
|
|
@@ -506,8 +509,9 @@ public enum EmbedderTests { | |
| }) | ||
| let mask = (padded .!= tokenizer.eosTokenId ?? 0) | ||
| let tokenTypes = MLXArray.zeros(like: padded) | ||
| let result = pooling( | ||
| model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), | ||
| let result = context.pooling( | ||
| context.model( | ||
| padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), | ||
| normalize: true, applyLayerNorm: true | ||
| ) | ||
| result.eval() | ||
|
|
@@ -539,17 +543,15 @@ public enum EmbedderTests { | |
|
|
||
| public enum ToolCallTests { | ||
|
|
||
| // MARK: LFM2 | ||
|
|
||
| public static func lfm2FormatAutoDetection(container: LMModelContainer) async throws { | ||
| public static func lfm2FormatAutoDetection(container: LLModelContainer) async throws { | ||
| let config = await container.configuration | ||
| try check( | ||
| config.toolCallFormat == ToolCallFormat.lfm2, | ||
| "Expected .lfm2 tool call format, got: \(String(describing: config.toolCallFormat))" | ||
| ) | ||
| } | ||
|
|
||
| public static func lfm2EndToEndGeneration(container: LMModelContainer) async throws { | ||
| public static func lfm2EndToEndGeneration(container: LLModelContainer) async throws { | ||
| let (result, toolCalls) = try await generateWithTools( | ||
| container: container, | ||
| userMessage: "What's the weather in Tokyo?") | ||
|
|
@@ -572,17 +574,15 @@ public enum ToolCallTests { | |
| ) | ||
| } | ||
|
|
||
| // MARK: GLM4 | ||
|
|
||
| public static func glm4FormatAutoDetection(container: LMModelContainer) async throws { | ||
| public static func glm4FormatAutoDetection(container: LLModelContainer) async throws { | ||
| let config = await container.configuration | ||
| try check( | ||
| config.toolCallFormat == ToolCallFormat.glm4, | ||
| "Expected .glm4 tool call format, got: \(String(describing: config.toolCallFormat))" | ||
| ) | ||
| } | ||
|
|
||
| public static func glm4EndToEndGeneration(container: LMModelContainer) async throws { | ||
| public static func glm4EndToEndGeneration(container: LLModelContainer) async throws { | ||
| let (result, toolCalls) = try await generateWithTools( | ||
| container: container, | ||
| userMessage: "What's the weather in Paris?") | ||
|
|
@@ -607,15 +607,15 @@ public enum ToolCallTests { | |
|
|
||
| // MARK: Mistral3 | ||
|
|
||
| public static func mistral3FormatAutoDetection(container: LMModelContainer) async throws { | ||
| public static func mistral3FormatAutoDetection(container: LLModelContainer) async throws { | ||
| let config = await container.configuration | ||
| try check( | ||
| config.toolCallFormat == ToolCallFormat.mistral, | ||
| "Expected .mistral tool call format, got: \(String(describing: config.toolCallFormat))" | ||
| ) | ||
| } | ||
|
|
||
| public static func mistral3EndToEndGeneration(container: LMModelContainer) async throws { | ||
| public static func mistral3EndToEndGeneration(container: LLModelContainer) async throws { | ||
| let input = UserInput( | ||
| chat: [ | ||
| .system( | ||
|
|
@@ -647,7 +647,7 @@ public enum ToolCallTests { | |
| ) | ||
| } | ||
|
|
||
| public static func mistral3MultiToolGeneration(container: LMModelContainer) async throws { | ||
| public static func mistral3MultiToolGeneration(container: LLModelContainer) async throws { | ||
| let input = UserInput( | ||
| chat: [ | ||
| .system( | ||
|
|
@@ -680,15 +680,15 @@ public enum ToolCallTests { | |
|
|
||
| // MARK: Nemotron | ||
|
|
||
| public static func nemotronFormatAutoDetection(container: LMModelContainer) async throws { | ||
| public static func nemotronFormatAutoDetection(container: LLModelContainer) async throws { | ||
| let config = await container.configuration | ||
| try check( | ||
| config.toolCallFormat == ToolCallFormat.xmlFunction, | ||
| "Expected .xmlFunction tool call format, got: \(String(describing: config.toolCallFormat))" | ||
| ) | ||
| } | ||
|
|
||
| public static func nemotronEndToEndGeneration(container: LMModelContainer) async throws { | ||
| public static func nemotronEndToEndGeneration(container: LLModelContainer) async throws { | ||
| let input = UserInput( | ||
| chat: [ | ||
| .system( | ||
|
|
@@ -721,7 +721,7 @@ public enum ToolCallTests { | |
| ) | ||
| } | ||
|
|
||
| public static func nemotronMultiToolGeneration(container: LMModelContainer) async throws { | ||
| public static func nemotronMultiToolGeneration(container: LLModelContainer) async throws { | ||
| let input = UserInput( | ||
| chat: [ | ||
| .system( | ||
|
|
@@ -755,15 +755,15 @@ public enum ToolCallTests { | |
|
|
||
| // MARK: Qwen3.5 | ||
|
|
||
| public static func qwen35FormatAutoDetection(container: LMModelContainer) async throws { | ||
| public static func qwen35FormatAutoDetection(container: LLModelContainer) async throws { | ||
| let config = await container.configuration | ||
| try check( | ||
| config.toolCallFormat == ToolCallFormat.xmlFunction, | ||
| "Expected .xmlFunction tool call format, got: \(String(describing: config.toolCallFormat))" | ||
| ) | ||
| } | ||
|
|
||
| public static func qwen35EndToEndGeneration(container: LMModelContainer) async throws { | ||
| public static func qwen35EndToEndGeneration(container: LLModelContainer) async throws { | ||
| let input = UserInput( | ||
| chat: [ | ||
| .system( | ||
|
|
@@ -795,7 +795,7 @@ public enum ToolCallTests { | |
| ) | ||
| } | ||
|
|
||
| public static func qwen35MultiToolGeneration(container: LMModelContainer) async throws { | ||
| public static func qwen35MultiToolGeneration(container: LLModelContainer) async throws { | ||
| let input = UserInput( | ||
| chat: [ | ||
| .system( | ||
|
|
@@ -830,7 +830,7 @@ public enum ToolCallTests { | |
| // MARK: Helpers | ||
|
|
||
| private static func generateWithTools( | ||
| container: LMModelContainer, | ||
| container: LLModelContainer, | ||
| input: UserInput, | ||
| maxTokens: Int = 100 | ||
| ) async throws -> (text: String, toolCalls: [ToolCall]) { | ||
|
|
@@ -858,7 +858,7 @@ public enum ToolCallTests { | |
| } | ||
|
|
||
| private static func generateWithTools( | ||
| container: LMModelContainer, | ||
| container: LLModelContainer, | ||
| userMessage: String | ||
| ) async throws -> (text: String, toolCalls: [ToolCall]) { | ||
| let input = UserInput( | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now using an LLM style ModelFactory -- indeed it is the same innards except for the load level creation of the model itself.
The pre-configured ids move to EmbedderRegistry, just like the LLMs have a LLMRegistry