From b9bf50bdafef02fffd5b83598a61bbf7d47434f9 Mon Sep 17 00:00:00 2001 From: Aegis-AI Date: Tue, 12 May 2026 13:10:15 -0700 Subject: [PATCH 1/2] Fix GPU Hang in Gemma4 cluster logit projection by vectorizing 2D scatter and implement MTP acceptance rate metric extraction --- Libraries/MLXLLM/Models/Gemma4Text.swift | 7 ++----- Libraries/MLXLMCommon/Evaluate.swift | 25 ++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/Libraries/MLXLLM/Models/Gemma4Text.swift b/Libraries/MLXLLM/Models/Gemma4Text.swift index 6fe3fbcd0..68cb1112e 100644 --- a/Libraries/MLXLLM/Models/Gemma4Text.swift +++ b/Libraries/MLXLLM/Models/Gemma4Text.swift @@ -1117,11 +1117,8 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens let scatterIdx2D = selectedCanonicalShaped.reshaped([B * S, totalCandidates]).asType(.int32) let selectedLogits2D = selectedLogits.reshaped([B * S, totalCandidates]) var output2D = output.reshaped([B * S, vocabSize]) - for bsIdx in 0 ..< B * S { - let idxRow = scatterIdx2D[bsIdx] // [totalCandidates] - let valRow = selectedLogits2D[bsIdx] // [totalCandidates] - output2D[bsIdx, idxRow] = valRow - } + let rowIndices = MLXArray(0 ..< Int32(B * S)).reshaped([B * S, 1]) + output2D[rowIndices, scatterIdx2D] = selectedLogits2D output = output2D.reshaped([B, S, vocabSize]) return output diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 22e1988e8..07d73dd06 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -503,6 +503,8 @@ protocol TokenIteratorProtocol: Sequence, IteratorProtocol where Element == Int var tokenCount: Int { get } var promptPrefillTime: TimeInterval { get } var streamingError: SSDStreamingError? { get } + var acceptedDraftTokens: Int { get } + var totalDraftTokens: Int { get } } /// Generator of tokens. @@ -549,6 +551,8 @@ public struct TokenIterator: TokenIteratorProtocol { var promptPrefillTime: TimeInterval = 0.0 var streamingError: SSDStreamingError? let ssdErrorLatch = SSDStreamingErrorLatch() + var acceptedDraftTokens = 0 + var totalDraftTokens = 0 /// Initialize a `TokenIterator` with the given tokens. Note: this has been /// replaced with ``init(input:model:cache:parameters:)``. @@ -794,6 +798,8 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol { // Internal metrics var promptPrefillTime: TimeInterval = 0.0 + var acceptedDraftTokens = 0 + var totalDraftTokens = 0 /// Initialize a `SpeculativeTokenIterator` with the given input. /// @@ -1014,6 +1020,9 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol { trimPromptCache(mainCache, numTokens: numDraft - accepted) trimPromptCache(draftCache, numTokens: Swift.max(numDraft - accepted - 1, 0)) + self.acceptedDraftTokens += accepted + self.totalDraftTokens += draftTokens.count + // Apply dynamic cache quantization after rewind quantizeKVCache(&mainCache) quantizeKVCache(&draftCache) @@ -2228,7 +2237,9 @@ private func generateLoopTask( generationTokenCount: tokenCount, promptTime: promptTime + iterator.promptPrefillTime, generationTime: generateTime, - stopReason: stopReason ?? .cancelled + stopReason: stopReason ?? .cancelled, + acceptedDraftTokens: iterator.acceptedDraftTokens, + totalDraftTokens: iterator.totalDraftTokens ) _ = continuation.yield(handler.infoEvent(info)) @@ -2298,6 +2309,12 @@ public struct GenerateCompletionInfo: Sendable { /// Reason generation stopped. public let stopReason: GenerateStopReason + /// Number of accepted draft tokens (if speculative decoding is active). + public let acceptedDraftTokens: Int + + /// Total number of draft tokens evaluated (if speculative decoding is active). + public let totalDraftTokens: Int + /// The number of tokens processed per second during the prompt phase. public var promptTokensPerSecond: Double { Double(promptTokenCount) / promptTime @@ -2313,13 +2330,17 @@ public struct GenerateCompletionInfo: Sendable { generationTokenCount: Int, promptTime: TimeInterval, generationTime: TimeInterval, - stopReason: GenerateStopReason = .stop + stopReason: GenerateStopReason = .stop, + acceptedDraftTokens: Int = 0, + totalDraftTokens: Int = 0 ) { self.promptTokenCount = promptTokenCount self.generationTokenCount = generationTokenCount self.promptTime = promptTime self.generateTime = generationTime self.stopReason = stopReason + self.acceptedDraftTokens = acceptedDraftTokens + self.totalDraftTokens = totalDraftTokens } public func summary() -> String { From b42d9a08cb7a68c8151666e9b6a710e6275451d8 Mon Sep 17 00:00:00 2001 From: Aegis-AI Date: Tue, 12 May 2026 14:27:23 -0700 Subject: [PATCH 2/2] Address Copilot review: avoid Int32 overflow via MLX.arange --- Libraries/MLXLLM/Models/Gemma4Text.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Libraries/MLXLLM/Models/Gemma4Text.swift b/Libraries/MLXLLM/Models/Gemma4Text.swift index 68cb1112e..8afbba152 100644 --- a/Libraries/MLXLLM/Models/Gemma4Text.swift +++ b/Libraries/MLXLLM/Models/Gemma4Text.swift @@ -1117,7 +1117,7 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens let scatterIdx2D = selectedCanonicalShaped.reshaped([B * S, totalCandidates]).asType(.int32) let selectedLogits2D = selectedLogits.reshaped([B * S, totalCandidates]) var output2D = output.reshaped([B * S, vocabSize]) - let rowIndices = MLXArray(0 ..< Int32(B * S)).reshaped([B * S, 1]) + let rowIndices = MLXArray.arange(B * S).asType(.int32).reshaped([B * S, 1]) output2D[rowIndices, scatterIdx2D] = selectedLogits2D output = output2D.reshaped([B, S, vocabSize])