diff --git a/Libraries/BenchmarkHelpers/BenchmarkHelpers.swift b/Libraries/BenchmarkHelpers/BenchmarkHelpers.swift index 93a0d78a6..f67a95aff 100644 --- a/Libraries/BenchmarkHelpers/BenchmarkHelpers.swift +++ b/Libraries/BenchmarkHelpers/BenchmarkHelpers.swift @@ -409,11 +409,11 @@ public func benchmarkVLMLoading( public func benchmarkEmbeddingLoading( from downloader: any Downloader, using tokenizerLoader: any TokenizerLoader, - configuration: MLXEmbedders.ModelConfiguration = .init( + configuration: ModelConfiguration = .init( id: "mlx-community/Qwen3-Embedding-0.6B-8bit"), runs: Int = BenchmarkDefaults.loadingRuns ) async throws -> BenchmarkStats { - _ = try await MLXEmbedders.loadModelContainer( + _ = try await EmbedderModelFactory.shared.loadContainer( from: downloader, using: tokenizerLoader, configuration: configuration ) { _ in } Memory.clearCache() @@ -421,7 +421,7 @@ public func benchmarkEmbeddingLoading( var times: [Double] = [] for i in 1 ... runs { let start = CFAbsoluteTimeGetCurrent() - _ = try await MLXEmbedders.loadModelContainer( + _ = try await EmbedderModelFactory.shared.loadContainer( from: downloader, using: tokenizerLoader, configuration: configuration ) { _ in } let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 diff --git a/Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift b/Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift index 53c7744e2..6cfed8210 100644 --- a/Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift +++ b/Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift @@ -11,8 +11,8 @@ import MLXLMCommon import MLXVLM // Both MLXLMCommon and MLXEmbedders define ModelContainer. -public typealias LMModelContainer = MLXLMCommon.ModelContainer -public typealias EmbeddingModelContainer = MLXEmbedders.ModelContainer +public typealias LLModelContainer = MLXLMCommon.ModelContainer +public typealias EmbeddingModelContainer = MLXEmbedders.EmbedderModelContainer // MARK: - Error @@ -47,20 +47,20 @@ public actor IntegrationTestModels { private let downloader: any Downloader private let tokenizerLoader: any TokenizerLoader - private var llmTask: Task? - private var vlmTask: Task? - private var lfm2Task: Task? - private var glm4Task: Task? - private var mistral3Task: Task? - private var nemotronTask: Task? - private var qwen35Task: Task? + private var llmTask: Task? + private var vlmTask: Task? + private var lfm2Task: Task? + private var glm4Task: Task? + private var mistral3Task: Task? + private var nemotronTask: Task? + private var qwen35Task: Task? public init(downloader: any Downloader, tokenizerLoader: any TokenizerLoader) { self.downloader = downloader self.tokenizerLoader = tokenizerLoader } - public func llmContainer() async throws -> LMModelContainer { + public func llmContainer() async throws -> LLModelContainer { if let task = llmTask { return try await task.value } @@ -81,7 +81,7 @@ public actor IntegrationTestModels { return try await task.value } - public func vlmContainer() async throws -> LMModelContainer { + public func vlmContainer() async throws -> LLModelContainer { if let task = vlmTask { return try await task.value } @@ -102,7 +102,7 @@ public actor IntegrationTestModels { return try await task.value } - public func lfm2Container() async throws -> LMModelContainer { + public func lfm2Container() async throws -> LLModelContainer { if let task = lfm2Task { return try await task.value } @@ -123,7 +123,7 @@ public actor IntegrationTestModels { return try await task.value } - public func glm4Container() async throws -> LMModelContainer { + public func glm4Container() async throws -> LLModelContainer { if let task = glm4Task { return try await task.value } @@ -144,7 +144,7 @@ public actor IntegrationTestModels { return try await task.value } - public func mistral3Container() async throws -> LMModelContainer { + public func mistral3Container() async throws -> LLModelContainer { if let task = mistral3Task { return try await task.value } @@ -165,7 +165,7 @@ public actor IntegrationTestModels { return try await task.value } - public func nemotronContainer() async throws -> LMModelContainer { + public func nemotronContainer() async throws -> LLModelContainer { if let task = nemotronTask { return try await task.value } @@ -186,7 +186,7 @@ public actor IntegrationTestModels { return try await task.value } - public func qwen35Container() async throws -> LMModelContainer { + public func qwen35Container() async throws -> LLModelContainer { if let task = qwen35Task { return try await task.value } @@ -212,8 +212,9 @@ public actor IntegrationTestModels { let tokenizerLoader = self.tokenizerLoader let id = "nomic_text_v1_5" print("Loading embedding model: \(id)") - let container = try await MLXEmbedders.loadModelContainer( - from: downloader, using: tokenizerLoader, configuration: .nomic_text_v1_5, + let container = try await EmbedderModelFactory.shared.loadContainer( + from: downloader, using: tokenizerLoader, + configuration: EmbedderRegistry.nomic_text_v1_5, progressHandler: logProgress(id) ) print("Loaded embedding model: \(id)") @@ -227,7 +228,7 @@ private let generateParameters = GenerateParameters(maxTokens: 200, temperature: public enum ChatSessionTests { - public static func oneShot(container: LMModelContainer) async throws { + public static func oneShot(container: LLModelContainer) async throws { let session = ChatSession(container, generateParameters: generateParameters) let result = try await streamAndCollect( session.streamResponse( @@ -238,7 +239,7 @@ public enum ChatSessionTests { ) } - public static func oneShotStream(container: LMModelContainer) async throws { + public static func oneShotStream(container: LLModelContainer) async throws { let session = ChatSession(container, generateParameters: generateParameters) let result = try await streamAndCollect( session.streamResponse( @@ -249,7 +250,7 @@ public enum ChatSessionTests { ) } - public static func multiTurnConversation(container: LMModelContainer) async throws { + public static func multiTurnConversation(container: LLModelContainer) async throws { let session = ChatSession( container, instructions: "You are a helpful assistant. Keep responses brief.", generateParameters: generateParameters) @@ -268,7 +269,7 @@ public enum ChatSessionTests { ) } - public static func visionModel(container: LMModelContainer) async throws { + public static func visionModel(container: LLModelContainer) async throws { let session = ChatSession(container, generateParameters: generateParameters) let redImage = CIImage(color: .red).cropped( to: CGRect(x: 0, y: 0, width: 100, height: 100)) @@ -283,7 +284,7 @@ public enum ChatSessionTests { ) } - public static func streamDetailsWithTools(container: LMModelContainer) async throws { + public static func streamDetailsWithTools(container: LLModelContainer) async throws { let tools: [ToolSpec] = [weatherToolSchema] let session = ChatSession(container, generateParameters: generateParameters, tools: tools) @@ -334,7 +335,7 @@ public enum ChatSessionTests { } } - public static func toolInvocation(container: LMModelContainer) async throws { + public static func toolInvocation(container: LLModelContainer) async throws { struct EmptyInput: Codable {} struct TimeOutput: Codable { @@ -369,7 +370,7 @@ public enum ChatSessionTests { ) } - public static func promptRehydration(container: LMModelContainer) async throws { + public static func promptRehydration(container: LLModelContainer) async throws { let history: [Chat.Message] = [ .system("You are a helpful assistant."), .user("My name is Bob."), @@ -414,8 +415,9 @@ public enum EmbedderTests { ) async throws { let modelId = "mlx-community/gemma-3-1b-it-qat-4bit" print("Loading Gemma 3 embedding model: \(modelId)") - let modelContainer = try await MLXEmbedders.loadModelContainer( - from: downloader, using: tokenizerLoader, configuration: .init(id: modelId), + let modelContainer = try await EmbedderModelFactory.shared.loadContainer( + from: downloader, using: tokenizerLoader, + configuration: ModelConfiguration(id: modelId), progressHandler: logProgress(modelId) ) print("Loaded Gemma 3 embedding model: \(modelId)") @@ -425,8 +427,8 @@ public enum EmbedderTests { "In the United States, PepsiCo Inc. is a leading soft drink company.", ] - let resultEmbeddings = await modelContainer.perform { - (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in + let resultEmbeddings = await modelContainer.perform { context in + let tokenizer = context.tokenizer let encoded = inputs.map { tokenizer.encode(text: $0, addSpecialTokens: true) } @@ -446,10 +448,10 @@ public enum EmbedderTests { let mask = (padded .!= (tokenizer.eosTokenId ?? 0)) let tokenTypes = MLXArray.zeros(like: padded) - let modelOutput = model( + let modelOutput = context.model( padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask) - let result = pooling( + let result = context.pooling( modelOutput, normalize: true, applyLayerNorm: true ) @@ -488,8 +490,9 @@ public enum EmbedderTests { "search_document: Polar Bears", ] - let resultEmbeddings = await container.perform { - (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in + let resultEmbeddings = await container.perform { context in + let tokenizer = context.tokenizer + let inputs = searchInputs.map { tokenizer.encode(text: $0, addSpecialTokens: true) } @@ -506,8 +509,9 @@ public enum EmbedderTests { }) let mask = (padded .!= tokenizer.eosTokenId ?? 0) let tokenTypes = MLXArray.zeros(like: padded) - let result = pooling( - model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), + let result = context.pooling( + context.model( + padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), normalize: true, applyLayerNorm: true ) result.eval() @@ -539,9 +543,7 @@ public enum EmbedderTests { public enum ToolCallTests { - // MARK: LFM2 - - public static func lfm2FormatAutoDetection(container: LMModelContainer) async throws { + public static func lfm2FormatAutoDetection(container: LLModelContainer) async throws { let config = await container.configuration try check( config.toolCallFormat == ToolCallFormat.lfm2, @@ -549,7 +551,7 @@ public enum ToolCallTests { ) } - public static func lfm2EndToEndGeneration(container: LMModelContainer) async throws { + public static func lfm2EndToEndGeneration(container: LLModelContainer) async throws { let (result, toolCalls) = try await generateWithTools( container: container, userMessage: "What's the weather in Tokyo?") @@ -572,9 +574,7 @@ public enum ToolCallTests { ) } - // MARK: GLM4 - - public static func glm4FormatAutoDetection(container: LMModelContainer) async throws { + public static func glm4FormatAutoDetection(container: LLModelContainer) async throws { let config = await container.configuration try check( config.toolCallFormat == ToolCallFormat.glm4, @@ -582,7 +582,7 @@ public enum ToolCallTests { ) } - public static func glm4EndToEndGeneration(container: LMModelContainer) async throws { + public static func glm4EndToEndGeneration(container: LLModelContainer) async throws { let (result, toolCalls) = try await generateWithTools( container: container, userMessage: "What's the weather in Paris?") @@ -607,7 +607,7 @@ public enum ToolCallTests { // MARK: Mistral3 - public static func mistral3FormatAutoDetection(container: LMModelContainer) async throws { + public static func mistral3FormatAutoDetection(container: LLModelContainer) async throws { let config = await container.configuration try check( config.toolCallFormat == ToolCallFormat.mistral, @@ -615,7 +615,7 @@ public enum ToolCallTests { ) } - public static func mistral3EndToEndGeneration(container: LMModelContainer) async throws { + public static func mistral3EndToEndGeneration(container: LLModelContainer) async throws { let input = UserInput( chat: [ .system( @@ -647,7 +647,7 @@ public enum ToolCallTests { ) } - public static func mistral3MultiToolGeneration(container: LMModelContainer) async throws { + public static func mistral3MultiToolGeneration(container: LLModelContainer) async throws { let input = UserInput( chat: [ .system( @@ -680,7 +680,7 @@ public enum ToolCallTests { // MARK: Nemotron - public static func nemotronFormatAutoDetection(container: LMModelContainer) async throws { + public static func nemotronFormatAutoDetection(container: LLModelContainer) async throws { let config = await container.configuration try check( config.toolCallFormat == ToolCallFormat.xmlFunction, @@ -688,7 +688,7 @@ public enum ToolCallTests { ) } - public static func nemotronEndToEndGeneration(container: LMModelContainer) async throws { + public static func nemotronEndToEndGeneration(container: LLModelContainer) async throws { let input = UserInput( chat: [ .system( @@ -721,7 +721,7 @@ public enum ToolCallTests { ) } - public static func nemotronMultiToolGeneration(container: LMModelContainer) async throws { + public static func nemotronMultiToolGeneration(container: LLModelContainer) async throws { let input = UserInput( chat: [ .system( @@ -755,7 +755,7 @@ public enum ToolCallTests { // MARK: Qwen3.5 - public static func qwen35FormatAutoDetection(container: LMModelContainer) async throws { + public static func qwen35FormatAutoDetection(container: LLModelContainer) async throws { let config = await container.configuration try check( config.toolCallFormat == ToolCallFormat.xmlFunction, @@ -763,7 +763,7 @@ public enum ToolCallTests { ) } - public static func qwen35EndToEndGeneration(container: LMModelContainer) async throws { + public static func qwen35EndToEndGeneration(container: LLModelContainer) async throws { let input = UserInput( chat: [ .system( @@ -795,7 +795,7 @@ public enum ToolCallTests { ) } - public static func qwen35MultiToolGeneration(container: LMModelContainer) async throws { + public static func qwen35MultiToolGeneration(container: LLModelContainer) async throws { let input = UserInput( chat: [ .system( @@ -830,7 +830,7 @@ public enum ToolCallTests { // MARK: Helpers private static func generateWithTools( - container: LMModelContainer, + container: LLModelContainer, input: UserInput, maxTokens: Int = 100 ) async throws -> (text: String, toolCalls: [ToolCall]) { @@ -858,7 +858,7 @@ public enum ToolCallTests { } private static func generateWithTools( - container: LMModelContainer, + container: LLModelContainer, userMessage: String ) async throws -> (text: String, toolCalls: [ToolCall]) { let input = UserInput( diff --git a/Libraries/MLXEmbedders/BaseConfiguration.swift b/Libraries/MLXEmbedders/BaseConfiguration.swift deleted file mode 100644 index 58fee60ab..000000000 --- a/Libraries/MLXEmbedders/BaseConfiguration.swift +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright © 2025 Apple Inc. - -import Foundation -import MLX - -/// The fundamental configuration for any MLX-based model. -/// -/// `BaseConfiguration` provides the metadata necessary to identify the model architecture -/// (`modelType`) and describes the quantization parameters used to compress the model's weights. -/// It is designed to be decoded directly from a model repository's `config.json`. -public struct BaseConfiguration: Codable, Sendable { - - /// The architecture identifier (e.g., "bert", "roberta", "xlm-roberta"). - public let modelType: String - - /// Configuration parameters for weight quantization. - /// - /// MLX uses group-wise quantization to reduce memory footprint. This struct - /// defines how weights are grouped and the precision (bits) used for each group. - public struct Quantization: Codable, Sendable, Equatable { - - /// Initializes a new quantization configuration. - /// - Parameters: - /// - groupSize: The number of weights that share the same scale and bias. - /// - bits: The bit-depth of the quantized weights (e.g., 4 or 8). - public init(groupSize: Int, bits: Int) { - self.groupSize = groupSize - self.bits = bits - } - - /// The size of the quantization group. - public let groupSize: Int - - /// The number of bits per weight. - public let bits: Int - - /// Internal storage for the quantization mode. - private var _mode: QuantizationMode? = nil - - /// The quantization method to use (defaults to `.affine`). - /// - /// Affine quantization (asymmetric) uses both a scale and a zero-point - /// to map floating point values to integers. - public var mode: QuantizationMode { _mode ?? .affine } - - /// Converts the configuration into a tuple format compatible with `MLX.quantize`. - public var asTuple: (Int, Int, QuantizationMode) { (groupSize, bits, mode) } - - enum CodingKeys: String, CodingKey { - case groupSize = "group_size" - case bits = "bits" - case _mode = "mode" - } - } - - /// Instructions for handling individual layers during the quantization process. - public enum QuantizationOption: Sendable { - /// Do not quantize this specific layer (keep it in high precision). - case skip - /// Quantize this layer using the provided parameters. - case quantize(Quantization) - } - - /// A container for per-layer quantization settings. - /// - /// This allows for "Mixed-Precision" or "Heterogeneous" quantization, where - /// sensitive layers (like the embedding head) can be kept at higher precision - /// while the rest of the model is compressed. - public struct PerLayerQuantization: Sendable { - /// The default quantization for any layer not explicitly named in `perLayerQuantization`. - public var quantization: Quantization? = nil - - /// A dictionary mapping layer paths (e.g., "model.embed_tokens") to their quantization options. - public var perLayerQuantization: [String: QuantizationOption] - - public init( - quantization: BaseConfiguration.Quantization? = nil, - perLayerQuantization: [String: BaseConfiguration.QuantizationOption] - ) { - self.quantization = quantization - self.perLayerQuantization = perLayerQuantization - } - - /// Resolves the quantization parameters for a specific layer. - /// - Parameter layer: The path/name of the layer. - /// - Returns: The `Quantization` settings to apply, or `nil` if the layer should be skipped. - public func quantization(layer: String) -> Quantization? { - if let perLayer = perLayerQuantization[layer] { - switch perLayer { - case .skip: - return nil - case .quantize(let quantization): - return quantization - } - } else { - return quantization - } - } - } - - /// An internal container designed to handle the mixed JSON structure found in `config.json`. - /// - /// Quantization configs in MLX often interleave global keys (like `bits`) with - /// specific layer keys (like `model.layers.0...`). This container uses manual - /// decoding to separate these interleaved values. - struct QuantizationContainer: Codable, Sendable { - var quantization: Quantization - var perLayerQuantization: PerLayerQuantization - - /// A custom CodingKey used to iterate over arbitrary layer names in JSON. - internal struct _DictionaryCodingKey: CodingKey { - internal let stringValue: String - internal let intValue: Int? - - internal init(stringValue: String) { - self.stringValue = stringValue - self.intValue = Int(stringValue) - } - - internal init(intValue: Int) { - self.stringValue = "\(intValue)" - self.intValue = intValue - } - } - - init(from decoder: any Decoder) throws { - // 1. Decode global quantization (bits/group_size) from the current level - self.quantization = try Quantization(from: decoder) - - // 2. Decode interleaved per-layer values - var perLayerQuantization = [String: QuantizationOption]() - let container = try decoder.container(keyedBy: _DictionaryCodingKey.self) - - for key in container.allKeys { - switch key.stringValue { - case Quantization.CodingKeys.groupSize.rawValue: continue - case Quantization.CodingKeys.bits.rawValue: continue - case Quantization.CodingKeys._mode.rawValue: continue - - default: - // If the value is a boolean 'false', we treat it as .skip - if let f = try? container.decode(Bool.self, forKey: key) { - if !f { - perLayerQuantization[key.stringValue] = .skip - } - } else { - // Otherwise, try to decode a specific Quantization object for this layer - perLayerQuantization[key.stringValue] = .quantize( - try container.decode(Quantization.self, forKey: key)) - } - } - } - self.perLayerQuantization = PerLayerQuantization( - quantization: quantization, perLayerQuantization: perLayerQuantization) - } - - func encode(to encoder: any Encoder) throws { - try quantization.encode(to: encoder) - - var container = encoder.container(keyedBy: _DictionaryCodingKey.self) - for (key, value) in perLayerQuantization.perLayerQuantization { - switch value { - case .skip: - try container.encode(false, forKey: .init(stringValue: key)) - case .quantize(let q): - try container.encode(q, forKey: .init(stringValue: key)) - } - } - } - } - - /// Internal storage for quantization details extracted from `config.json`. - var quantizationContainer: QuantizationContainer? - - /// The default quantization settings. - @available(*, deprecated, message: "Please use perLayerQuantization instead") - public var quantization: Quantization? { - quantizationContainer?.quantization - } - - /// The per-layer quantization settings, including the default fallback. - public var perLayerQuantization: PerLayerQuantization? { - quantizationContainer?.perLayerQuantization - } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case quantizationContainer = "quantization" - } -} diff --git a/Libraries/MLXEmbedders/Configuration.swift b/Libraries/MLXEmbedders/Configuration.swift deleted file mode 100644 index 8da6aade1..000000000 --- a/Libraries/MLXEmbedders/Configuration.swift +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation -import MLXLMCommon - -private class ModelTypeRegistry: @unchecked Sendable { - - // Note: Using NSLock as we have very small (just dictionary get/set) - // critical sections and expect no contention. This allows the methods - // to remain synchronous. - private let lock = NSLock() - - private var creators: [String: @Sendable (Data) throws -> EmbeddingModel] = [ - "bert": { data in - let configuration = try JSONDecoder.json5().decode(BertConfiguration.self, from: data) - return BertModel(configuration) - }, - "roberta": { data in - let configuration = try JSONDecoder.json5().decode(BertConfiguration.self, from: data) - return BertModel(configuration) - }, - "xlm-roberta": { data in - let configuration = try JSONDecoder.json5().decode(BertConfiguration.self, from: data) - return BertModel(configuration) - }, - "distilbert": { data in - let configuration = try JSONDecoder.json5().decode(BertConfiguration.self, from: data) - return BertModel(configuration) - }, - "nomic_bert": { data in - let configuration = try JSONDecoder.json5().decode( - NomicBertConfiguration.self, from: data) - return NomicBertModel(configuration, pooler: false) - }, - "qwen3": { data in - let configuration = try JSONDecoder.json5().decode(Qwen3Configuration.self, from: data) - return Qwen3Model(configuration) - }, - "gemma3": { data in - let configuration = try JSONDecoder.json5().decode(Gemma3Configuration.self, from: data) - return EmbeddingGemma(configuration) - }, - "gemma3_text": { data in - let configuration = try JSONDecoder.json5().decode(Gemma3Configuration.self, from: data) - return EmbeddingGemma(configuration) - }, - "gemma3n": { data in - let configuration = try JSONDecoder.json5().decode(Gemma3Configuration.self, from: data) - return EmbeddingGemma(configuration) - }, - ] - - public func registerModelType( - _ type: String, creator: @Sendable @escaping (Data) throws -> EmbeddingModel - ) { - lock.withLock { - creators[type] = creator - } - } - - public func createModel(configuration: Data, rawValue: String) throws -> EmbeddingModel { - let creator = lock.withLock { - creators[rawValue] - } - guard let creator else { - throw EmbedderError.unsupportedModelType(rawValue) - } - return try creator(configuration) - } - -} - -private let modelTypeRegistry = ModelTypeRegistry() - -public struct ModelType: RawRepresentable, Codable, Sendable { - public let rawValue: String - - public init(rawValue: String) { - self.rawValue = rawValue - } - - public static func registerModelType( - _ type: String, creator: @Sendable @escaping (Data) throws -> EmbeddingModel - ) { - modelTypeRegistry.registerModelType(type, creator: creator) - } - - public func createModel(configuration: Data) throws -> EmbeddingModel { - try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue) - } -} diff --git a/Libraries/MLXEmbedders/EmbedderModelContainer.swift b/Libraries/MLXEmbedders/EmbedderModelContainer.swift new file mode 100644 index 000000000..5ea6425d3 --- /dev/null +++ b/Libraries/MLXEmbedders/EmbedderModelContainer.swift @@ -0,0 +1,114 @@ +// Copyright © 2026 Apple Inc. + +import Foundation +import MLXLMCommon + +/// Container for embedder models that guarantees single threaded access. +/// +/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access +/// the model and/or tokenizer (any values from the ``EmbedderModelContext``): +/// +/// ```swift +/// let resultEmbeddings = await modelContainer.perform { context in +/// let tokenizer = context.tokenizer +/// let encoded = inputs.map { +/// tokenizer.encode(text: $0, addSpecialTokens: true) +/// } +/// ... +/// let modelOutput = context.model( +/// padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask) +/// +/// let result = context.pooling( +/// modelOutput, +/// normalize: true, applyLayerNorm: true +/// ) +/// result.eval() +/// return result.map { $0.asArray(Float.self) } +/// } +/// ``` +public final class EmbedderModelContainer: Sendable { + private let context: SerialAccessContainer + + public var configuration: ModelConfiguration { + get async { + await context.read { $0.configuration } + } + } + + public var tokenizer: Tokenizer { + get async { + await context.read { $0.tokenizer } + } + } + + public var poolingStrategy: Pooling.Strategy { + get async { + await context.read { $0.pooling.strategy } + } + } + + public init(context: consuming EmbedderModelContext) { + self.context = .init(context) + } + + /// Perform an action on the ``EmbedderModelContext``. + /// Callers _must_ eval any `MLXArray` before returning as `MLXArray` is not `Sendable`. + /// + /// - Note: The closure receives `EmbedderModelContext` which is not `Sendable`. This is intentional - + /// the closure runs within the actor's isolation, ensuring thread-safe access to the model. + /// - Note: The `sending` keyword indicates the return value is transferred (not shared) across + /// isolation boundaries, allowing non-Sendable types to be safely returned. + public func perform( + _ action: @Sendable (EmbedderModelContext) async throws -> sending R + ) async rethrows -> sending R { + try await context.read { + try await action($0) + } + } + + @available(*, deprecated, message: "use perform(_: (EmbedderModelContext) -> R) instead") + public func perform( + _ action: @Sendable (EmbeddingModel, Tokenizer, Pooling) async throws -> sending R + ) async rethrows -> sending R { + try await context.read { + try await action($0.model, $0.tokenizer, $0.pooling) + } + } + + /// Perform an action on the ``EmbedderModelContext`` with additional (non `Sendable`) context values. + /// Callers _must_ eval any `MLXArray` before returning as + /// `MLXArray` is not `Sendable`. + public func perform( + nonSendable values: consuming V, + _ action: @Sendable (EmbedderModelContext, V) async throws -> R + ) async rethrows -> sending R { + let values = SendableBox(values) + return try await context.read { + try await action($0, values.consume()) + } + } + + /// Update the owned `EmbedderModelContext`. + /// - Parameter action: update action + public func update(_ action: @Sendable (inout EmbedderModelContext) -> Void) async { + await context.update { + action(&$0) + } + } + + // MARK: - Thread-safe convenience methods + + /// The resolved local model directory for the loaded container. + public var modelDirectory: URL { + get async throws { + try (await configuration).modelDirectory + } + } + + /// The resolved local tokenizer directory for the loaded container. + public var tokenizerDirectory: URL { + get async throws { + try (await configuration).tokenizerDirectory + } + } +} diff --git a/Libraries/MLXEmbedders/EmbeddingModel.swift b/Libraries/MLXEmbedders/EmbeddingModel.swift index ca8f90882..9049fcea5 100644 --- a/Libraries/MLXEmbedders/EmbeddingModel.swift +++ b/Libraries/MLXEmbedders/EmbeddingModel.swift @@ -5,97 +5,12 @@ import MLX import MLXLMCommon import MLXNN -/// Container for models that guarantees single threaded access. -/// -/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access -/// the model and/or tokenizer: -/// -/// ```swift -/// let promptTokens = await modelContainer.perform { _, tokenizer in -/// tokenizer.encode(text: prompt) -/// } -/// ``` -/// -/// or: -/// -/// ```swift -/// let result = await modelContainer.perform { model, tokenizer in -/// LLM.generate( -/// promptTokens: promptTokens, parameters: generateParameters, model: model, -/// tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens -/// ) { tokens in -/// ... -/// } -/// } -/// ``` -public actor ModelContainer { - let model: EmbeddingModel - let tokenizer: Tokenizer - let pooler: Pooling - - public init( - model: EmbeddingModel, - tokenizer: Tokenizer, - pooler: Pooling = Pooling(strategy: .none) - ) { - self.model = model - self.tokenizer = tokenizer - self.pooler = pooler - } - - /// Build the model and tokenizer without passing non-sendable data over isolation barriers - public init( - modelDirectory: URL, - tokenizerDirectory: URL, - configuration: ModelConfiguration, - tokenizerLoader: any TokenizerLoader - ) async throws { - // Load tokenizer and model in parallel - async let tokenizerTask = tokenizerLoader.load(from: tokenizerDirectory) - - self.model = try loadSynchronous( - modelDirectory: modelDirectory, modelName: configuration.name) - self.pooler = loadPooling(modelDirectory: modelDirectory, model: model) - - self.tokenizer = try await tokenizerTask - } - - /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as - /// `MLXArray` is not `Sendable`. - public func perform(_ action: @Sendable (EmbeddingModel, Tokenizer, Pooling) throws -> R) - rethrows -> R - { - try action(model, tokenizer, pooler) - } -} - -extension Module { - - /// Compute the number of parameters in a possibly quantized model - public func numParameters() -> Int { - return leafModules().flattenedValues().map { - mod -> Int in - if let qlin = mod as? QuantizedLinear { - return qlin.scales.size * qlin.groupSize - } else if let qemb = mod as? QuantizedEmbedding { - return qemb.scales.size * qemb.groupSize - } else { - return mod.parameters().flattenedValues().reduce( - 0, - { - $0 + $1.size - }) - } - }.reduce(0, +) - } -} - public struct EmbeddingModelOutput { public let hiddenStates: MLXArray? public let pooledOutput: MLXArray? } -public protocol EmbeddingModel: Module { +public protocol EmbeddingModel: BaseLanguageModel { var vocabularySize: Int { get } var poolingStrategy: Pooling.Strategy? { get } @@ -105,9 +20,6 @@ public protocol EmbeddingModel: Module { tokenTypeIds: MLXArray?, attentionMask: MLXArray? ) -> EmbeddingModelOutput - - /// Optionally preprocess the weights and modify / remove values as needed. - func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] } extension EmbeddingModel { diff --git a/Libraries/MLXEmbedders/Load.swift b/Libraries/MLXEmbedders/Load.swift deleted file mode 100644 index c93f25e57..000000000 --- a/Libraries/MLXEmbedders/Load.swift +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation -import MLX -import MLXLMCommon -import MLXNN - -/// Errors encountered during the model loading and initialization process. -/// -/// This enum provides detailed feedback for failures in model type identification, -/// file access, JSON decoding, and missing configuration files. -public enum EmbedderError: LocalizedError { - - /// The specified `model_type` in `config.json` is not supported by the current implementation. - case unsupportedModelType(String) - - /// A required file could not be read from the disk. - /// - Parameters: - /// - fileName: The name of the file (e.g., "config.json"). - /// - modelName: The name/ID of the model being loaded. - /// - error: The underlying system error. - case configurationFileError(String, String, Error) - - /// The configuration file exists but contains invalid JSON or missing required fields. - case configurationDecodingError(String, String, DecodingError) - - /// A human-readable description of the error. - public var errorDescription: String? { - switch self { - case .unsupportedModelType(let type): - return "Unsupported model type: \(type)" - case .configurationFileError(let file, let modelName, let error): - return "Error reading '\(file)' for model '\(modelName)': \(error.localizedDescription)" - case .configurationDecodingError(let file, let modelName, let decodingError): - let errorDetail = extractDecodingErrorDetail(decodingError) - return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)" - } - } - - /// Internal helper to provide specific details about JSON decoding failures, - /// such as the exact key path that failed. - private func extractDecodingErrorDetail(_ error: DecodingError) -> String { - switch error { - case .keyNotFound(let key, let context): - let path = (context.codingPath + [key]).map { $0.stringValue }.joined(separator: ".") - return "Missing field '\(path)'" - case .typeMismatch(_, let context): - let path = context.codingPath.map { $0.stringValue }.joined(separator: ".") - return "Type mismatch at '\(path)'" - case .valueNotFound(_, let context): - let path = context.codingPath.map { $0.stringValue }.joined(separator: ".") - return "Missing value at '\(path)'" - case .dataCorrupted(let context): - if context.codingPath.isEmpty { - return "Invalid JSON" - } else { - let path = context.codingPath.map { $0.stringValue }.joined(separator: ".") - return "Invalid data at '\(path)'" - } - @unknown default: - return error.localizedDescription - } - } -} - -/// Resolve model and tokenizer directories from a ``ModelConfiguration`` -/// using a ``Downloader``. -/// -/// - Parameters: -/// - downloader: The downloader to use for fetching remote resources. -/// - configuration: The configuration identifying the model. -/// - useLatest: When true, always checks the provider for updates. -/// - progressHandler: A closure to monitor download progress. -/// - Returns: A tuple of (modelDirectory, tokenizerDirectory). -func resolveDirectories( - from downloader: any Downloader, - configuration: ModelConfiguration, - useLatest: Bool = false, - progressHandler: @Sendable @escaping (Progress) -> Void -) async throws -> (modelDirectory: URL, tokenizerDirectory: URL) { - let modelDirectory: URL - switch configuration.id { - case .id(let id, let revision): - modelDirectory = try await downloader.download( - id: id, revision: revision, - matching: modelDownloadPatterns, - useLatest: useLatest, - progressHandler: progressHandler) - case .directory(let directory): - modelDirectory = directory - } - - let tokenizerDirectory: URL - switch configuration.tokenizerSource { - case .id(let id, let revision): - tokenizerDirectory = try await downloader.download( - id: id, revision: revision, - matching: tokenizerDownloadPatterns, - useLatest: useLatest, - progressHandler: { _ in }) - case .directory(let directory): - tokenizerDirectory = directory - case nil: - tokenizerDirectory = modelDirectory - } - - return (modelDirectory, tokenizerDirectory) -} - -/// Asynchronously loads the `EmbeddingModel` and its associated `Tokenizer`. -/// -/// This is the primary high-level function for initializing an embedding pipeline. -/// It leverages `async let` to load the tokenizer in parallel while the model -/// structure is being built synchronously. -/// -/// - Parameters: -/// - downloader: The downloader to use for fetching remote resources. -/// - tokenizerLoader: The tokenizer loader to use for loading the tokenizer. -/// - configuration: The model configuration. -/// - useLatest: When true, always checks the provider for updates. -/// - progressHandler: A closure for tracking download progress. -/// - Returns: A tuple containing the initialized `EmbeddingModel` and `Tokenizer`. -public func load( - from downloader: any Downloader, - using tokenizerLoader: any TokenizerLoader, - configuration: ModelConfiguration, - useLatest: Bool = false, - progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } -) async throws -> (EmbeddingModel, Tokenizer) { - let (modelDirectory, tokenizerDirectory) = try await resolveDirectories( - from: downloader, configuration: configuration, useLatest: useLatest, - progressHandler: progressHandler) - - async let tokenizerTask = tokenizerLoader.load(from: tokenizerDirectory) - let model = try loadSynchronous(modelDirectory: modelDirectory, modelName: configuration.name) - let tokenizer = try await tokenizerTask - - return (model, tokenizer) -} - -/// Synchronously initializes the model architecture, loads weights, and applies quantization. -/// -/// This function performs the following steps: -/// 1. Reads and decodes `config.json`. -/// 2. Instantiates the specific model class based on `model_type`. -/// 3. Recursively scans the directory for `.safetensors` weight files. -/// 4. Applies quantization if defined in the configuration. -/// 5. Updates the model parameters and performs an initial evaluation (`eval`). -/// -/// - Parameters: -/// - modelDirectory: The local `URL` containing model files. -/// - modelName: The display name of the model for error reporting. -/// - Returns: A fully initialized and weighted `EmbeddingModel`. -func loadSynchronous(modelDirectory: URL, modelName: String) throws -> EmbeddingModel { - let configurationURL = modelDirectory.appending(component: "config.json") - let configData: Data - do { - configData = try Data(contentsOf: configurationURL) - } catch { - throw EmbedderError.configurationFileError( - configurationURL.lastPathComponent, modelName, error) - } - - let baseConfig: BaseConfiguration - do { - baseConfig = try JSONDecoder.json5().decode(BaseConfiguration.self, from: configData) - } catch let error as DecodingError { - throw EmbedderError.configurationDecodingError( - configurationURL.lastPathComponent, modelName, error) - } - - let modelType = ModelType(rawValue: baseConfig.modelType) - let model: EmbeddingModel - do { - model = try modelType.createModel(configuration: configData) - } catch let error as DecodingError { - throw EmbedderError.configurationDecodingError( - configurationURL.lastPathComponent, modelName, error) - } - - var weights = [String: MLXArray]() - let enumerator = FileManager.default.enumerator( - at: modelDirectory, includingPropertiesForKeys: nil)! - for case let url as URL in enumerator { - if url.pathExtension == "safetensors" { - let w = try loadArrays(url: url) - for (key, value) in w { - weights[key] = value - } - } - } - - weights = model.sanitize(weights: weights) - - if let perLayerQuantization = baseConfig.perLayerQuantization { - quantize(model: model) { path, module in - if weights["\(path).scales"] != nil { - return perLayerQuantization.quantization(layer: path)?.asTuple - } else { - return nil - } - } - } - - let parameters = ModuleParameters.unflattened(weights) - try model.update(parameters: parameters, verify: [.all]) - - eval(model) - - return model -} - -/// Asynchronously loads a `ModelContainer` for thread-safe model access. -/// -/// The `ModelContainer` is recommended for applications where multiple threads -/// or tasks may need to access the embedding model simultaneously. -/// -/// - Parameters: -/// - downloader: The downloader to use for fetching remote resources. -/// - tokenizerLoader: The tokenizer loader to use for loading the tokenizer. -/// - configuration: The model configuration. -/// - useLatest: When true, always checks the provider for updates. -/// - progressHandler: A closure for tracking download progress. -/// - Returns: A thread-safe `ModelContainer` instance. -public func loadModelContainer( - from downloader: any Downloader, - using tokenizerLoader: any TokenizerLoader, - configuration: ModelConfiguration, - useLatest: Bool = false, - progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } -) async throws -> ModelContainer { - let (modelDirectory, tokenizerDirectory) = try await resolveDirectories( - from: downloader, configuration: configuration, useLatest: useLatest, - progressHandler: progressHandler) - - return try await ModelContainer( - modelDirectory: modelDirectory, - tokenizerDirectory: tokenizerDirectory, - configuration: configuration, - tokenizerLoader: tokenizerLoader) -} - -/// Load an embedding model from a local directory. -/// -/// No downloader is needed — the model and tokenizer are loaded from -/// the given directory. -/// -/// - Parameters: -/// - directory: The local directory containing model files. -/// - tokenizerLoader: The tokenizer loader to use for loading the tokenizer. -/// - Returns: A tuple containing the initialized `EmbeddingModel` and `Tokenizer`. -public func load( - from directory: URL, - using tokenizerLoader: any TokenizerLoader -) async throws -> (EmbeddingModel, Tokenizer) { - let name = - directory.deletingLastPathComponent().lastPathComponent + "/" - + directory.lastPathComponent - async let tokenizerTask = tokenizerLoader.load(from: directory) - let model = try loadSynchronous(modelDirectory: directory, modelName: name) - let tokenizer = try await tokenizerTask - return (model, tokenizer) -} - -/// Load an embedding model container from a local directory. -/// -/// No downloader is needed — the model and tokenizer are loaded from -/// the given directory. -/// -/// - Parameters: -/// - directory: The local directory containing model files. -/// - tokenizerLoader: The tokenizer loader to use for loading the tokenizer. -/// - Returns: A thread-safe `ModelContainer` instance. -public func loadModelContainer( - from directory: URL, - using tokenizerLoader: any TokenizerLoader -) async throws -> ModelContainer { - try await ModelContainer( - modelDirectory: directory, - tokenizerDirectory: directory, - configuration: ModelConfiguration(directory: directory), - tokenizerLoader: tokenizerLoader) -} diff --git a/Libraries/MLXEmbedders/ModelFactory.swift b/Libraries/MLXEmbedders/ModelFactory.swift new file mode 100644 index 000000000..8997de7bb --- /dev/null +++ b/Libraries/MLXEmbedders/ModelFactory.swift @@ -0,0 +1,221 @@ +// Copyright © 2026 Apple Inc. + +import Foundation +import MLXLMCommon +import MLXNN + +private func create( + _ configurationType: C.Type, _ modelInit: @escaping (C) -> M +) -> (Data) throws -> M { + { data in + let configuration = try JSONDecoder.json5().decode(C.self, from: data) + return modelInit(configuration) + } +} + +/// Registry of model type, e.g 'bert', to functions that can instantiate the model from configuration. +public enum EmbedderTypeRegistry { + + public static let shared: ModelTypeRegistry = .init(creators: [ + "bert": create(BertConfiguration.self) { BertModel($0) }, + "roberta": create(BertConfiguration.self) { BertModel($0) }, + "xlm-roberta": create(BertConfiguration.self) { BertModel($0) }, + "distilbert": create(BertConfiguration.self) { BertModel($0) }, + + "nomic_bert": create(NomicBertConfiguration.self) { NomicBertModel($0, pooler: false) }, + "qwen3": create(Qwen3Configuration.self) { Qwen3Model($0) }, + + "gemma3": create(Gemma3Configuration.self) { EmbeddingGemma($0) }, + "gemma3_text": create(Gemma3Configuration.self) { EmbeddingGemma($0) }, + "gemma3n": create(Gemma3Configuration.self) { EmbeddingGemma($0) }, + ]) + +} + +/// Registry of known embedder model configurations. +public class EmbedderRegistry: AbstractModelRegistry, @unchecked Sendable { + + /// Shared instance with default model configurations. + public static let shared = EmbedderRegistry(modelConfigurations: all()) + + /// BGE Micro v2 (TaylorAI) - optimized for extremely low latency. + public static let bge_micro = ModelConfiguration(id: "TaylorAI/bge-micro-v2") + /// GTE Tiny - a small, efficient embedding model. + public static let gte_tiny = ModelConfiguration(id: "TaylorAI/gte-tiny") + /// MiniLM-L6 - the industry-standard small embedding model. + public static let minilm_l6 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L6-v2") + /// Snowflake Arctic Embed XS. + public static let snowflake_xs = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-xs") + /// MiniLM-L12 - a more accurate version of MiniLM. + public static let minilm_l12 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L12-v2") + /// BGE Small en v1.5. + public static let bge_small = ModelConfiguration(id: "BAAI/bge-small-en-v1.5") + /// Multilingual E5 Small - supports over 100 languages. + public static let multilingual_e5_small = ModelConfiguration( + id: "intfloat/multilingual-e5-small") + /// BGE Base en v1.5. + public static let bge_base = ModelConfiguration(id: "BAAI/bge-base-en-v1.5") + /// Nomic Embed Text v1. + public static let nomic_text_v1 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1") + /// Nomic Embed Text v1.5 - supports Matryoshka embeddings. + public static let nomic_text_v1_5 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1.5") + /// BGE Large en v1.5. + public static let bge_large = ModelConfiguration(id: "BAAI/bge-large-en-v1.5") + /// Snowflake Arctic Embed L. + public static let snowflake_lg = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-l") + /// BGE-M3 - Multi-lingual, Multi-functional, Multi-granularity. + public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3") + /// Mixedbread AI Large v1. + public static let mixedbread_large = ModelConfiguration( + id: "mixedbread-ai/mxbai-embed-large-v1") + /// Qwen3 Embedding 0.6B - 4-bit quantized version. + public static let qwen3_embedding = ModelConfiguration( + id: "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ") + + private static func all() -> [ModelConfiguration] { + [ + bge_micro, + gte_tiny, + minilm_l6, + snowflake_xs, + minilm_l12, + bge_small, + multilingual_e5_small, + bge_base, + nomic_text_v1, + nomic_text_v1_5, + bge_large, + snowflake_lg, + bge_m3, + mixedbread_large, + qwen3_embedding, + ] + } +} + +/// Context of values that work together to provide an ``EmbeddingModel``. +/// +/// This is created using a ``EmbedderModelFactory`` and often used +/// inside a ``EmbedderModelContainer``. +public struct EmbedderModelContext { + public var configuration: ModelConfiguration + public var model: any EmbeddingModel + public var tokenizer: any Tokenizer + public let pooling: Pooling + + public init( + configuration: ModelConfiguration, model: any EmbeddingModel, + tokenizer: any Tokenizer, pooling: Pooling + ) { + self.configuration = configuration + self.model = model + self.tokenizer = tokenizer + self.pooling = pooling + } +} + +/// Factory for creating new Embedder models. +/// +/// Callers can use the `shared` instance or create a new instance if custom configuration +/// is required. +/// +/// ```swift +/// let downloader: any Downloader +/// let tokenizerLoader: any TokenizerLoader +/// let modelId = "mlx-community/gemma-3-1b-it-qat-4bit" +/// let modelContainer = try await EmbedderModelFactory.shared.loadContainer( +/// from: downloader, using: tokenizerLoader, configuration: .init(id: modelId), +/// progressHandler: logProgress(modelId) +/// ) +/// ``` +public final class EmbedderModelFactory: GenericModelFactory { + + public typealias ContextType = EmbedderModelContext + public typealias ContainerType = EmbedderModelContainer + + public init( + typeRegistry: ModelTypeRegistry, + modelRegistry: AbstractModelRegistry + ) { + self.typeRegistry = typeRegistry + self.modelRegistry = modelRegistry + } + + /// Shared instance with default behavior. + public static let shared = EmbedderModelFactory( + typeRegistry: EmbedderTypeRegistry.shared, modelRegistry: EmbedderRegistry.shared) + + /// registry of model type, e.g. configuration value `gemma3` -> configuration and init methods + public let typeRegistry: ModelTypeRegistry + + /// registry of model id to configuration, e.g. `sentence-transformers/all-MiniLM-L6-v2` + public let modelRegistry: AbstractModelRegistry + + public func _load( + configuration: ResolvedModelConfiguration, + tokenizerLoader: any TokenizerLoader + ) async throws -> EmbedderModelContext { + let modelDirectory = configuration.modelDirectory + + // Load config.json once and decode for both base config and model-specific config + let configurationURL = modelDirectory.appending(component: "config.json") + let configData: Data + do { + configData = try Data(contentsOf: configurationURL) + } catch { + throw ModelFactoryError.configurationFileError( + configurationURL.lastPathComponent, configuration.name, error) + } + + let baseConfig: BaseConfiguration + do { + baseConfig = try JSONDecoder.json5().decode(BaseConfiguration.self, from: configData) + } catch let error as DecodingError { + throw ModelFactoryError.configurationDecodingError( + configurationURL.lastPathComponent, configuration.name, error) + } + + let model: EmbeddingModel + do { + model = try await typeRegistry.createModel( + configuration: configData, modelType: baseConfig.modelType) + } catch let error as DecodingError { + throw ModelFactoryError.configurationDecodingError( + configurationURL.lastPathComponent, configuration.name, error) + } + + // Load tokenizer and weights in parallel + async let tokenizerTask = tokenizerLoader.load( + from: configuration.tokenizerDirectory) + + try loadWeights( + modelDirectory: modelDirectory, model: model, + perLayerQuantization: baseConfig.perLayerQuantization) + + let tokenizer = try await tokenizerTask + + // Build a ModelConfiguration for the ModelContext + let tokenizerSource: TokenizerSource? = + configuration.tokenizerDirectory == modelDirectory + ? nil + : .directory(configuration.tokenizerDirectory) + let modelConfig = ModelConfiguration( + directory: modelDirectory, + tokenizerSource: tokenizerSource, + defaultPrompt: configuration.defaultPrompt, + extraEOSTokens: configuration.extraEOSTokens, + eosTokenIds: configuration.eosTokenIds, + toolCallFormat: configuration.toolCallFormat) + + let pooling = loadPooling(modelDirectory: modelDirectory, model: model) + + return .init( + configuration: modelConfig, model: model, + tokenizer: tokenizer, pooling: pooling + ) + } + + public func _wrap(_ context: EmbedderModelContext) -> EmbedderModelContainer { + .init(context: context) + } +} diff --git a/Libraries/MLXEmbedders/Models.swift b/Libraries/MLXEmbedders/Models.swift deleted file mode 100644 index bf9dee479..000000000 --- a/Libraries/MLXEmbedders/Models.swift +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation -import MLXLMCommon - -/// A registry and configuration provider for embedding models. -/// -/// `ModelConfiguration` manages how models are identified (either via Hugging Face Hub IDs or local file URLs) -/// and provides a mechanism to override tokenizer settings. It includes a global registry of -/// well-known models (like BGE, E5, and Snowflake Arctic) to simplify initialization. -/// -/// ### Example -/// ```swift -/// // Using a pre-registered model -/// let config = ModelConfiguration.bge_small -/// -/// // Using a custom local directory -/// let customConfig = ModelConfiguration(directory: myURL) -/// ``` -public struct ModelConfiguration: Sendable { - - /// The backing storage for the model's location. - public enum Identifier: Sendable { - /// A Hugging Face Hub repository identifier (e.g., "BAAI/bge-small-en-v1.5"). - case id(String, revision: String = "main") - /// A file system URL pointing to a local model directory. - case directory(URL) - } - - /// The model's identifier (ID or Directory). - public var id: Identifier - - /// A display-friendly name for the model. - /// - /// For Hub models, this returns the repo ID. For local directories, - /// it returns a path-based name (e.g., "ParentDir/ModelDir"). - public var name: String { - switch id { - case .id(let string, _): - string - case .directory(let url): - url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent - } - } - - /// Where to load the tokenizer from when it differs from the model directory. - /// - /// - `.id`: download from a remote provider (requires a downloader) - /// - `.directory`: load from a local path - /// - `nil`: use the same directory as the model - public let tokenizerSource: TokenizerSource? - - /// Initializes a configuration using a Hub repository ID. - /// - Parameters: - /// - id: The Hugging Face repo ID. - /// - revision: The Git revision to use (defaults to "main"). - /// - tokenizerSource: Optional alternate source for the tokenizer. - public init( - id: String, - revision: String = "main", - tokenizerSource: TokenizerSource? = nil - ) { - self.id = .id(id, revision: revision) - self.tokenizerSource = tokenizerSource - } - - /// Initializes a configuration using a local directory. - /// - Parameters: - /// - directory: The `URL` of the model on disk. - /// - tokenizerSource: Optional alternate source for the tokenizer. - public init( - directory: URL, - tokenizerSource: TokenizerSource? = nil - ) { - self.id = .directory(directory) - self.tokenizerSource = tokenizerSource - } - - // MARK: - Registry Management - - /// Global registry of available model configurations. - @MainActor - public static var registry = [String: ModelConfiguration]() - - /// Registers an array of configurations into the global registry. - /// - Parameter configurations: The models to register. - @MainActor - public static func register(configurations: [ModelConfiguration]) { - bootstrap() - - for c in configurations { - registry[c.name] = c - } - } - - /// Retrieves a configuration by its ID or name. - /// - /// If the ID is not found in the registry, a new `ModelConfiguration` is - /// created on-the-fly using the provided string as a Hub ID. - /// - /// - Parameter id: The model name or Hub ID. - /// - Returns: A `ModelConfiguration` instance. - @MainActor - public static func configuration(id: String) -> ModelConfiguration { - bootstrap() - - if let c = registry[id] { - return c - } else { - return ModelConfiguration(id: id) - } - } - - /// Returns all registered model configurations. - @MainActor - public static var models: some Collection & Sendable { - bootstrap() - return Self.registry.values - } -} - -// MARK: - Predefined Models - -extension ModelConfiguration { - /// BGE Micro v2 (TaylorAI) - optimized for extremely low latency. - public static let bge_micro = ModelConfiguration(id: "TaylorAI/bge-micro-v2") - /// GTE Tiny - a small, efficient embedding model. - public static let gte_tiny = ModelConfiguration(id: "TaylorAI/gte-tiny") - /// MiniLM-L6 - the industry-standard small embedding model. - public static let minilm_l6 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L6-v2") - /// Snowflake Arctic Embed XS. - public static let snowflake_xs = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-xs") - /// MiniLM-L12 - a more accurate version of MiniLM. - public static let minilm_l12 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L12-v2") - /// BGE Small en v1.5. - public static let bge_small = ModelConfiguration(id: "BAAI/bge-small-en-v1.5") - /// Multilingual E5 Small - supports over 100 languages. - public static let multilingual_e5_small = ModelConfiguration( - id: "intfloat/multilingual-e5-small") - /// BGE Base en v1.5. - public static let bge_base = ModelConfiguration(id: "BAAI/bge-base-en-v1.5") - /// Nomic Embed Text v1. - public static let nomic_text_v1 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1") - /// Nomic Embed Text v1.5 - supports Matryoshka embeddings. - public static let nomic_text_v1_5 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1.5") - /// BGE Large en v1.5. - public static let bge_large = ModelConfiguration(id: "BAAI/bge-large-en-v1.5") - /// Snowflake Arctic Embed L. - public static let snowflake_lg = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-l") - /// BGE-M3 - Multi-lingual, Multi-functional, Multi-granularity. - public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3") - /// Mixedbread AI Large v1. - public static let mixedbread_large = ModelConfiguration( - id: "mixedbread-ai/mxbai-embed-large-v1") - /// Qwen3 Embedding 0.6B - 4-bit quantized version. - public static let qwen3_embedding = ModelConfiguration( - id: "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ") - - private enum BootstrapState: Sendable { - case idle - case bootstrapping - case bootstrapped - } - - /// Internal state to ensure the registry is only populated once. - @MainActor - static private var bootstrapState = BootstrapState.idle - - /// Populates the registry with default models if it hasn't been done already. - @MainActor - static func bootstrap() { - switch bootstrapState { - case .idle: - bootstrapState = .bootstrapping - register(configurations: [ - bge_micro, gte_tiny, minilm_l6, snowflake_xs, minilm_l12, - bge_small, multilingual_e5_small, bge_base, nomic_text_v1, - nomic_text_v1_5, bge_large, snowflake_lg, bge_m3, - mixedbread_large, qwen3_embedding, - ]) - bootstrapState = .bootstrapped - case .bootstrapping, .bootstrapped: - break - } - } -} diff --git a/Libraries/MLXEmbedders/Pooling.swift b/Libraries/MLXEmbedders/Pooling.swift index c1c9e1baf..e4ffa2d4f 100644 --- a/Libraries/MLXEmbedders/Pooling.swift +++ b/Libraries/MLXEmbedders/Pooling.swift @@ -63,10 +63,10 @@ func loadPooling(modelDirectory: URL, model: EmbeddingModel) -> Pooling { /// /// `Pooling` takes the sequence of hidden states from a transformer model and collapses them /// into a single vector using strategies like mean, max, or token selection. -public class Pooling: Module { +open class Pooling: Module { /// Supported pooling strategies. - public enum Strategy { + public enum Strategy: Sendable { /// Average all token embeddings (weighted by mask). case mean /// Use the pooled output (CLS) provided by the model. diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index f8e261bdd..969b9ca4d 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -20,7 +20,7 @@ private func create( public enum LLMTypeRegistry { /// Shared instance with default model types. - public static let shared: ModelTypeRegistry = .init(creators: [ + public static let shared: ModelTypeRegistry = .init(creators: [ "mistral": create(LlamaConfiguration.self, LlamaModel.init), "llama": create(LlamaConfiguration.self, LlamaModel.init), "phi": create(PhiConfiguration.self, PhiModel.init), @@ -472,9 +472,14 @@ private struct LLMUserInputProcessor: UserInputProcessor { /// let modelContainer = try await LLMModelFactory.shared.loadContainer( /// configuration: LLMRegistry.llama3_8B_4bit) /// ``` -public final class LLMModelFactory: ModelFactory { +public final class LLMModelFactory: GenericModelFactory { - public init(typeRegistry: ModelTypeRegistry, modelRegistry: AbstractModelRegistry) { + public typealias ContextType = ModelContext + public typealias ContainerType = ModelContainer + + public init( + typeRegistry: ModelTypeRegistry, modelRegistry: AbstractModelRegistry + ) { self.typeRegistry = typeRegistry self.modelRegistry = modelRegistry } @@ -484,7 +489,7 @@ public final class LLMModelFactory: ModelFactory { typeRegistry: LLMTypeRegistry.shared, modelRegistry: LLMRegistry.shared) /// registry of model type, e.g. configuration value `llama` -> configuration and init methods - public let typeRegistry: ModelTypeRegistry + public let typeRegistry: ModelTypeRegistry /// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit` public let modelRegistry: AbstractModelRegistry diff --git a/Libraries/MLXLLM/Models/DeepseekV3.swift b/Libraries/MLXLLM/Models/DeepseekV3.swift index e20fdbe64..8b0767d71 100644 --- a/Libraries/MLXLLM/Models/DeepseekV3.swift +++ b/Libraries/MLXLLM/Models/DeepseekV3.swift @@ -436,7 +436,7 @@ public class DeepseekV3Model: Module, LLMModel, KVCacheDimensionProvider, LoRAMo func dequant(weight: MLXArray, scaleInv: MLXArray) -> MLXArray { let bs = 128 - let (m, n) = (weight.shape[0], weight.shape[1]) + let (m, n) = (weight.dim(0), weight.dim(1)) let padBottom = (bs - m % bs) % bs let padSide = (bs - n % bs) % bs diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index df1eab41f..c5d4a060b 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -421,7 +421,7 @@ public class Gemma3TextModel: Module, LLMModel { _ input: LMInput, cache: [KVCache], windowSize: Int? = nil ) throws -> PrepareResult { let promptTokens = input.text.tokens - let promptCount = promptTokens.shape[0] + let promptCount = promptTokens.dim(0) guard promptCount > 0 else { print("Warning: Preparing with empty prompt tokens.") diff --git a/Libraries/MLXLLM/Models/Gemma3nText.swift b/Libraries/MLXLLM/Models/Gemma3nText.swift index 727aeb122..1aaa1d75a 100644 --- a/Libraries/MLXLLM/Models/Gemma3nText.swift +++ b/Libraries/MLXLLM/Models/Gemma3nText.swift @@ -306,7 +306,7 @@ class Gemma3nAttention: Module { var adjustedMask = mask if case .array(let maskArray) = mask { let keysSeqLen = keys.shape[keys.shape.count - 2] - if maskArray.shape.last! != keysSeqLen { + if maskArray.dim(-1) != keysSeqLen { let slicedMask = maskArray[.ellipsis, 0 ..< keysSeqLen].asType(queries.dtype) adjustedMask = .array(slicedMask) } else { @@ -589,7 +589,7 @@ class Gemma3nDecoderLayer: Module { var finalMask = mask if isSliding, case .array(let maskArray) = mask { - let effectiveSeqLen = max(cachePosition?.shape[0] ?? 0, slidingWindow) + let effectiveSeqLen = max(cachePosition?.dim(0) ?? 0, slidingWindow) let minDtype = MLXArray(Float.leastNormalMagnitude, dtype: maskArray.dtype) let slidingWindowMask = tril( @@ -599,7 +599,7 @@ class Gemma3nDecoderLayer: Module { let updatedMask = MLX.where(slidingWindowMask, minDtype, maskArray) let offset = max(0, (cachePosition?.max().item() ?? 0) - effectiveSeqLen + 1) - let maskIndexes = MLXArray(0 ..< min(effectiveSeqLen, updatedMask.shape.last!)) + offset + let maskIndexes = MLXArray(0 ..< min(effectiveSeqLen, updatedMask.dim(-1))) + offset let slicedMask = take(updatedMask, maskIndexes.asType(.int32), axis: -1) finalMask = .array(slicedMask) } @@ -822,7 +822,7 @@ public class Gemma3nLanguageModel: Module { let cacheArray = cache ?? Array(repeating: nil as KVCache?, count: requiredCacheSize) let pastSeenTokens = cacheArray.first??.offset ?? 0 - let cachePosition = MLXArray(pastSeenTokens ..< (pastSeenTokens + h.shape[1])) + let cachePosition = MLXArray(pastSeenTokens ..< (pastSeenTokens + h.dim(1))) var fullMask: MLXFast.ScaledDotProductAttentionMaskMode = .none var slidingWindowMask: MLXFast.ScaledDotProductAttentionMaskMode = .none @@ -1030,7 +1030,7 @@ public class Gemma3nTextModel: Module, LLMModel { _ input: LMInput, cache: [KVCache], windowSize: Int? = nil ) throws -> PrepareResult { let promptTokens = input.text.tokens - let promptCount = promptTokens.shape[0] + let promptCount = promptTokens.dim(0) guard promptCount > 0 else { print("Warning: Preparing with empty prompt tokens.") diff --git a/Libraries/MLXLLM/Models/Gemma4Text.swift b/Libraries/MLXLLM/Models/Gemma4Text.swift index 8733d55c0..545ef2478 100644 --- a/Libraries/MLXLLM/Models/Gemma4Text.swift +++ b/Libraries/MLXLLM/Models/Gemma4Text.swift @@ -306,7 +306,7 @@ private class Gemma4Attention: Module { var adjustedMask = mask if case .array(let maskArray) = mask { let keysSeqLen = keys.dim(2) - if maskArray.shape.last! != keysSeqLen { + if maskArray.dim(-1) != keysSeqLen { adjustedMask = .array(maskArray[.ellipsis, 0 ..< keysSeqLen]) } } diff --git a/Libraries/MLXLLM/Models/LFM2.swift b/Libraries/MLXLLM/Models/LFM2.swift index 8d7fc1b45..a46e1d4e7 100644 --- a/Libraries/MLXLLM/Models/LFM2.swift +++ b/Libraries/MLXLLM/Models/LFM2.swift @@ -385,7 +385,7 @@ public class LFM2Model: Module, LLMModel, KVCacheDimensionProvider { var sanitizedParam = param if name.contains("conv.weight") { - if param.shape[param.shape.count - 1] > param.shape[1] { + if param.shape[param.shape.count - 1] > param.dim(1) { sanitizedParam = param.transposed(0, 2, 1) } } diff --git a/Libraries/MLXLLM/Models/LFM2MoE.swift b/Libraries/MLXLLM/Models/LFM2MoE.swift index fcefb2e01..7f27127b7 100644 --- a/Libraries/MLXLLM/Models/LFM2MoE.swift +++ b/Libraries/MLXLLM/Models/LFM2MoE.swift @@ -441,7 +441,7 @@ public class LFM2MoEModel: Module, LLMModel, KVCacheDimensionProvider { for (name, param) in weights { var tensor = param if name.contains("conv.weight") { - if tensor.shape.last! > tensor.shape[1] { + if tensor.dim(-1) > tensor.dim(1) { tensor = tensor.transposed(0, 2, 1) } } diff --git a/Libraries/MLXLLM/Models/MiMoV2Flash.swift b/Libraries/MLXLLM/Models/MiMoV2Flash.swift index f9545e8e1..11df1e41b 100644 --- a/Libraries/MLXLLM/Models/MiMoV2Flash.swift +++ b/Libraries/MLXLLM/Models/MiMoV2Flash.swift @@ -408,7 +408,7 @@ public class MiMoV2FlashModel: Module, LLMModel, KVCacheDimensionProvider { func dequant(weight: MLXArray, scaleInv: MLXArray) -> MLXArray { let dtype = weight.dtype let bs = 128 - let (m, n) = (weight.shape[0], weight.shape[1]) + let (m, n) = (weight.dim(0), weight.dim(1)) let padBottom = bs * scaleInv.dim(0) - m let padSide = bs * scaleInv.dim(1) - n diff --git a/Libraries/MLXLLM/Models/MiniMax.swift b/Libraries/MLXLLM/Models/MiniMax.swift index 47bae6a59..7357accc2 100644 --- a/Libraries/MLXLLM/Models/MiniMax.swift +++ b/Libraries/MLXLLM/Models/MiniMax.swift @@ -227,7 +227,7 @@ public class MiniMaxModel: Module, LLMModel, KVCacheDimensionProvider { func dequant(weight: MLXArray, scaleInv: MLXArray) -> MLXArray { let dtype = weight.dtype let bs = 128 - let (m, n) = (weight.shape[0], weight.shape[1]) + let (m, n) = (weight.dim(0), weight.dim(1)) let padBottom = (bs - m % bs) % bs let padSide = (bs - n % bs) % bs diff --git a/Libraries/MLXLMCommon/BaseConfiguration.swift b/Libraries/MLXLMCommon/BaseConfiguration.swift index f37bce65f..1a57fab61 100644 --- a/Libraries/MLXLMCommon/BaseConfiguration.swift +++ b/Libraries/MLXLMCommon/BaseConfiguration.swift @@ -3,25 +3,49 @@ import Foundation import MLX -/// Base ``LanguageModel`` configuration -- provides `modelType` -/// and `quantization` (used in loading the model). +/// The fundamental configuration for any MLX-based model. /// -/// This is used by ``ModelFactory/load(from:using:configuration:useLatest:progressHandler:)`` -/// to determine the type of model to load. +/// `BaseConfiguration` provides the metadata necessary to identify the model architecture +/// (`modelType`) and describes the quantization parameters used to compress the model's weights. +/// It is designed to be decoded directly from a model repository's `config.json`. +/// +/// Typically used the ``GenericModelFactory`` implementations during load. public struct BaseConfiguration: Codable, Sendable { + + /// The architecture identifier (e.g., "bert", "roberta", "xlm-roberta"). public let modelType: String + /// Configuration parameters for weight quantization. + /// + /// MLX uses group-wise quantization to reduce memory footprint. This struct + /// defines how weights are grouped and the precision (bits) used for each group. public struct Quantization: Codable, Sendable, Equatable { + + /// Initializes a new quantization configuration. + /// - Parameters: + /// - groupSize: The number of weights that share the same scale and bias. + /// - bits: The bit-depth of the quantized weights (e.g., 4 or 8). public init(groupSize: Int, bits: Int) { self.groupSize = groupSize self.bits = bits } + /// The size of the quantization group. public let groupSize: Int + + /// The number of bits per weight. public let bits: Int + + /// Internal storage for the quantization mode. private var _mode: QuantizationMode? = nil + + /// The quantization method to use (defaults to `.affine`). + /// + /// Affine quantization (asymmetric) uses both a scale and a zero-point + /// to map floating point values to integers. public var mode: QuantizationMode { _mode ?? .affine } + /// Converts the configuration into a tuple format compatible with `MLX.quantize`. public var asTuple: (Int, Int, QuantizationMode) { (groupSize, bits, mode) } enum CodingKeys: String, CodingKey { @@ -33,13 +57,22 @@ public struct BaseConfiguration: Codable, Sendable { /// handling instructions for ``PerLayerQuantization`` public enum QuantizationOption: Sendable { + /// Do not quantize this specific layer (keep it in high precision). case skip + /// Quantize this layer using the provided parameters. case quantize(Quantization) } - /// Per-layer ``Quantization`` values with optional default. + /// A container for per-layer ``Quantization`` settings. + /// + /// This allows for "Mixed-Precision" or "Heterogeneous" quantization, where + /// sensitive layers (like the embedding head) can be kept at higher precision + /// while the rest of the model is compressed. public struct PerLayerQuantization: Sendable { + /// The default quantization for any layer not explicitly named in `perLayerQuantization`. public var quantization: Quantization? = nil + + /// A dictionary mapping layer paths (e.g., "model.embed_tokens") to their quantization options. public var perLayerQuantization: [String: QuantizationOption] public init( @@ -50,7 +83,9 @@ public struct BaseConfiguration: Codable, Sendable { self.perLayerQuantization = perLayerQuantization } - /// The quantization to apply for the given layer name or nil for no quantization. + /// Resolves the quantization parameters for a specific layer. + /// - Parameter layer: The path/name of the layer. + /// - Returns: The `Quantization` settings to apply, or `nil` if the layer should be skipped. public func quantization(layer: String) -> Quantization? { if let perLayer = perLayerQuantization[layer] { switch perLayer { @@ -65,8 +100,7 @@ public struct BaseConfiguration: Codable, Sendable { } } - /// Special codable to support a mixed key: Int / key: Quantization - /// structure for hereogenous quantization, e.g. + /// An internal container designed to handle the mixed JSON structure found in `config.json`. /// /// ``` /// "quantization": { @@ -79,12 +113,14 @@ public struct BaseConfiguration: Codable, Sendable { /// "model.layers.0.self_attn.q_norm": false, /// ``` /// - /// This mixed type structure requires manual decoding. + /// Quantization configs in MLX often interleave global keys (like `bits`) with + /// specific layer keys (like `model.layers.0...`). This container uses manual + /// decoding to separate these interleaved values. struct QuantizationContainer: Codable, Sendable { var quantization: Quantization var perLayerQuantization: PerLayerQuantization - // based on Dictionary's coding key + /// A custom CodingKey used to iterate over arbitrary layer names in JSON. internal struct _DictionaryCodingKey: CodingKey { internal let stringValue: String internal let intValue: Int? @@ -118,11 +154,13 @@ public struct BaseConfiguration: Codable, Sendable { case "quant_method", "linear_class", "quantization_mode": continue default: + // If the value is a boolean 'false', we treat it as .skip if let f = try? container.decode(Bool.self, forKey: key) { if !f { perLayerQuantization[key.stringValue] = .skip } } else { + // Otherwise, try to decode a specific Quantization object for this layer perLayerQuantization[key.stringValue] = .quantize( try container.decode(Quantization.self, forKey: key)) } @@ -147,16 +185,19 @@ public struct BaseConfiguration: Codable, Sendable { } } + /// Internal storage for quantization details extracted from `config.json`. var quantizationContainer: QuantizationContainer? /// EOS token IDs from config.json. Can be a single Int or an array of Ints. public var eosTokenIds: IntOrIntArray? + /// The default quantization settings. @available(*, deprecated, message: "Please use perLayerQuantization instead") public var quantization: Quantization? { quantizationContainer?.quantization } + /// The per-layer quantization settings, including the default fallback. public var perLayerQuantization: PerLayerQuantization? { quantizationContainer?.perLayerQuantization } diff --git a/Libraries/MLXLMCommon/LanguageModel.swift b/Libraries/MLXLMCommon/LanguageModel.swift index 822d7ac3a..01b551fca 100644 --- a/Libraries/MLXLMCommon/LanguageModel.swift +++ b/Libraries/MLXLMCommon/LanguageModel.swift @@ -4,6 +4,31 @@ import Foundation import MLX import MLXNN +/// Abstract form of a model that processes language. +public protocol BaseLanguageModel: Module { + /// Optionally preprocess the weights and modify / remove values as needed. + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] + + /// Optionally preprocess the weights with access to safetensor metadata. + /// + /// The default implementation forwards to ``sanitize(weights:)``. + /// Models can override this to inspect metadata (e.g. check `metadata["format"] == "mlx"`) + /// and skip or customize sanitization accordingly. + func sanitize(weights: [String: MLXArray], metadata: [String: String]) -> [String: MLXArray] +} + +extension BaseLanguageModel { + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + weights + } + + public func sanitize(weights: [String: MLXArray], metadata: [String: String]) -> [String: + MLXArray] + { + sanitize(weights: weights) + } +} + /// Time/Height/Width struct to represent information about input images. public struct THW: Sendable { @@ -150,7 +175,7 @@ public enum PrepareResult { /// - calls ``prepare(_:cache:windowSize:)`` to initialize the KVCache and consume the prompt /// - calls ``callAsFunction(_:cache:state:)-9kuvf`` for each token, producing an ``LMOutput`` /// - the ``TokenIterator`` accumulates this information into a ``GenerateResult`` -public protocol LanguageModel: Module { +public protocol LanguageModel: BaseLanguageModel { /// Prepare the cache state and consume the ``LMInput``. /// @@ -169,16 +194,6 @@ public protocol LanguageModel: Module { /// create a new array of ``KVCache``: automatic implementation if self /// implements ``KVCacheDimensionProvider`` func newCache(parameters: GenerateParameters?) -> [KVCache] - - /// Optionally preprocess the weights and modify / remove values as needed. - func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] - - /// Optionally preprocess the weights with access to safetensor metadata. - /// - /// The default implementation forwards to ``sanitize(weights:)``. - /// Models can override this to inspect metadata (e.g. check `metadata["format"] == "mlx"`) - /// and skip or customize sanitization accordingly. - func sanitize(weights: [String: MLXArray], metadata: [String: String]) -> [String: MLXArray] } extension LanguageModel { @@ -194,18 +209,6 @@ extension LanguageModel { } } -extension LanguageModel { - public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { - weights - } - - public func sanitize(weights: [String: MLXArray], metadata: [String: String]) -> [String: - MLXArray] - { - sanitize(weights: weights) - } -} - /// Optional protocol that can be implemented by ``LanguageModel`` and will /// provide an automatic implementation of ``LanguageModel/newCache(parameters:)`` public protocol KVCacheDimensionProvider { diff --git a/Libraries/MLXLMCommon/Load.swift b/Libraries/MLXLMCommon/Load.swift index 715f34e20..ac422d2ba 100644 --- a/Libraries/MLXLMCommon/Load.swift +++ b/Libraries/MLXLMCommon/Load.swift @@ -6,13 +6,13 @@ import MLXNN /// Load model weights. /// -/// This is typically called via ``ModelFactory/load(from:using:configuration:useLatest:progressHandler:)``. +/// This is typically called via ``GenericModelFactory/load(from:using:configuration:useLatest:progressHandler:)``. /// This function loads all `safetensor` files in the given `modelDirectory`, -/// calls ``LanguageModel/sanitize(weights:metadata:)`` to allow per-model preprocessing, +/// calls ``BaseLanguageModel/sanitize(weights:metadata:)`` to allow per-model preprocessing, /// applies optional quantization, and /// updates the model with the weights. public func loadWeights( - modelDirectory: URL, model: LanguageModel, + modelDirectory: URL, model: BaseLanguageModel, quantization: BaseConfiguration.Quantization? = nil, perLayerQuantization: BaseConfiguration.PerLayerQuantization? = nil ) throws { diff --git a/Libraries/MLXLMCommon/ModelConfiguration.swift b/Libraries/MLXLMCommon/ModelConfiguration.swift index d3d50e3ce..5fbdce2dc 100644 --- a/Libraries/MLXLMCommon/ModelConfiguration.swift +++ b/Libraries/MLXLMCommon/ModelConfiguration.swift @@ -2,7 +2,15 @@ import Foundation -/// Configuration for a given model name with overrides for prompts and tokens. +/// Configuration for a given model: at least an org/name identifier or a directory with the model files. +/// +/// Optionally callers can provide some default values and overrides for: +/// +/// - a default prompt +/// - EOS tokens / strings +/// - tool calling formats +/// +/// Some of these are specific to LLMs and VLMs -- embedding models will ignore those properties. /// /// See e.g. `MLXLM.ModelRegistry` for an example of use. public struct ModelConfiguration: Sendable { @@ -21,13 +29,21 @@ public struct ModelConfiguration: Sendable { } } + /// The backing storage for the model's location. public enum Identifier: Sendable { + /// A Hugging Face Hub repository identifier (e.g., "BAAI/bge-small-en-v1.5"). case id(String, revision: String = "main") + /// A file system URL pointing to a local model directory. case directory(URL) } + /// The model's identifier (ID or Directory). public var id: Identifier + /// A display-friendly name for the model. + /// + /// For Hub models, this returns the repo ID. For local directories, + /// it returns a path-based name (e.g., "ParentDir/ModelDir"). public var name: String { switch id { case .id(let id, _): diff --git a/Libraries/MLXLMCommon/ModelFactory.swift b/Libraries/MLXLMCommon/ModelFactory.swift index df8a8ccc3..880df7575 100644 --- a/Libraries/MLXLMCommon/ModelFactory.swift +++ b/Libraries/MLXLMCommon/ModelFactory.swift @@ -55,7 +55,7 @@ public enum ModelFactoryError: LocalizedError { /// Context of types that work together to provide a ``LanguageModel``. /// -/// A ``ModelContext`` is created by ``ModelFactory/load(from:using:configuration:useLatest:progressHandler:)``. +/// A ``ModelContext`` is created by ``GenericModelFactory/load(from:using:configuration:useLatest:progressHandler:)``. /// This contains the following: /// /// - ``ModelConfiguration``: identifier for the model @@ -63,7 +63,7 @@ public enum ModelFactoryError: LocalizedError { /// - ``UserInputProcessor``: can convert ``UserInput`` into ``LMInput`` /// - `Tokenizer` -- the tokenizer used by ``UserInputProcessor`` /// -/// See also ``ModelFactory/loadContainer(from:using:configuration:useLatest:progressHandler:)`` and +/// See also ``GenericModelFactory/loadContainer(from:using:configuration:useLatest:progressHandler:)`` and /// ``ModelContainer``. public struct ModelContext { public var configuration: ModelConfiguration @@ -84,23 +84,41 @@ public struct ModelContext { /// Protocol for code that can load models. /// -/// ## See Also -/// - ``loadModel(from:using:id:revision:useLatest:progressHandler:)`` -/// - ``loadModel(from:using:)`` -/// - ``loadModelContainer(from:using:id:revision:useLatest:progressHandler:)`` -/// - ``loadModelContainer(from:using:)`` -public protocol ModelFactory: Sendable { +/// See concrete implementations in: +/// +/// - `LLMModelFactory` +/// - `VLMModelFactory` +/// - `EmbedderModelFactory` +/// +/// or, if loading LLM/VLMs, use the free functions: +/// +/// - ``loadModel(from:using:configuration:useLatest:progressHandler:)`` +/// - ``loadModelContainer(from:using:configuration:useLatest:progressHandler:)`` +/// +/// or variants. +public protocol GenericModelFactory: Sendable { + + associatedtype ContextType + associatedtype ContainerType: Sendable var modelRegistry: AbstractModelRegistry { get } + /// load level load of a ``ResolvedModelConfiguration`` (urls) into a + /// ``ContextType``. This is typically `struct` that holds the values + /// needed to run inference in the model and is _not_ `Sendable`. func _load( configuration: ResolvedModelConfiguration, tokenizerLoader: any TokenizerLoader - ) async throws -> ModelContext + ) async throws -> ContextType + /// Wrap a ``ContextType`` in a ``ContainerType``. + /// + /// The `ContainerType` is a `Sendable` container for managing the model contained + /// in the `ContextType`. + func _wrap(_ context: ContextType) -> ContainerType } -extension ModelFactory { +extension GenericModelFactory { /// Resolve a model identifier, e.g. "mlx-community/Llama-3.2-3B-Instruct-4bit", into /// a ``ModelConfiguration``. @@ -118,10 +136,9 @@ extension ModelFactory { public func contains(id: String) -> Bool { modelRegistry.contains(id: id) } - } -extension ModelFactory { +extension GenericModelFactory { /// Load a model from a ``Downloader`` and ``ModelConfiguration``, /// producing a ``ModelContext``. @@ -138,7 +155,7 @@ extension ModelFactory { configuration: ModelConfiguration, useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } - ) async throws -> sending ModelContext { + ) async throws -> sending ContextType { let resolved = try await resolve( configuration: configuration, from: downloader, useLatest: useLatest, progressHandler: progressHandler) @@ -153,12 +170,12 @@ extension ModelFactory { configuration: ModelConfiguration, useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } - ) async throws -> ModelContainer { + ) async throws -> ContainerType { let resolved = try await resolve( configuration: configuration, from: downloader, useLatest: useLatest, progressHandler: progressHandler) let context = try await _load(configuration: resolved, tokenizerLoader: tokenizerLoader) - return ModelContainer(context: context) + return _wrap(context) } /// Load a model from a local directory, producing a ``ModelContext``. @@ -168,7 +185,7 @@ extension ModelFactory { public func load( from directory: URL, using tokenizerLoader: any TokenizerLoader - ) async throws -> sending ModelContext { + ) async throws -> sending ContextType { try await _load( configuration: .init(directory: directory), tokenizerLoader: tokenizerLoader) } @@ -177,14 +194,25 @@ extension ModelFactory { public func loadContainer( from directory: URL, using tokenizerLoader: any TokenizerLoader - ) async throws -> ModelContainer { + ) async throws -> ContainerType { let context = try await _load( configuration: .init(directory: directory), tokenizerLoader: tokenizerLoader) - return ModelContainer(context: context) + return _wrap(context) } } +extension GenericModelFactory where ContextType == ModelContext, ContainerType == ModelContainer { + + public func _wrap(_ context: ModelContext) -> ModelContainer { + .init(context: context) + } + +} + +/// For backward compatibility: `ModelFactory` refers to an LLM/VLM model factory. +public typealias ModelFactory = GenericModelFactory + /// Resolve a ``ModelConfiguration`` into a ``ResolvedModelConfiguration`` by /// downloading remote sources via a ``Downloader``. /// @@ -227,6 +255,8 @@ public func resolve( tokenizerDirectory: tokenizerDirectory) } +// MARK: - LLM Model Loading Free Functions -- implied ModelFactory + /// Load a model given a ``ModelConfiguration``, downloading via a ``Downloader``. /// /// Returns a ``ModelContext`` holding the model and tokenizer without @@ -373,7 +403,8 @@ public func loadModelContainer( } } -private func load(loader: (ModelFactory) async throws -> sending R) async throws -> sending R { +private func load(loader: (any ModelFactory) async throws -> sending R) async throws -> sending R +{ let factories = ModelFactoryRegistry.shared.modelFactories() var lastError: Error? for factory in factories { @@ -419,7 +450,7 @@ private func load(loader: (ModelFactory) async throws -> sending R) async thr /// ## See Also /// - ``ModelFactoryRegistry`` public protocol ModelFactoryTrampoline { - static func modelFactory() -> ModelFactory? + static func modelFactory() -> (any GenericModelFactory)? } /// Registry of ``ModelFactory`` trampolines. @@ -441,28 +472,30 @@ final public class ModelFactoryRegistry: @unchecked Sendable { public static let shared = ModelFactoryRegistry() private let lock = NSLock() - private var trampolines: [() -> ModelFactory?] + private var trampolines: [() -> (any ModelFactory)?] private init() { self.trampolines = [ { - (NSClassFromString("MLXVLM.TrampolineModelFactory") as? ModelFactoryTrampoline.Type)? + (NSClassFromString("MLXVLM.TrampolineModelFactory") + as? any ModelFactoryTrampoline.Type)? .modelFactory() }, { - (NSClassFromString("MLXLLM.TrampolineModelFactory") as? ModelFactoryTrampoline.Type)? + (NSClassFromString("MLXLLM.TrampolineModelFactory") + as? any ModelFactoryTrampoline.Type)? .modelFactory() }, ] } - public func addTrampoline(_ trampoline: @escaping () -> ModelFactory?) { + public func addTrampoline(_ trampoline: @escaping () -> (any ModelFactory)?) { lock.withLock { trampolines.append(trampoline) } } - public func modelFactories() -> [ModelFactory] { + public func modelFactories() -> [any ModelFactory] { lock.withLock { trampolines.compactMap { $0() } } diff --git a/Libraries/MLXLMCommon/Registries/AbstractModelRegistry.swift b/Libraries/MLXLMCommon/Registries/AbstractModelRegistry.swift index 135532717..60b7b0560 100644 --- a/Libraries/MLXLMCommon/Registries/AbstractModelRegistry.swift +++ b/Libraries/MLXLMCommon/Registries/AbstractModelRegistry.swift @@ -25,7 +25,7 @@ open class AbstractModelRegistry: @unchecked Sendable { } } - /// Returns configuration from ``ModelFactory/modelRegistry``. + /// Returns configuration from ``GenericModelFactory/modelRegistry``. /// /// - Note: If the id doesn't exists in the configuration, this will return a new instance of it. /// If you want to check if the configuration in model registry, you should use ``contains(id:)``. diff --git a/Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift b/Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift index a610f5e0d..d050b8077 100644 --- a/Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift +++ b/Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift @@ -2,7 +2,9 @@ import Foundation -public actor ModelTypeRegistry { +public actor ModelTypeRegistry { + + private var creators: [String: (Data) throws -> T] /// Creates an empty registry. public init() { @@ -10,22 +12,19 @@ public actor ModelTypeRegistry { } /// Creates a registry with given creators. - public init(creators: [String: (Data) throws -> any LanguageModel]) { + public init(creators: [String: (Data) throws -> T]) { self.creators = creators } - private var creators: [String: (Data) throws -> any LanguageModel] - /// Add a new model to the type registry. public func registerModelType( - _ type: String, creator: @escaping (Data) throws -> any LanguageModel + _ type: String, creator: @escaping (Data) throws -> T ) { creators[type] = creator } /// Given a `modelType` and configuration data instantiate a new `LanguageModel`. - public func createModel(configuration: Data, modelType: String) throws -> sending LanguageModel - { + public func createModel(configuration: Data, modelType: String) throws -> sending T { guard let creator = creators[modelType] else { throw ModelFactoryError.unsupportedModelType(modelType) } diff --git a/Libraries/MLXLMCommon/Utilities/SerialAccessContainer.swift b/Libraries/MLXLMCommon/Utilities/SerialAccessContainer.swift index 01c1af826..b6c591b97 100644 --- a/Libraries/MLXLMCommon/Utilities/SerialAccessContainer.swift +++ b/Libraries/MLXLMCommon/Utilities/SerialAccessContainer.swift @@ -41,7 +41,7 @@ private actor AsyncMutex { /// Unlike an `actor`, this will guarantee exclusive access for the duration of the async /// call. This is important for things like `ModelContainer` that have to perform async /// work but also need to prevent other callers for using _any_ of the internal state. -final class SerialAccessContainer: @unchecked Sendable { +package final class SerialAccessContainer: @unchecked Sendable { private var value: T private let lock = AsyncMutex() @@ -101,14 +101,14 @@ final class SerialAccessContainer: @unchecked Sendable { /// }.consume() /// } /// ``` -final class SendableBox: @unchecked Sendable { +package final class SendableBox: @unchecked Sendable { private var value: T? - init(_ value: consuming T) { + package init(_ value: consuming T) { self.value = consume value } - consuming func consume() -> T { + package consuming func consume() -> T { guard let value else { fatalError("value already consumed") } diff --git a/Libraries/MLXVLM/Models/FastVLM.swift b/Libraries/MLXVLM/Models/FastVLM.swift index dcfe56650..ae69e39f7 100644 --- a/Libraries/MLXVLM/Models/FastVLM.swift +++ b/Libraries/MLXVLM/Models/FastVLM.swift @@ -1087,8 +1087,8 @@ public class FastVLM: Module, VLMModel, KVCacheDimensionProvider { let (_, imageFeatures, _) = visionModel(pixelValues.transposed(0, 2, 3, 1)) let (B, H, W, C) = ( - imageFeatures.shape[0], imageFeatures.shape[1], imageFeatures.shape[2], - imageFeatures.shape[3] + imageFeatures.dim(0), imageFeatures.dim(1), imageFeatures.dim(2), + imageFeatures.dim(3) ) let mmInputs = multimodalProjector(imageFeatures.reshaped(B, H * W, C)) let finalEmbeddings = prepareInputsForMultimodal( diff --git a/Libraries/MLXVLM/Models/Gemma3.swift b/Libraries/MLXVLM/Models/Gemma3.swift index d0c968f75..58301660e 100644 --- a/Libraries/MLXVLM/Models/Gemma3.swift +++ b/Libraries/MLXVLM/Models/Gemma3.swift @@ -865,12 +865,12 @@ private func maskedScatter( // Scatter the scaled image features into the special image token positions let imagePositions = MLXArray(imagePositionIndices) - guard scaledImageFeaturesFlattened.shape[0] == imagePositions.shape[0] else { + guard scaledImageFeaturesFlattened.dim(0) == imagePositions.dim(0) else { fatalError( """ Critical error in maskedScatter: Size mismatch between image features and positions. - Image features: \(scaledImageFeaturesFlattened.shape[0]) - Image positions: \(imagePositions.shape[0]) + Image features: \(scaledImageFeaturesFlattened.dim(0)) + Image positions: \(imagePositions.dim(0)) """) } finalEmbeddingFlattened[imagePositions] = scaledImageFeaturesFlattened diff --git a/Libraries/MLXVLM/Models/Gemma4.swift b/Libraries/MLXVLM/Models/Gemma4.swift index 1b5037d97..72cb1b950 100644 --- a/Libraries/MLXVLM/Models/Gemma4.swift +++ b/Libraries/MLXVLM/Models/Gemma4.swift @@ -67,9 +67,9 @@ private func gemma4MaskedScatter( return inputTensor } - guard flattenedSource.shape[0] == targetIndices.count else { + guard flattenedSource.dim(0) == targetIndices.count else { fatalError( - "Masked scatter shape mismatch. source=\(flattenedSource.shape[0]) mask=\(targetIndices.count)" + "Masked scatter shape mismatch. source=\(flattenedSource.dim(0)) mask=\(targetIndices.count)" ) } @@ -193,7 +193,8 @@ private func gemma4AdjustAttentionMask( ) -> MLXFast.ScaledDotProductAttentionMaskMode { switch mask { case .array(let maskArray): - guard let maskLength = maskArray.shape.last, maskLength > keyLength else { + let maskLength = maskArray.dim(-1) + guard maskLength > keyLength else { return mask } let start = maskLength - keyLength @@ -1812,10 +1813,6 @@ public struct Gemma4Processor: UserInputProcessor { var processedImage: LMInput.ProcessedImage? if !input.images.isEmpty { - let imagePlaceholderCount = promptTokens.filter { $0 == config.imageTokenId }.count - let boiCount = promptTokens.filter { $0 == config.boiTokenId }.count - let eoiCount = promptTokens.filter { $0 == config.eoiTokenId }.count - let imagePixelsAndFrames = try input.images.map { try preprocess(images: [$0.asCIImage()], processing: input.processing) } @@ -1840,10 +1837,6 @@ public struct Gemma4Processor: UserInputProcessor { } } promptTokens = expandedTokens - - let expandedImageTokenCount = promptTokens.filter { $0 == config.imageTokenId }.count - let expandedBoiCount = promptTokens.filter { $0 == config.boiTokenId }.count - let expandedEoiCount = promptTokens.filter { $0 == config.eoiTokenId }.count } let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) diff --git a/Libraries/MLXVLM/Models/Idefics3.swift b/Libraries/MLXVLM/Models/Idefics3.swift index 7c8cd5ff3..e68f70367 100644 --- a/Libraries/MLXVLM/Models/Idefics3.swift +++ b/Libraries/MLXVLM/Models/Idefics3.swift @@ -697,8 +697,8 @@ public class Idefics3: Module, VLMModel, KVCacheDimensionProvider { var segments = [MLXArray]() var start_idx = 0 - let chunkSize = imageFeatures.shape[1] // 64 - let chunkCount = imagePositions.count / chunkSize // Should be imageFeatures.shape[0] + let chunkSize = imageFeatures.dim(1) // 64 + let chunkCount = imagePositions.count / chunkSize // Should be imageFeatures.dim(0) let chunks = (0 ..< chunkCount).map { startIndex in let start = startIndex * chunkSize let end = start + chunkSize diff --git a/Libraries/MLXVLM/Models/LFM2VL.swift b/Libraries/MLXVLM/Models/LFM2VL.swift index 0d99a766a..ec08a72a6 100644 --- a/Libraries/MLXVLM/Models/LFM2VL.swift +++ b/Libraries/MLXVLM/Models/LFM2VL.swift @@ -570,7 +570,7 @@ private enum Language { var sanitizedParam = param if name.contains("conv.weight") { - if param.shape[param.shape.count - 1] > param.shape[1] { + if param.shape[param.shape.count - 1] > param.dim(1) { sanitizedParam = param.transposed(0, 2, 1) } } @@ -1075,7 +1075,7 @@ public class LFM2VL: Module, VLMModel, KVCacheDimensionProvider { // Handle conv weight transposition var value = v if newKey.contains("conv.weight") { - if v.shape[v.shape.count - 1] > v.shape[1] { + if v.shape[v.shape.count - 1] > v.dim(1) { value = v.transposed(0, 2, 1) } } diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index 25a28d111..0724d4d2b 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXLMCommon +import MLXNN public enum VLMError: LocalizedError, Equatable { case imageRequired @@ -79,7 +80,7 @@ private func create( public enum VLMTypeRegistry { /// Shared instance with default model types. - public static let shared: ModelTypeRegistry = .init(creators: [ + public static let shared: ModelTypeRegistry = .init(creators: [ "paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init), "qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init), "qwen2_5_vl": create(Qwen25VLConfiguration.self, Qwen25VL.init), @@ -90,7 +91,6 @@ public enum VLMTypeRegistry { "gemma3": create(Gemma3Configuration.self, Gemma3.init), "gemma4": create(Gemma4Configuration.self, Gemma4.init), "smolvlm": create(SmolVLM2Configuration.self, SmolVLM2.init), - // TODO: see if we can make it work with fastvlm rather than llava_qwen2 "fastvlm": create(FastVLMConfiguration.self, FastVLM.init), "llava_qwen2": create(FastVLMConfiguration.self, FastVLM.init), "pixtral": create(PixtralConfiguration.self, PixtralVLM.init), @@ -287,10 +287,13 @@ public typealias ModelRegistry = VLMRegistry /// let modelContainer = try await VLMModelFactory.shared.loadContainer( /// configuration: VLMRegistry.paligemma3bMix4488bit) /// ``` -public final class VLMModelFactory: ModelFactory { +public final class VLMModelFactory: GenericModelFactory { + + public typealias ContextType = ModelContext + public typealias ContainerType = ModelContainer public init( - typeRegistry: ModelTypeRegistry, processorRegistry: ProcessorTypeRegistry, + typeRegistry: ModelTypeRegistry, processorRegistry: ProcessorTypeRegistry, modelRegistry: AbstractModelRegistry ) { self.typeRegistry = typeRegistry @@ -304,7 +307,7 @@ public final class VLMModelFactory: ModelFactory { modelRegistry: VLMRegistry.shared) /// registry of model type, e.g. configuration value `paligemma` -> configuration and init methods - public let typeRegistry: ModelTypeRegistry + public let typeRegistry: ModelTypeRegistry /// registry of input processor type, e.g. configuration value `PaliGemmaProcessor` -> configuration and init methods public let processorRegistry: ProcessorTypeRegistry