diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index cce61087e..f3fa1066c 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -8,6 +8,23 @@ import MLX import MLXLMCommon import MLXNN +// M-RoPE for the Qwen2-VL language model is a direct port of the Qwen2.5-VL +// implementation added in ml-explore/mlx-swift-lm#239 (`Qwen25VL.swift`); the LM +// spatial-position scheme is identical between the two models. Each ported piece +// below carries a `// #239: Qwen25VL ` marker. Map: +// positionIdsKey/ropeDeltasKey ............ Qwen25VL.swift:13-14 +// Attention mropeSectionRaw/_invFreq ...... :54,58 ; mropeCosSin :108 ; branch :144 +// DecoderLayer/Qwen2Model positionIds ..... :208 / :240 +// LanguageModel state + decode delta ...... :279-316 +// getRopeIndex ............................ :1027-1182 +// inputEmbeddings/prepare/callAsFunction .. :941 / :982 / :1184 +// +// Per-call decoder state is plumbed through `LMOutput.State` rather than stored as +// instance vars, so the model stays a pure function and one instance can serve many +// concurrent sessions without their MROPE state colliding. // #239: Qwen25VL:13-14 +private let positionIdsKey = LMOutput.Key("qwen2vl.positionIds") +private let ropeDeltasKey = LMOutput.Key("qwen2vl.ropeDeltas") + // MARK: - Language private enum Language { @@ -45,7 +62,11 @@ private enum Language { let kvHeads: Int let headDim: Int let scale: Float - let mropeSection: [Int] + let mropeSection: [Int] // cumulative section indices (for half-dim split) + let mropeSectionRaw: [Int] // raw section sizes [16, 24, 24] (for full-dim split) + // Leading underscore makes Module's weight loader skip this property — + // invFreq is computed from ropeTheta+headDim, not a trained weight. + private let _invFreq: MLXArray @ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "k_proj") var wk: Linear @@ -67,6 +88,8 @@ private enum Language { self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() { + // Raw sections e.g. [16, 24, 24] — used for splitting full-dim cos/sin + self.mropeSectionRaw = array // mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist() self.mropeSection = sequence(state: (0, array.makeIterator())) { state in if let v = state.1.next() { @@ -81,12 +104,42 @@ private enum Language { fatalError("rope_scaling['mrope_section'] must be an array of integers") } + // Compute inv_freq for MROPE (same formula as Python) + // inv_freq = 1.0 / (theta ^ (arange(0, dim, 2) / dim)) + let freqIndices = MLXArray(stride(from: 0, to: headDim, by: 2)).asType(.float32) + let base = MLXArray(args.ropeTheta) + self._invFreq = 1.0 / pow(base, freqIndices / Float(headDim)) + self._rotaryEmbedding.wrappedValue = RoPE( dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) } + /// Compute cos/sin for MROPE from 3D position IDs. + /// Matches Python apply_mrope: start with temporal, overwrite H/W ranges. + // #239: Qwen25VL.swift Attention.mropeCosSin (:108) + private func mropeCosSin(positionIds: MLXArray) -> (MLXArray, MLXArray) { + // positionIds: [3, batch, seq] + let invFreqExpanded = _invFreq.reshaped(1, 1, -1, 1) // [1, 1, dim/2, 1] + let posExpanded = positionIds[0..., 0..., .newAxis, 0...].asType(.float32) // [3, batch, 1, seq] + var freqs = matmul(invFreqExpanded, posExpanded) // [3, batch, dim/2, seq] + freqs = freqs.transposed(0, 1, 3, 2) // [3, batch, seq, dim/2] + + var freqsT = freqs[0] + var offset = mropeSectionRaw[0] + for dim in 1 ..< mropeSectionRaw.count { + let length = mropeSectionRaw[dim] + freqsT[0..., 0..., offset ..< (offset + length)] = + freqs[dim][0..., 0..., offset ..< (offset + length)] + offset += length + } + + let emb = concatenated([freqsT, freqsT], axis: -1) + return (MLX.cos(emb), MLX.sin(emb)) + } + public func callAsFunction( - _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?, + positionIds: MLXArray? = nil ) -> MLXArray { let (B, L) = (x.dim(0), x.dim(1)) @@ -99,9 +152,20 @@ private enum Language { keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) - let offset = cache?.offset ?? 0 - queries = rotaryEmbedding(queries, offset: offset) - keys = rotaryEmbedding(keys, offset: offset) + if let positionIds { + // MROPE path: compute 3D-aware cos/sin from position IDs + // #239: Qwen25VL.swift Attention.callAsFunction positionIds branch (:144) + let (cosValues, sinValues) = mropeCosSin(positionIds: positionIds) + let cos = cosValues[.newAxis, 0..., 0..., 0...] // [1, batch, seq, dim] + let sin = sinValues[.newAxis, 0..., 0..., 0...] + queries = (queries * cos) + (QwenVL.rotateHalf(queries) * sin) + keys = (keys * cos) + (QwenVL.rotateHalf(keys) * sin) + } else { + // Simple sequential RoPE (no-image / text-only path) + let offset = cache?.offset ?? 0 + queries = rotaryEmbedding(queries, offset: offset) + keys = rotaryEmbedding(keys, offset: offset) + } let output = attentionWithCacheUpdate( queries: queries, @@ -153,9 +217,11 @@ private enum Language { } public func callAsFunction( - _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?, + positionIds: MLXArray? = nil ) -> MLXArray { - var r = attention(inputLayerNorm(x), mask: mask, cache: cache) + // #239: Qwen25VL.swift Qwen25VLDecoderLayer.callAsFunction (:208) — thread positionIds + var r = attention(inputLayerNorm(x), mask: mask, cache: cache, positionIds: positionIds) let h = x + r r = mlp(postAttentionLayerNorm(h)) let out = h + r @@ -184,7 +250,8 @@ private enum Language { } public func callAsFunction( - _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil, + positionIds: MLXArray? = nil ) -> MLXArray { var h: MLXArray if let inputEmbedding { @@ -197,8 +264,9 @@ private enum Language { let mask = createAttentionMask(h: h, cache: cache?.first) + // #239: Qwen25VL.swift Qwen25Model.callAsFunction (:240) — thread positionIds for (i, layer) in layers.enumerated() { - h = layer(h, mask: mask, cache: cache?[i]) + h = layer(h, mask: mask, cache: cache?[i], positionIds: positionIds) } return norm(h) @@ -222,17 +290,213 @@ private enum Language { } public func callAsFunction( - _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + _ inputs: MLXArray?, cache: [KVCache]? = nil, + state: LMOutput.State?, + inputEmbedding: MLXArray? = nil, + positionIds: MLXArray? = nil ) -> LMOutput { - var out = model(inputs, cache: cache, inputEmbedding: inputEmbedding) + // #239: Qwen25VL.swift LanguageModel.callAsFunction state + decode delta (:279-316) + var state = state ?? .init() + var effectivePositionIds = positionIds ?? state[positionIdsKey] + if state[positionIdsKey] != nil { + state[positionIdsKey] = nil + } + + // Decode steps: reconstruct positions from ropeDeltas + cache offset. + if effectivePositionIds == nil, let ropeDeltas = state[ropeDeltasKey], + let cache, let input = inputs ?? inputEmbedding + { + let batch = input.dim(0) + let seqLength = input.dim(1) + let lastCacheOffset = cache.last?.offset ?? 0 + + var delta = MLXArray(lastCacheOffset).asType(.int32) + ropeDeltas.asType(.int32) + + var base = MLXArray(0 ..< seqLength).asType(.int32) + base = base[.newAxis, 0...] + base = broadcast(base, to: [batch, seqLength]) + + if delta.dim(0) == 1 && batch > 1 { + delta = repeated(delta, count: batch, axis: 0) + } + + base = base + delta + + effectivePositionIds = base[.newAxis, 0..., 0...] + effectivePositionIds = broadcast(effectivePositionIds!, to: [3, batch, seqLength]) + } + + var out = model( + inputs, cache: cache, inputEmbedding: inputEmbedding, + positionIds: effectivePositionIds) if let lmHead { out = lmHead(out) } else { out = model.embedTokens.asLinear(out) } - return LMOutput(logits: out) + return LMOutput(logits: out, state: state) } } + + /// Build 3-D MROPE position IDs (temporal, height, width) from input_ids + + /// image/video grid sizes, plus the rope deltas used to continue positions + /// during decode. Port of Qwen2-VL's `get_rope_index`. + // #239: Qwen25VL.swift getRopeIndex (:1027-1182), verbatim + static func getRopeIndex( + inputIds: MLXArray, + imageGridTHW: [THW]?, + videoGridTHW: [THW]?, + spatialMergeSize: Int, + imageTokenId: Int, + videoTokenId: Int, + attentionMask: MLXArray? = nil + ) -> (MLXArray, MLXArray) { + + let (batchSize, seqLength) = (inputIds.dim(0), inputIds.dim(1)) + + guard inputIds.ndim > 0, imageGridTHW != nil || videoGridTHW != nil else { + var positionIds = MLXArray(0 ..< seqLength).asType(.int32) + positionIds = broadcast(positionIds[.newAxis, 0...], to: [batchSize, seqLength]) + let positionIds3D = broadcast( + positionIds[.newAxis, 0..., 0...], to: [3, batchSize, seqLength]) + let zeros = MLXArray.zeros([batchSize], dtype: .int32) + return (positionIds3D, zeros) + } + + var positionIds = ones(like: inputIds).asType(.int32) + positionIds = broadcast(positionIds[.newAxis, 0..., 0...], to: [3, batchSize, seqLength]) + + var mropePositionDeltas: [Int] = [] + let mask = attentionMask ?? ones(like: inputIds) + + for batchIdx in 0 ..< batchSize { + var batchInputIds = inputIds[batchIdx, 0...] + batchInputIds = `where`( + mask[batchIdx, 0...] .== 1, batchInputIds, zeros(like: batchInputIds)) + + let imageNums = ((batchInputIds .== MLXArray(imageTokenId)).asType(.int32).sum()).item( + Int.self) + let videoNums = ((batchInputIds .== MLXArray(videoTokenId)).asType(.int32).sum()).item( + Int.self) + + let inputTokens = batchInputIds.asArray(Int32.self).map { Int($0) } + var llmPosIdsList: [MLXArray] = [] + + var st = 0 + var remainImages = imageNums + var remainVideos = videoNums + var imageIndex = 0 + var videoIndex = 0 + + for _ in 0 ..< (imageNums + videoNums) { + let edImage: Int + if remainImages > 0, let idx = inputTokens[st...].firstIndex(of: imageTokenId) { + edImage = idx + } else { + edImage = inputTokens.count + 1 + } + + let edVideo: Int + if remainVideos > 0, let idx = inputTokens[st...].firstIndex(of: videoTokenId) { + edVideo = idx + } else { + edVideo = inputTokens.count + 1 + } + + let (t, h, w, ed): (Int, Int, Int, Int) + if edImage < edVideo { + guard let grid = imageGridTHW, imageIndex < grid.count else { break } + (t, h, w) = grid[imageIndex].values + imageIndex += 1 + remainImages -= 1 + ed = edImage + } else { + guard let grid = videoGridTHW, videoIndex < grid.count else { break } + (t, h, w) = grid[videoIndex].values + videoIndex += 1 + remainVideos -= 1 + ed = edVideo + } + + let llmGridT = t + let llmGridH = h / spatialMergeSize + let llmGridW = w / spatialMergeSize + + let stIdx: Int + if let lastArray = llmPosIdsList.last { + stIdx = lastArray.max().item(Int.self) + 1 + } else { + stIdx = 0 + } + + // Text tokens before this visual block + let textLen = ed - st + if textLen > 0 { + var index = MLXArray(0 ..< textLen).reshaped([1, textLen]) + index = broadcast(index, to: [3, textLen]) + index = index + MLXArray(stIdx) + llmPosIdsList.append(index) + } + + // 3D position IDs for visual tokens (temporal, height, width) + var tIndex = MLXArray(0 ..< llmGridT).reshaped([llmGridT, 1]) + tIndex = broadcast(tIndex, to: [llmGridT, llmGridH * llmGridW]) + tIndex = tIndex.flattened() + + var hIndex = MLXArray(0 ..< llmGridH).reshaped([1, llmGridH, 1]) + hIndex = broadcast(hIndex, to: [llmGridT, llmGridH, llmGridW]) + hIndex = hIndex.flattened() + + var wIndex = MLXArray(0 ..< llmGridW).reshaped([1, 1, llmGridW]) + wIndex = broadcast(wIndex, to: [llmGridT, llmGridH, llmGridW]) + wIndex = wIndex.flattened() + + let visualPosIds = stacked([tIndex, hIndex, wIndex]) + MLXArray(textLen + stIdx) + llmPosIdsList.append(visualPosIds) + + st = ed + llmGridT * llmGridH * llmGridW + } + + // Remaining text tokens after last visual block + if st < inputTokens.count { + let stIdx: Int + if let lastArray = llmPosIdsList.last { + stIdx = lastArray.max().item(Int.self) + 1 + } else { + stIdx = 0 + } + + let textLen = inputTokens.count - st + var tIndex = MLXArray(0 ..< textLen).reshaped([1, textLen]) + tIndex = broadcast(tIndex, to: [3, textLen]) + llmPosIdsList.append(tIndex + MLXArray(stIdx)) + } + + if !llmPosIdsList.isEmpty { + let llmPositions = concatenated(llmPosIdsList, axis: 1) // [3, seq] + + let expandedMask = broadcast( + mask[batchIdx, 0...][.newAxis, .newAxis, 0...], to: [3, 1, seqLength]) + let expandedPositions = llmPositions[0..., .newAxis, 0...] + let newPositions = `where`( + expandedMask, expandedPositions, + positionIds[0..., batchIdx ..< batchIdx + 1, 0...]) + + positionIds = newPositions + + let maxPosId = llmPositions.max().item(Int.self) + mropePositionDeltas.append(maxPosId + 1 - inputTokens.count) + } + } + + let deltas: MLXArray + if mropePositionDeltas.isEmpty { + deltas = MLXArray.zeros([batchSize], dtype: .int32) + } else { + deltas = MLXArray(mropePositionDeltas.map { Int32($0) }) + } + return (positionIds, deltas) + } } // MARK: - Vision @@ -667,11 +931,18 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) } + /// Builds the multimodal input embedding for one prefill step. + /// + /// Returns the embeddings paired with the prefill-only MROPE state + /// (positionIds + ropeDeltas) — both nil on the no-image path. The + /// caller seeds `LMOutput.State` with these so subsequent decode steps + /// reconstruct positions from `ropeDeltas + cacheOffset` without + /// mutating the model. private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, frames: [THW]?) - -> MLXArray + -> (embeds: MLXArray, positionIds: MLXArray?, ropeDeltas: MLXArray?) { guard let pixelValues, let frames else { - return languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis]) + return (languageModel.model.embedTokens(inputIds[.newAxis, .ellipsis]), nil, nil) } // Get the input embeddings from the language model @@ -685,10 +956,23 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { } // Insert special image tokens in the input_ids - return QwenVL.mergeInputIdsWithImageFeatures( + let merged = QwenVL.mergeInputIdsWithImageFeatures( inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates, imageTokenId: config.baseConfiguration.imageTokenId, videoTokenId: config.baseConfiguration.videoTokenId) + + // Compute MROPE 3D position IDs for spatial awareness + // #239: Qwen25VL.swift inputEmbeddings (:941) + let inputIds2D = inputIds.ndim == 1 ? inputIds[.newAxis, 0...] : inputIds + let (positionIds, ropeDeltas) = Language.getRopeIndex( + inputIds: inputIds2D, + imageGridTHW: frames, + videoGridTHW: nil, + spatialMergeSize: config.visionConfiguration.spatialMergeSize, + imageTokenId: config.baseConfiguration.imageTokenId, + videoTokenId: config.baseConfiguration.videoTokenId) + + return (merged, positionIds, ropeDeltas) } public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws @@ -714,17 +998,33 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { allFrames.append(contentsOf: videoFrames) } - let inputEmbeddings = self.inputEmbeddings( + let (embeds, positionIds, ropeDeltas) = self.inputEmbeddings( inputIds: input.text.tokens, pixelValues: allPixels, frames: allFrames.isEmpty ? nil : allFrames) - let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) + // #239: Qwen25VL.swift prepare(_:cache:windowSize:) (:982) + // Seed per-call decoder state with the prefill-only MROPE positions + + // ropeDeltas (both nil on the no-image path). The LMOutput's `state` + // returned here is consumed by subsequent decode steps via + // `callAsFunction(_:cache:state:)`. + var state = LMOutput.State() + if let positionIds { + state[positionIdsKey] = positionIds + } + if let ropeDeltas { + state[ropeDeltasKey] = ropeDeltas + } + + let result = languageModel(nil, cache: cache, state: state, inputEmbedding: embeds) return .logits(result) } - public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray { - languageModel(inputs, cache: cache).logits + // #239: Qwen25VL.swift Model.callAsFunction(_:LMInput.Text,cache:,state:) (:1184) + public func callAsFunction( + _ input: LMInput.Text, cache: [any KVCache]?, state: LMOutput.State? + ) -> LMOutput { + languageModel(input.tokens, cache: cache, state: state) } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {