Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Libraries/BenchmarkHelpers/BenchmarkHelpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -409,19 +409,19 @@ public func benchmarkVLMLoading(
public func benchmarkEmbeddingLoading(
from downloader: any Downloader,
using tokenizerLoader: any TokenizerLoader,
configuration: MLXEmbedders.ModelConfiguration = .init(
configuration: ModelConfiguration = .init(
id: "mlx-community/Qwen3-Embedding-0.6B-8bit"),
runs: Int = BenchmarkDefaults.loadingRuns
) async throws -> BenchmarkStats {
_ = try await MLXEmbedders.loadModelContainer(
_ = try await EmbedderModelFactory.shared.loadContainer(
from: downloader, using: tokenizerLoader, configuration: configuration
) { _ in }
Memory.clearCache()

var times: [Double] = []
for i in 1 ... runs {
let start = CFAbsoluteTimeGetCurrent()
_ = try await MLXEmbedders.loadModelContainer(
_ = try await EmbedderModelFactory.shared.loadContainer(
from: downloader, using: tokenizerLoader, configuration: configuration
) { _ in }
let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000
Expand Down
108 changes: 54 additions & 54 deletions Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import MLXLMCommon
import MLXVLM

// Both MLXLMCommon and MLXEmbedders define ModelContainer.
public typealias LMModelContainer = MLXLMCommon.ModelContainer
public typealias EmbeddingModelContainer = MLXEmbedders.ModelContainer
public typealias LLModelContainer = MLXLMCommon.ModelContainer
public typealias EmbeddingModelContainer = MLXEmbedders.EmbedderModelContainer

// MARK: - Error

Expand Down Expand Up @@ -47,20 +47,20 @@ public actor IntegrationTestModels {
private let downloader: any Downloader
private let tokenizerLoader: any TokenizerLoader

private var llmTask: Task<LMModelContainer, Error>?
private var vlmTask: Task<LMModelContainer, Error>?
private var lfm2Task: Task<LMModelContainer, Error>?
private var glm4Task: Task<LMModelContainer, Error>?
private var mistral3Task: Task<LMModelContainer, Error>?
private var nemotronTask: Task<LMModelContainer, Error>?
private var qwen35Task: Task<LMModelContainer, Error>?
private var llmTask: Task<LLModelContainer, Error>?
private var vlmTask: Task<LLModelContainer, Error>?
private var lfm2Task: Task<LLModelContainer, Error>?
private var glm4Task: Task<LLModelContainer, Error>?
private var mistral3Task: Task<LLModelContainer, Error>?
private var nemotronTask: Task<LLModelContainer, Error>?
private var qwen35Task: Task<LLModelContainer, Error>?

public init(downloader: any Downloader, tokenizerLoader: any TokenizerLoader) {
self.downloader = downloader
self.tokenizerLoader = tokenizerLoader
}

public func llmContainer() async throws -> LMModelContainer {
public func llmContainer() async throws -> LLModelContainer {
if let task = llmTask {
return try await task.value
}
Expand All @@ -81,7 +81,7 @@ public actor IntegrationTestModels {
return try await task.value
}

public func vlmContainer() async throws -> LMModelContainer {
public func vlmContainer() async throws -> LLModelContainer {
if let task = vlmTask {
return try await task.value
}
Expand All @@ -102,7 +102,7 @@ public actor IntegrationTestModels {
return try await task.value
}

public func lfm2Container() async throws -> LMModelContainer {
public func lfm2Container() async throws -> LLModelContainer {
if let task = lfm2Task {
return try await task.value
}
Expand All @@ -123,7 +123,7 @@ public actor IntegrationTestModels {
return try await task.value
}

public func glm4Container() async throws -> LMModelContainer {
public func glm4Container() async throws -> LLModelContainer {
if let task = glm4Task {
return try await task.value
}
Expand All @@ -144,7 +144,7 @@ public actor IntegrationTestModels {
return try await task.value
}

public func mistral3Container() async throws -> LMModelContainer {
public func mistral3Container() async throws -> LLModelContainer {
if let task = mistral3Task {
return try await task.value
}
Expand All @@ -165,7 +165,7 @@ public actor IntegrationTestModels {
return try await task.value
}

public func nemotronContainer() async throws -> LMModelContainer {
public func nemotronContainer() async throws -> LLModelContainer {
if let task = nemotronTask {
return try await task.value
}
Expand All @@ -186,7 +186,7 @@ public actor IntegrationTestModels {
return try await task.value
}

public func qwen35Container() async throws -> LMModelContainer {
public func qwen35Container() async throws -> LLModelContainer {
if let task = qwen35Task {
return try await task.value
}
Expand All @@ -212,8 +212,9 @@ public actor IntegrationTestModels {
let tokenizerLoader = self.tokenizerLoader
let id = "nomic_text_v1_5"
print("Loading embedding model: \(id)")
let container = try await MLXEmbedders.loadModelContainer(
from: downloader, using: tokenizerLoader, configuration: .nomic_text_v1_5,
let container = try await EmbedderModelFactory.shared.loadContainer(
from: downloader, using: tokenizerLoader,
configuration: EmbedderRegistry.nomic_text_v1_5,
Comment on lines +215 to +217

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now using an LLM style ModelFactory -- indeed it is the same innards except for the load level creation of the model itself.

The pre-configured ids move to EmbedderRegistry, just like the LLMs have a LLMRegistry

progressHandler: logProgress(id)
)
print("Loaded embedding model: \(id)")
Expand All @@ -227,7 +228,7 @@ private let generateParameters = GenerateParameters(maxTokens: 200, temperature:

public enum ChatSessionTests {

public static func oneShot(container: LMModelContainer) async throws {
public static func oneShot(container: LLModelContainer) async throws {
let session = ChatSession(container, generateParameters: generateParameters)
let result = try await streamAndCollect(
session.streamResponse(
Expand All @@ -238,7 +239,7 @@ public enum ChatSessionTests {
)
}

public static func oneShotStream(container: LMModelContainer) async throws {
public static func oneShotStream(container: LLModelContainer) async throws {
let session = ChatSession(container, generateParameters: generateParameters)
let result = try await streamAndCollect(
session.streamResponse(
Expand All @@ -249,7 +250,7 @@ public enum ChatSessionTests {
)
}

public static func multiTurnConversation(container: LMModelContainer) async throws {
public static func multiTurnConversation(container: LLModelContainer) async throws {
let session = ChatSession(
container, instructions: "You are a helpful assistant. Keep responses brief.",
generateParameters: generateParameters)
Expand All @@ -268,7 +269,7 @@ public enum ChatSessionTests {
)
}

public static func visionModel(container: LMModelContainer) async throws {
public static func visionModel(container: LLModelContainer) async throws {
let session = ChatSession(container, generateParameters: generateParameters)
let redImage = CIImage(color: .red).cropped(
to: CGRect(x: 0, y: 0, width: 100, height: 100))
Expand All @@ -283,7 +284,7 @@ public enum ChatSessionTests {
)
}

public static func streamDetailsWithTools(container: LMModelContainer) async throws {
public static func streamDetailsWithTools(container: LLModelContainer) async throws {
let tools: [ToolSpec] = [weatherToolSchema]
let session = ChatSession(container, generateParameters: generateParameters, tools: tools)

Expand Down Expand Up @@ -334,7 +335,7 @@ public enum ChatSessionTests {
}
}

public static func toolInvocation(container: LMModelContainer) async throws {
public static func toolInvocation(container: LLModelContainer) async throws {
struct EmptyInput: Codable {}

struct TimeOutput: Codable {
Expand Down Expand Up @@ -369,7 +370,7 @@ public enum ChatSessionTests {
)
}

public static func promptRehydration(container: LMModelContainer) async throws {
public static func promptRehydration(container: LLModelContainer) async throws {
let history: [Chat.Message] = [
.system("You are a helpful assistant."),
.user("My name is Bob."),
Expand Down Expand Up @@ -414,8 +415,9 @@ public enum EmbedderTests {
) async throws {
let modelId = "mlx-community/gemma-3-1b-it-qat-4bit"
print("Loading Gemma 3 embedding model: \(modelId)")
let modelContainer = try await MLXEmbedders.loadModelContainer(
from: downloader, using: tokenizerLoader, configuration: .init(id: modelId),
let modelContainer = try await EmbedderModelFactory.shared.loadContainer(
from: downloader, using: tokenizerLoader,
configuration: ModelConfiguration(id: modelId),
progressHandler: logProgress(modelId)
)
print("Loaded Gemma 3 embedding model: \(modelId)")
Expand All @@ -425,8 +427,8 @@ public enum EmbedderTests {
"In the United States, PepsiCo Inc. is a leading soft drink company.",
]

let resultEmbeddings = await modelContainer.perform {
(model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in
let resultEmbeddings = await modelContainer.perform { context in
let tokenizer = context.tokenizer
let encoded = inputs.map {
Comment on lines -428 to 432

@davidkoski davidkoski Apr 10, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The container now works identically to ModelContainer from the LLM side. Well, more or less -- it has been thinned down a bit, but otherwise the same.

Now the closure gets a context rather than a tuple.

tokenizer.encode(text: $0, addSpecialTokens: true)
}
Expand All @@ -446,10 +448,10 @@ public enum EmbedderTests {
let mask = (padded .!= (tokenizer.eosTokenId ?? 0))
let tokenTypes = MLXArray.zeros(like: padded)

let modelOutput = model(
let modelOutput = context.model(
padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask)

let result = pooling(
let result = context.pooling(
modelOutput,
normalize: true, applyLayerNorm: true
)
Expand Down Expand Up @@ -488,8 +490,9 @@ public enum EmbedderTests {
"search_document: Polar Bears",
]

let resultEmbeddings = await container.perform {
(model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in
let resultEmbeddings = await container.perform { context in
let tokenizer = context.tokenizer

let inputs = searchInputs.map {
tokenizer.encode(text: $0, addSpecialTokens: true)
}
Expand All @@ -506,8 +509,9 @@ public enum EmbedderTests {
})
let mask = (padded .!= tokenizer.eosTokenId ?? 0)
let tokenTypes = MLXArray.zeros(like: padded)
let result = pooling(
model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask),
let result = context.pooling(
context.model(
padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask),
normalize: true, applyLayerNorm: true
)
result.eval()
Expand Down Expand Up @@ -539,17 +543,15 @@ public enum EmbedderTests {

public enum ToolCallTests {

// MARK: LFM2

public static func lfm2FormatAutoDetection(container: LMModelContainer) async throws {
public static func lfm2FormatAutoDetection(container: LLModelContainer) async throws {
let config = await container.configuration
try check(
config.toolCallFormat == ToolCallFormat.lfm2,
"Expected .lfm2 tool call format, got: \(String(describing: config.toolCallFormat))"
)
}

public static func lfm2EndToEndGeneration(container: LMModelContainer) async throws {
public static func lfm2EndToEndGeneration(container: LLModelContainer) async throws {
let (result, toolCalls) = try await generateWithTools(
container: container,
userMessage: "What's the weather in Tokyo?")
Expand All @@ -572,17 +574,15 @@ public enum ToolCallTests {
)
}

// MARK: GLM4

public static func glm4FormatAutoDetection(container: LMModelContainer) async throws {
public static func glm4FormatAutoDetection(container: LLModelContainer) async throws {
let config = await container.configuration
try check(
config.toolCallFormat == ToolCallFormat.glm4,
"Expected .glm4 tool call format, got: \(String(describing: config.toolCallFormat))"
)
}

public static func glm4EndToEndGeneration(container: LMModelContainer) async throws {
public static func glm4EndToEndGeneration(container: LLModelContainer) async throws {
let (result, toolCalls) = try await generateWithTools(
container: container,
userMessage: "What's the weather in Paris?")
Expand All @@ -607,15 +607,15 @@ public enum ToolCallTests {

// MARK: Mistral3

public static func mistral3FormatAutoDetection(container: LMModelContainer) async throws {
public static func mistral3FormatAutoDetection(container: LLModelContainer) async throws {
let config = await container.configuration
try check(
config.toolCallFormat == ToolCallFormat.mistral,
"Expected .mistral tool call format, got: \(String(describing: config.toolCallFormat))"
)
}

public static func mistral3EndToEndGeneration(container: LMModelContainer) async throws {
public static func mistral3EndToEndGeneration(container: LLModelContainer) async throws {
let input = UserInput(
chat: [
.system(
Expand Down Expand Up @@ -647,7 +647,7 @@ public enum ToolCallTests {
)
}

public static func mistral3MultiToolGeneration(container: LMModelContainer) async throws {
public static func mistral3MultiToolGeneration(container: LLModelContainer) async throws {
let input = UserInput(
chat: [
.system(
Expand Down Expand Up @@ -680,15 +680,15 @@ public enum ToolCallTests {

// MARK: Nemotron

public static func nemotronFormatAutoDetection(container: LMModelContainer) async throws {
public static func nemotronFormatAutoDetection(container: LLModelContainer) async throws {
let config = await container.configuration
try check(
config.toolCallFormat == ToolCallFormat.xmlFunction,
"Expected .xmlFunction tool call format, got: \(String(describing: config.toolCallFormat))"
)
}

public static func nemotronEndToEndGeneration(container: LMModelContainer) async throws {
public static func nemotronEndToEndGeneration(container: LLModelContainer) async throws {
let input = UserInput(
chat: [
.system(
Expand Down Expand Up @@ -721,7 +721,7 @@ public enum ToolCallTests {
)
}

public static func nemotronMultiToolGeneration(container: LMModelContainer) async throws {
public static func nemotronMultiToolGeneration(container: LLModelContainer) async throws {
let input = UserInput(
chat: [
.system(
Expand Down Expand Up @@ -755,15 +755,15 @@ public enum ToolCallTests {

// MARK: Qwen3.5

public static func qwen35FormatAutoDetection(container: LMModelContainer) async throws {
public static func qwen35FormatAutoDetection(container: LLModelContainer) async throws {
let config = await container.configuration
try check(
config.toolCallFormat == ToolCallFormat.xmlFunction,
"Expected .xmlFunction tool call format, got: \(String(describing: config.toolCallFormat))"
)
}

public static func qwen35EndToEndGeneration(container: LMModelContainer) async throws {
public static func qwen35EndToEndGeneration(container: LLModelContainer) async throws {
let input = UserInput(
chat: [
.system(
Expand Down Expand Up @@ -795,7 +795,7 @@ public enum ToolCallTests {
)
}

public static func qwen35MultiToolGeneration(container: LMModelContainer) async throws {
public static func qwen35MultiToolGeneration(container: LLModelContainer) async throws {
let input = UserInput(
chat: [
.system(
Expand Down Expand Up @@ -830,7 +830,7 @@ public enum ToolCallTests {
// MARK: Helpers

private static func generateWithTools(
container: LMModelContainer,
container: LLModelContainer,
input: UserInput,
maxTokens: Int = 100
) async throws -> (text: String, toolCalls: [ToolCall]) {
Expand Down Expand Up @@ -858,7 +858,7 @@ public enum ToolCallTests {
}

private static func generateWithTools(
container: LMModelContainer,
container: LLModelContainer,
userMessage: String
) async throws -> (text: String, toolCalls: [ToolCall]) {
let input = UserInput(
Expand Down
Loading
Loading