From fdf2c656fe8226fd524f5f50421df78efd4e6164 Mon Sep 17 00:00:00 2001 From: atlascodesai <76924051+atlascodesai@users.noreply.github.com> Date: Fri, 12 Jun 2026 11:00:04 -0600 Subject: [PATCH 1/5] feat(gemma4): add audio tower and fixtures --- .../Gemma4AudioIntegrationTests.swift | 95 ++ Libraries/MLXLMCommon/LanguageModel.swift | 5 +- Libraries/MLXVLM/Models/Gemma4.swift | 948 +++++++++++++++++- .../Models/Gemma4AudioFeatureExtractor.swift | 278 +++++ Package.swift | 14 +- .../Fixtures/gemma4_e2e_reference.json | 8 + .../Fixtures/gemma4_mel_alignment.json | 1 + .../Fixtures/gemma4_mel_reference.json | 136 +++ .../Fixtures/gemma4_token_alignment.json | 1 + .../MLXLMTests/Gemma4AudioAlignmentTest.swift | 148 +++ Tests/MLXLMTests/Gemma4AudioTests.swift | 157 +++ .../MLXLMTests/Resources/FIXTURES_LICENSES.md | 17 + .../Resources/gemma_audio_librispeech.wav | Bin 0 -> 374800 bytes .../Resources/gemma_speech_long.wav | Bin 0 -> 469608 bytes .../Resources/gemma_speech_test.wav | Bin 0 -> 220720 bytes .../Resources/gemma_speech_test2.wav | Bin 0 -> 190828 bytes 16 files changed, 1776 insertions(+), 32 deletions(-) create mode 100644 IntegrationTesting/IntegrationTestingTests/Gemma4AudioIntegrationTests.swift create mode 100644 Libraries/MLXVLM/Models/Gemma4AudioFeatureExtractor.swift create mode 100644 Tests/MLXLMTests/Fixtures/gemma4_e2e_reference.json create mode 100644 Tests/MLXLMTests/Fixtures/gemma4_mel_alignment.json create mode 100644 Tests/MLXLMTests/Fixtures/gemma4_mel_reference.json create mode 100644 Tests/MLXLMTests/Fixtures/gemma4_token_alignment.json create mode 100644 Tests/MLXLMTests/Gemma4AudioAlignmentTest.swift create mode 100644 Tests/MLXLMTests/Gemma4AudioTests.swift create mode 100644 Tests/MLXLMTests/Resources/FIXTURES_LICENSES.md create mode 100644 Tests/MLXLMTests/Resources/gemma_audio_librispeech.wav create mode 100644 Tests/MLXLMTests/Resources/gemma_speech_long.wav create mode 100644 Tests/MLXLMTests/Resources/gemma_speech_test.wav create mode 100644 Tests/MLXLMTests/Resources/gemma_speech_test2.wav diff --git a/IntegrationTesting/IntegrationTestingTests/Gemma4AudioIntegrationTests.swift b/IntegrationTesting/IntegrationTestingTests/Gemma4AudioIntegrationTests.swift new file mode 100644 index 000000000..6b5f2e79b --- /dev/null +++ b/IntegrationTesting/IntegrationTestingTests/Gemma4AudioIntegrationTests.swift @@ -0,0 +1,95 @@ +// Copyright © 2026 Apple Inc. +// +// Real end-to-end Gemma 4 audio inference. Downloads an audio-capable Gemma 4 +// VLM and asks it to transcribe real speech clips, exercising the full audio +// path: AVAssetReader PCM -> mel feature extractor -> Conformer audio tower -> +// begin/end-of-audio prompt splice -> text. +// +// Speech clips are committed under Tests/MLXLMTests/Resources/. +// +// Run: +// xcodebuild test -project IntegrationTesting.xcodeproj \ +// -scheme IntegrationTesting -destination 'platform=macOS' \ +// -only-testing:IntegrationTestingTests/Gemma4AudioIntegrationTests + +import Foundation +import HuggingFace +import IntegrationTestHelpers +import MLXHuggingFace +import MLXLMCommon +import Testing +import Tokenizers + +private let models = IntegrationTestModels( + downloader: #hubDownloader(), + tokenizerLoader: #huggingFaceTokenizerLoader() +) + +private let resources = URL(fileURLWithPath: #filePath) + .deletingLastPathComponent() + .deletingLastPathComponent() + .deletingLastPathComponent() + .appendingPathComponent("Tests/MLXLMTests/Resources") + .path + +struct SpeechCase: Sendable, CustomStringConvertible { + let file: String + let expected: [String] + var description: String { file } +} + +private let speechCases: [SpeechCase] = [ + .init( + file: "gemma_speech_test.wav", + expected: ["quick", "brown", "fox", "lazy", "dog", "river"]), + .init( + file: "gemma_speech_long.wav", + expected: ["weather", "rain", "forecast", "afternoon", "breeze", "evening", "sky"]), +] + +@Suite(.serialized) +struct Gemma4AudioIntegrationTests { + + private func transcribe(model: String, clip: SpeechCase) async throws -> String { + let container = try await models.vlmContainer(for: ModelConfiguration(id: model)) + let session = ChatSession( + container, generateParameters: GenerateParameters(maxTokens: 120, temperature: 0)) + let url = URL(fileURLWithPath: "\(resources)/\(clip.file)") + return try await session.respond( + to: "Transcribe the speech in this audio clip.", + images: [], videos: [], audios: [.url(url)]) + } + + private func assertRecovered(_ answer: String, _ clip: SpeechCase) { + let lower = answer.lowercased() + #expect(!lower.contains(""), "audio path regressed to a wall: \(answer)") + let hits = clip.expected.filter { lower.contains($0) } + #expect( + hits.count >= 3, + "[\(clip.file)] did not recover the spoken words (matched \(hits) in: \(answer))") + } + + @Test(arguments: speechCases) + func gemma4_e4b_transcribes(_ clip: SpeechCase) async throws { + let answer = try await transcribe(model: "mlx-community/gemma-4-e4b-it-4bit", clip: clip) + print("[e4b/\(clip.file)] \(answer)") + assertRecovered(answer, clip) + } + + @Test func gemma4_e4b_perceivesRealSpeech() async throws { + let clip = SpeechCase(file: "gemma_audio_librispeech.wav", expected: []) + let answer = try await transcribe(model: "mlx-community/gemma-4-e4b-it-4bit", clip: clip) + print("[e4b/librispeech] \(answer)") + let lower = answer.lowercased() + #expect(!lower.contains(""), "audio path regressed to a wall") + #expect( + !lower.contains("not provided") && !lower.contains("no audio") + && !lower.contains("haven't provided") && !lower.contains("have not provided"), + "model claims no audio; audio not reaching the model: \(answer)") + let contentWords = ["middle", "class", "welcome", "mr", "mister", "gospel", "apostle"] + let hits = contentWords.filter { lower.contains($0) } + #expect( + hits.count >= 2, + "did not perceive the real-speech content (matched \(hits) in: \(answer))") + } +} diff --git a/Libraries/MLXLMCommon/LanguageModel.swift b/Libraries/MLXLMCommon/LanguageModel.swift index 4c85ac1c0..f958c4a1e 100644 --- a/Libraries/MLXLMCommon/LanguageModel.swift +++ b/Libraries/MLXLMCommon/LanguageModel.swift @@ -125,11 +125,14 @@ public struct LMInput { public struct ProcessedAudio { public let samples: MLXArray + public let mask: MLXArray? public init( - samples: MLXArray + samples: MLXArray, + mask: MLXArray? = nil ) { self.samples = samples + self.mask = mask } } diff --git a/Libraries/MLXVLM/Models/Gemma4.swift b/Libraries/MLXVLM/Models/Gemma4.swift index 8ad788ee3..62040739c 100644 --- a/Libraries/MLXVLM/Models/Gemma4.swift +++ b/Libraries/MLXVLM/Models/Gemma4.swift @@ -405,9 +405,69 @@ public struct Gemma4VisionConfiguration: Codable, Sendable { } } +public struct Gemma4AudioConfiguration: Codable, Sendable { + public let hiddenSize: Int + public let numHiddenLayers: Int + public let numAttentionHeads: Int + public let subsamplingConvChannels: [Int] + public let convKernelSize: Int + public let residualWeight: Float + public let attentionChunkSize: Int + public let attentionContextLeft: Int + public let attentionContextRight: Int + public let attentionLogitCap: Float + public let attentionInvalidLogitsValue: Float + public let useClippedLinears: Bool + public let rmsNormEps: Float + public let gradientClipping: Float + public let outputProjDims: Int? + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case numHiddenLayers = "num_hidden_layers" + case numAttentionHeads = "num_attention_heads" + case subsamplingConvChannels = "subsampling_conv_channels" + case convKernelSize = "conv_kernel_size" + case residualWeight = "residual_weight" + case attentionChunkSize = "attention_chunk_size" + case attentionContextLeft = "attention_context_left" + case attentionContextRight = "attention_context_right" + case attentionLogitCap = "attention_logit_cap" + case attentionInvalidLogitsValue = "attention_invalid_logits_value" + case useClippedLinears = "use_clipped_linears" + case rmsNormEps = "rms_norm_eps" + case gradientClipping = "gradient_clipping" + case outputProjDims = "output_proj_dims" + } + + public init(from decoder: any Swift.Decoder) throws { + let c = try decoder.container(keyedBy: CodingKeys.self) + hiddenSize = try c.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 1024 + numHiddenLayers = try c.decodeIfPresent(Int.self, forKey: .numHiddenLayers) ?? 12 + numAttentionHeads = try c.decodeIfPresent(Int.self, forKey: .numAttentionHeads) ?? 8 + subsamplingConvChannels = + try c.decodeIfPresent([Int].self, forKey: .subsamplingConvChannels) ?? [128, 32] + convKernelSize = try c.decodeIfPresent(Int.self, forKey: .convKernelSize) ?? 5 + residualWeight = try c.decodeIfPresent(Float.self, forKey: .residualWeight) ?? 0.5 + attentionChunkSize = try c.decodeIfPresent(Int.self, forKey: .attentionChunkSize) ?? 12 + attentionContextLeft = try c.decodeIfPresent(Int.self, forKey: .attentionContextLeft) ?? 13 + attentionContextRight = try c.decodeIfPresent(Int.self, forKey: .attentionContextRight) ?? 0 + attentionLogitCap = try c.decodeIfPresent(Float.self, forKey: .attentionLogitCap) ?? 50.0 + attentionInvalidLogitsValue = + try c.decodeIfPresent(Float.self, forKey: .attentionInvalidLogitsValue) ?? -1e9 + useClippedLinears = + try c.decodeIfPresent(Bool.self, forKey: .useClippedLinears) ?? true + rmsNormEps = try c.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1e-6 + gradientClipping = + try c.decodeIfPresent(Float.self, forKey: .gradientClipping) ?? 1e10 + outputProjDims = try c.decodeIfPresent(Int.self, forKey: .outputProjDims) + } +} + public struct Gemma4Configuration: Codable, Sendable { public let textConfiguration: Gemma4TextConfiguration public let visionConfiguration: Gemma4VisionConfiguration + public let audioConfiguration: Gemma4AudioConfiguration? public let modelType: String public let quantization: BaseConfiguration.Quantization? public let imageTokenId: Int @@ -428,6 +488,7 @@ public struct Gemma4Configuration: Codable, Sendable { enum CodingKeys: String, CodingKey { case textConfiguration = "text_config" case visionConfiguration = "vision_config" + case audioConfiguration = "audio_config" case modelType = "model_type" case quantization case imageTokenId = "image_token_id" @@ -447,6 +508,8 @@ public struct Gemma4Configuration: Codable, Sendable { Gemma4TextConfiguration.self, forKey: CodingKeys.textConfiguration) visionConfiguration = try c.decode( Gemma4VisionConfiguration.self, forKey: CodingKeys.visionConfiguration) + audioConfiguration = try c.decodeIfPresent( + Gemma4AudioConfiguration.self, forKey: CodingKeys.audioConfiguration) modelType = try c.decodeIfPresent(String.self, forKey: CodingKeys.modelType) ?? "gemma4" quantization = try c.decodeIfPresent( BaseConfiguration.Quantization.self, forKey: CodingKeys.quantization) @@ -1049,7 +1112,6 @@ final class Gemma4TextBackbone: Module { } let finalPerLayerInputs = projectPerLayerInputs(h0, perLayerInputs: processedPerLayerInputs) - let hasExplicitCache = cache != nil let localCache = cache ?? Array(repeating: nil as KVCache?, count: max(firstKVSharedLayerIdx, 1)) let fullMask: MLXFast.ScaledDotProductAttentionMaskMode @@ -1738,6 +1800,716 @@ private final class Gemma4MultimodalEmbedder: Module, UnaryLayer { } } +// MARK: - Audio + +private final class Gemma4AudioRMSNorm: Module, UnaryLayer { + let eps: Float + @ModuleInfo var weight: MLXArray + + init(dimensions: Int, eps: Float = 1e-6) { + self.eps = eps + self._weight.wrappedValue = MLXArray.ones([dimensions]) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + MLXFast.rmsNorm(x, weight: weight, eps: eps) + } +} + +private final class Gemma4AudioClippableLinear: Module, UnaryLayer { + @ModuleInfo(key: "linear") var linear: Linear + @ModuleInfo(key: "input_min") var inputMin: MLXArray? + @ModuleInfo(key: "input_max") var inputMax: MLXArray? + @ModuleInfo(key: "output_min") var outputMin: MLXArray? + @ModuleInfo(key: "output_max") var outputMax: MLXArray? + let useClipping: Bool + + init(inFeatures: Int, outFeatures: Int, bias: Bool = false, useClipping: Bool = true) { + self.useClipping = useClipping + self._linear.wrappedValue = Linear(inFeatures, outFeatures, bias: bias) + if useClipping { + self._inputMin.wrappedValue = MLXArray(-Float.infinity) + self._inputMax.wrappedValue = MLXArray(Float.infinity) + self._outputMin.wrappedValue = MLXArray(-Float.infinity) + self._outputMax.wrappedValue = MLXArray(Float.infinity) + } + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let clippedInput: MLXArray + if let inputMin, let inputMax { + clippedInput = clip(x, min: inputMin, max: inputMax) + } else { + clippedInput = x + } + let projected = linear(clippedInput) + if let outputMin, let outputMax { + return clip(projected, min: outputMin, max: outputMax) + } + return projected + } +} + +/// LayerNorm without bias, matching `nn.LayerNorm(dims, bias=False)` in the Python model. +/// The checkpoint stores a single `weight` parameter at the `norm` key. +private final class Gemma4AudioLayerNorm: Module, UnaryLayer { + @ModuleInfo var weight: MLXArray + let eps: Float + + init(dimensions: Int, eps: Float = 1e-6) { + self.eps = eps + self._weight.wrappedValue = MLXArray.ones([dimensions]) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let xFloat = x.asType(.float32) + let meanVal = MLX.mean(xFloat, axis: -1, keepDims: true) + let variance = MLX.mean((xFloat - meanVal).square(), axis: -1, keepDims: true) + let normalized = (xFloat - meanVal) * rsqrt(variance + eps) + return (normalized * weight.asType(.float32)).asType(x.dtype) + } +} + +private final class Gemma4SSCPConvBlock: Module { + let timeStride: Int = 2 + let padding: (Int, Int, Int, Int) = (1, 1, 1, 1) + + @ModuleInfo(key: "conv") var conv: Conv2d + @ModuleInfo(key: "norm") var norm: Gemma4AudioLayerNorm + + init(config: Gemma4AudioConfiguration, idx: Int) { + let inChannels = idx == 0 ? 1 : config.subsamplingConvChannels[idx - 1] + let outChannels = config.subsamplingConvChannels[idx] + + // Conv2d: MLX expects [B, H, W, C], weight [C_out, kH, kW, C_in] + self._conv.wrappedValue = Conv2d( + inputChannels: inChannels, + outputChannels: outChannels, + kernelSize: 3, + stride: 2, + padding: 0, + bias: false + ) + + self._norm.wrappedValue = Gemma4AudioLayerNorm( + dimensions: outChannels, eps: config.rmsNormEps) + super.init() + } + + func callAsFunction(_ x: MLXArray, mask: MLXArray) -> (MLXArray, MLXArray) { + // x: [B, T, F, C] (MLX channel-last) + // mask: [B, T] (True = invalid/padding) + + // Zero out invalid positions + var x = MLX.where( + expandedDimensions(expandedDimensions(mask, axis: -1), axis: -1), + MLXArray(0.0, dtype: x.dtype), x) + + // Manual padding on T and F dims + x = MLX.padded( + x, + widths: [ + .init((0, 0)), .init((padding.0, padding.1)), + .init((padding.2, padding.3)), .init((0, 0)), + ]) + + x = conv(x) // [B, T_out, F_out, C_out] + + // Downsample mask by time stride + let tOut = x.dim(1) + let downsampled = mask[0..., .stride(by: timeStride)] + let outputMask = downsampled[0..., .. (MLXArray, MLXArray) { + // audioMel: [B, T, F_in] + // Add channel dim: [B, T, F, 1] + var x = expandedDimensions(audioMel, axis: -1) + + var currentMask = mask + (x, currentMask) = layer0(x, mask: currentMask) + (x, currentMask) = layer1(x, mask: currentMask) + + // Flatten F*C -> [B, T, F*C] + let batchSize = x.dim(0) + let timeSteps = x.dim(1) + let freqBins = x.dim(2) + let channels = x.dim(3) + x = x.reshaped(batchSize, timeSteps, freqBins * channels) + + // Project to hidden_size + x = inputProjLinear(x) + + return (x, currentMask) + } +} + +private final class Gemma4ConformerFeedForward: Module { + let gradientClipping: Float + let residualWeight: Float + + @ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "ffw_layer_1") var ffwLayer1: Gemma4AudioClippableLinear + @ModuleInfo(key: "ffw_layer_2") var ffwLayer2: Gemma4AudioClippableLinear + @ModuleInfo(key: "post_layer_norm") var postLayerNorm: Gemma4AudioRMSNorm + + init(config: Gemma4AudioConfiguration) { + self.gradientClipping = config.gradientClipping + self.residualWeight = config.residualWeight + + self._preLayerNorm.wrappedValue = Gemma4AudioRMSNorm(dimensions: config.hiddenSize) + self._ffwLayer1.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize * 4, + useClipping: config.useClippedLinears) + self._ffwLayer2.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize * 4, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + self._postLayerNorm.wrappedValue = Gemma4AudioRMSNorm(dimensions: config.hiddenSize) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let residual = x + var h = clip(x, min: -gradientClipping, max: gradientClipping) + h = preLayerNorm(h) + h = ffwLayer1(h) + h = silu(h) + h = ffwLayer2(h) + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = postLayerNorm(h) + return residual + h * residualWeight + } +} + +private final class Gemma4AudioRelativePositionEmbedding: Module { + let numHeads: Int + let channels: Int + let headDim: Int + let maxBackward: Int + let maxForward: Int + let invTimescales: MLXArray + + @ModuleInfo(key: "pos_proj") var posProj: Linear + + init(config: Gemma4AudioConfiguration) { + self.numHeads = config.numAttentionHeads + self.channels = config.hiddenSize + self.headDim = config.hiddenSize / config.numAttentionHeads + self.maxBackward = max(0, config.attentionContextLeft - 1) + self.maxForward = config.attentionContextRight + + self._posProj.wrappedValue = Linear( + config.hiddenSize, config.numAttentionHeads * headDim, bias: false) + + let minTimescale: Float = 1.0 + let maxTimescale: Float = 10000.0 + let numTimescales = config.hiddenSize / 2 + let logTimescaleIncrement = + Foundation.log(maxTimescale / minTimescale) / Float(max(numTimescales - 1, 1)) + self.invTimescales = + MLXArray(minTimescale) + * MLX.exp(MLXArray(0 ..< numTimescales).asType(.float32) * (-logTimescaleIncrement)) + + super.init() + } + + private func getTimingSignal(_ position: MLXArray, dtype: DType) -> MLXArray { + let posFloat = position.asType(.float32) + let pos = expandedDimensions(posFloat, axis: -1) + let invTS = invTimescales.reshaped(1, 1, -1) + let scaledTime = pos * invTS + let signal = concatenated([sin(scaledTime), cos(scaledTime)], axis: -1) + return signal.asType(dtype) + } + + private func relativeShift( + _ termBD: MLXArray, batchSize: Int, numHeads: Int, numBlocks: Int, + blockSize: Int, contextSize: Int, maxSpanPlus1: Int + ) -> MLXArray { + let padAmount = (contextSize + 1) - maxSpanPlus1 + var shifted = MLX.padded( + termBD, + widths: [ + .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, padAmount)), + ]) + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize * (contextSize + 1)) + shifted = shifted[0..., 0..., 0..., ..<(blockSize * contextSize)] + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize, contextSize) + return shifted + } + + func callAsFunction(queries: MLXArray, keys: MLXArray) -> MLXArray { + // queries: [B, U, W, N, H], keys: [B, U, C, N, H] + let batchSize = queries.dim(0) + let numBlocks = queries.dim(1) + let blockSize = queries.dim(2) + let contextSize = keys.dim(2) + + let posIndices = MLXArray( + stride(from: maxBackward, through: -maxForward, by: -1).map { Int32($0) } + ) + .reshaped(1, -1) + let maxSpanPlus1 = posIndices.dim(1) + + var sinEmb = getTimingSignal(posIndices, dtype: queries.dtype) + sinEmb = posProj(sinEmb.asType(posProj.weight.dtype)) + sinEmb = sinEmb.reshaped(maxSpanPlus1, numHeads, headDim) + sinEmb = sinEmb.asType(queries.dtype) + + // queries_p: [B, N, U, W, H], keys_p: [B, N, U, H, C] + let queriesP = queries.transposed(0, 3, 1, 2, 4) + let keysP = keys.transposed(0, 3, 1, 4, 2) + let termAC = queriesP.matmul(keysP) + + // sin_emb_t: [N, H, maxSpan] + let sinEmbT = sinEmb.transposed(1, 2, 0) + let qReshaped = queriesP.reshaped(batchSize, numHeads, numBlocks * blockSize, headDim) + var termBD = qReshaped.matmul(sinEmbT).reshaped( + batchSize, numHeads, numBlocks, blockSize, maxSpanPlus1) + + termBD = relativeShift( + termBD, batchSize: batchSize, numHeads: numHeads, numBlocks: numBlocks, + blockSize: blockSize, contextSize: contextSize, maxSpanPlus1: maxSpanPlus1) + + return termAC + termBD + } +} + +private final class Gemma4AudioAttention: Module { + let numHeads: Int + let hiddenSize: Int + let headDim: Int + let chunkSize: Int + let maxFutureHorizon: Int + let maxPastHorizon: Int + let contextSize: Int + let invalidLogitsValue: Float + let softcap: Float + let qScale: Float + let kScale: Float + + @ModuleInfo(key: "relative_k_proj") var relativeKProj: Linear + @ParameterInfo(key: "per_dim_scale") var perDimScale: MLXArray + @ModuleInfo(key: "q_proj") var qProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "k_proj") var kProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "v_proj") var vProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "post") var post: Gemma4AudioClippableLinear + + // Relative position embedding (inline) + // Note: relPosInvTimescales is NOT a model parameter — it's a computed constant. + // Store as [Float] to avoid MLX Module treating it as a loadable weight. + private let relPosNumHeads: Int + private let relPosHeadDim: Int + private let relPosMaxBackward: Int + private let relPosMaxForward: Int + private let relPosInvTimescalesData: [Float] + + init(config: Gemma4AudioConfiguration) { + self.numHeads = config.numAttentionHeads + self.hiddenSize = config.hiddenSize + self.headDim = config.hiddenSize / config.numAttentionHeads + self.chunkSize = config.attentionChunkSize + self.maxFutureHorizon = config.attentionContextRight + self.maxPastHorizon = max(0, config.attentionContextLeft - 1) + self.contextSize = chunkSize + maxPastHorizon + maxFutureHorizon + self.invalidLogitsValue = config.attentionInvalidLogitsValue + self.softcap = config.attentionLogitCap + + self.qScale = pow(Float(headDim), -0.5) / Foundation.log(2.0) + self.kScale = Foundation.log(1 + Foundation.exp(1.0)) / Foundation.log(2.0) + + self._relativeKProj.wrappedValue = Linear( + config.hiddenSize, config.numAttentionHeads * headDim, bias: false) + self._perDimScale.wrappedValue = MLXArray.zeros([headDim]) + self._qProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._kProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._vProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._post.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + + // Relative position embedding setup + self.relPosNumHeads = numHeads + self.relPosHeadDim = headDim + self.relPosMaxBackward = maxPastHorizon + self.relPosMaxForward = maxFutureHorizon + + let minTimescale: Float = 1.0 + let maxTimescale: Float = 10000.0 + let numTimescales = config.hiddenSize / 2 + let logTimescaleIncrement = + Foundation.log(maxTimescale / minTimescale) / Float(max(numTimescales - 1, 1)) + self.relPosInvTimescalesData = (0 ..< numTimescales).map { i in + minTimescale * Foundation.exp(Float(i) * (-logTimescaleIncrement)) + } + + super.init() + } + + private func padDim1(_ x: MLXArray, padLeft: Int, padRight: Int) -> MLXArray { + var widths = Array(repeating: IntOrPair((0, 0)), count: x.ndim) + widths[1] = IntOrPair((padLeft, padRight)) + return MLX.padded(x, widths: widths) + } + + private func convertToBlock(_ x: MLXArray) -> MLXArray { + // [B, T, ...] -> [B, num_blocks, chunk_size, ...] + let batchSize = x.dim(0) + let timeSteps = x.dim(1) + let rest = Array(x.shape.dropFirst(2)) + let numBlocks = (timeSteps + chunkSize - 1) / chunkSize + let padLen = numBlocks * chunkSize - timeSteps + var result = x + if padLen > 0 { + result = padDim1(result, padLeft: 0, padRight: padLen) + } + return result.reshaped([batchSize, numBlocks, chunkSize] + rest) + } + + private func extractBlockContext(_ x: MLXArray) -> MLXArray { + // [B, T, ...] -> [B, num_blocks, context_size, ...] + let padLeft = maxPastHorizon + let padRight = maxFutureHorizon + chunkSize - 1 + let padded = padDim1(x, padLeft: padLeft, padRight: padRight) + let tPadded = padded.dim(1) + let numBlocks = (tPadded - contextSize) / chunkSize + 1 + + // Build indices: starts[:, None] + offsets[None, :] + let starts = MLXArray( + stride(from: 0, to: numBlocks * chunkSize, by: chunkSize).map { + Int32($0) + }) + let offsets = MLXArray((0 ..< contextSize).map { Int32($0) }) + let indices = expandedDimensions(starts, axis: 1) + expandedDimensions(offsets, axis: 0) + // indices: [numBlocks, contextSize] + + // Gather using advanced indexing + // padded: [B, T_padded, ...rest] + // We need padded[:, indices] which gives [B, numBlocks, contextSize, ...rest] + return padded[0..., indices] + } + + private func relPosTimingSignal(_ position: MLXArray, dtype: DType) -> MLXArray { + let posFloat = position.asType(.float32) + let pos = expandedDimensions(posFloat, axis: -1) + let invTS = MLXArray(relPosInvTimescalesData).reshaped(1, 1, -1) + let scaledTime = pos * invTS + let signal = concatenated([sin(scaledTime), cos(scaledTime)], axis: -1) + return signal.asType(dtype) + } + + private func relPosRelativeShift( + _ termBD: MLXArray, batchSize: Int, numHeads: Int, numBlocks: Int, + blockSize: Int, contextSize: Int, maxSpanPlus1: Int + ) -> MLXArray { + let padAmount = (contextSize + 1) - maxSpanPlus1 + var shifted = MLX.padded( + termBD, + widths: [ + .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, padAmount)), + ]) + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize * (contextSize + 1)) + shifted = shifted[0..., 0..., 0..., ..<(blockSize * contextSize)] + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize, contextSize) + return shifted + } + + private func computeRelativePositionLogits(queries: MLXArray, keys: MLXArray) -> MLXArray { + // queries: [B, U, W, N, H], keys: [B, U, C, N, H] + let batchSize = queries.dim(0) + let numBlocks = queries.dim(1) + let blockSize = queries.dim(2) + let ctxSize = keys.dim(2) + + // Past-only relative positions [maxPastHorizon ... 0], matching Google's + // reference (`torch.arange(max_past_horizon, -1, -1)`). pr-192 originally + // used a symmetric [maxBackward ... -maxForward] span, which over-counts + // positions by maxFutureHorizon and misaligns relPosRelativeShift's pad + // math → semantically-wrong attention bias (audio not understood). + let posIndices = MLXArray( + stride(from: relPosMaxBackward, through: 0, by: -1).map { Int32($0) } + ).reshaped(1, -1) + let maxSpanPlus1 = posIndices.dim(1) + + var sinEmb = relPosTimingSignal(posIndices, dtype: queries.dtype) + sinEmb = relativeKProj(sinEmb.asType(relativeKProj.weight.dtype)) + sinEmb = sinEmb.reshaped(maxSpanPlus1, relPosNumHeads, relPosHeadDim) + sinEmb = sinEmb.asType(queries.dtype) + + let queriesP = queries.transposed(0, 3, 1, 2, 4) + let keysP = keys.transposed(0, 3, 1, 4, 2) + let termAC = queriesP.matmul(keysP) + + let sinEmbT = sinEmb.transposed(1, 2, 0) + let qReshaped = queriesP.reshaped( + batchSize, relPosNumHeads, numBlocks * blockSize, relPosHeadDim) + var termBD = qReshaped.matmul(sinEmbT).reshaped( + batchSize, relPosNumHeads, numBlocks, blockSize, maxSpanPlus1) + + termBD = relPosRelativeShift( + termBD, batchSize: batchSize, numHeads: relPosNumHeads, numBlocks: numBlocks, + blockSize: blockSize, contextSize: ctxSize, maxSpanPlus1: maxSpanPlus1) + + return termAC + termBD + } + + func callAsFunction( + _ hiddenStates: MLXArray, mask: MLXArray, causalValidMask: MLXArray + ) -> MLXArray { + let batchSize = hiddenStates.dim(0) + let timeSteps = hiddenStates.dim(1) + let qkvShape = [batchSize, timeSteps, numHeads, headDim] + + var q = qProj(hiddenStates).asType(.float32).reshaped(qkvShape) + var k = kProj(hiddenStates).asType(.float32).reshaped(qkvShape) + let v = vProj(hiddenStates).asType(.float32).reshaped(qkvShape) + + let pds = softplus(perDimScale) + q = q * (qScale * pds) + k = k * kScale + + let queryBlocks = convertToBlock(q) // [B, U, W, N, H] + let keyBlocks = extractBlockContext(k) // [B, U, C, N, H] + let valueBlocks = extractBlockContext(v) // [B, U, C, N, H] + let numBlocks = queryBlocks.dim(1) + + // Build validity condition + let validMask = logicalNot(mask) // True = valid + let extractedValid = extractBlockContext(validMask) // [B, U, C] + // condition: [B, 1, U, W, C] + let condition = + expandedDimensions(expandedDimensions(extractedValid, axis: 1), axis: 3) + * expandedDimensions( + expandedDimensions(expandedDimensions(causalValidMask, axis: 0), axis: 0), axis: 0) + + var logits = computeRelativePositionLogits(queries: queryBlocks, keys: keyBlocks) + logits = tanh(logits / softcap) * softcap + logits = MLX.where( + condition .> 0, logits, MLXArray(invalidLogitsValue, dtype: logits.dtype)) + + let probs = softmax(logits, axis: -1) + // context = einsum("bnuwc,bucnh->buwnh", probs, valueBlocks) + var context = einsum("bnuwc,bucnh->buwnh", probs, valueBlocks) + context = context.reshaped(batchSize, numBlocks * chunkSize, numHeads, headDim) + context = context[0..., .. [B, T, D] and post-project + context = context.reshaped(batchSize, timeSteps, numHeads * headDim) + return post(context) + } +} + +private final class Gemma4ConformerLightConv1d: Module { + let gradientClipping: Float + let causalPadding: Int + + @ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "linear_start") var linearStart: Gemma4AudioClippableLinear + @ModuleInfo(key: "depthwise_conv1d") var depthwiseConv1d: Conv1d + @ModuleInfo(key: "conv_norm") var convNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "linear_end") var linearEnd: Gemma4AudioClippableLinear + + init(config: Gemma4AudioConfiguration) { + self.gradientClipping = config.gradientClipping + self.causalPadding = config.convKernelSize - 1 + + self._preLayerNorm.wrappedValue = Gemma4AudioRMSNorm( + dimensions: config.hiddenSize, eps: config.rmsNormEps) + self._linearStart.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize * 2, + useClipping: config.useClippedLinears) + // Depthwise conv1d: groups = hidden_size so weight shape is [out, kernel, 1] + self._depthwiseConv1d.wrappedValue = Conv1d( + inputChannels: config.hiddenSize, + outputChannels: config.hiddenSize, + kernelSize: config.convKernelSize, + stride: 1, + padding: 0, + groups: config.hiddenSize, + bias: false + ) + self._convNorm.wrappedValue = Gemma4AudioRMSNorm( + dimensions: config.hiddenSize, eps: config.rmsNormEps) + self._linearEnd.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let residual = x + + var h = preLayerNorm(x) + h = linearStart(h) + + // GLU: split in half along last dim and gate + let halfDim = h.dim(-1) / 2 + let x1 = h[0..., 0..., .. MLXArray { + var h = feedForward1(x) + + // Attention with pre/post norm and residual + let residual = h + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = normPreAttn(h) + h = selfAttn(h, mask: mask, causalValidMask: causalValidMask) + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = residual + normPostAttn(h) + + // Zero out invalid positions before lconv1d + let validityMask = expandedDimensions(logicalNot(mask), axis: -1).asType(h.dtype) + h = h * validityMask + + h = lconv1d(h) + h = feedForward2(h) + h = clip(h, min: -gradientClipping, max: gradientClipping) + return normOut(h) + } +} + +private final class Gemma4AudioEncoder: Module { + let config: Gemma4AudioConfiguration + + @ModuleInfo(key: "subsample_conv_projection") var subsampleConvProjection: + Gemma4SubSampleConvProjection + @ModuleInfo(key: "layers") var layers: [Gemma4ConformerBlock] + @ModuleInfo(key: "output_proj") var outputProj: Linear? + + init(config: Gemma4AudioConfiguration) { + self.config = config + self._subsampleConvProjection.wrappedValue = Gemma4SubSampleConvProjection(config: config) + self._layers.wrappedValue = (0 ..< config.numHiddenLayers).map { _ in + Gemma4ConformerBlock(config: config) + } + if let outputProjDims = config.outputProjDims { + self._outputProj.wrappedValue = Linear( + config.hiddenSize, outputProjDims, bias: true) + } + super.init() + } + + private func buildCausalValidMask() -> MLXArray { + let chunkSize = config.attentionChunkSize + let maxFutureHorizon = config.attentionContextRight + let maxPastHorizon = max(0, config.attentionContextLeft - 1) + let upperDiagonal = maxPastHorizon + maxFutureHorizon + let ctxSize = chunkSize + maxPastHorizon + maxFutureHorizon + + let lowerCausal = tril(MLXArray.ones([ctxSize, chunkSize])).transposed() + let upperCausal = tril( + MLXArray.ones([chunkSize, ctxSize]), + k: upperDiagonal) + let maskResult = (lowerCausal * upperCausal).asType(.bool) + return maskResult + } + + func callAsFunction(_ audioMel: MLXArray, audioMelMask: MLXArray) -> (MLXArray, MLXArray) { + var (audioEncodings, currentMask) = subsampleConvProjection(audioMel, mask: audioMelMask) + + let causalValidMask = buildCausalValidMask() + + for block in layers { + audioEncodings = block( + audioEncodings, mask: currentMask, causalValidMask: causalValidMask) + } + + if let outputProj { + audioEncodings = outputProj(audioEncodings) + } + + if currentMask.dim(1) != audioEncodings.dim(1) { + let targetLen = audioEncodings.dim(1) + currentMask = currentMask[0..., ..