diff --git a/Libraries/MLXLLM/Models/Gemma4Text.swift b/Libraries/MLXLLM/Models/Gemma4Text.swift index 6fe3fbcd0..8afbba152 100644 --- a/Libraries/MLXLLM/Models/Gemma4Text.swift +++ b/Libraries/MLXLLM/Models/Gemma4Text.swift @@ -1117,11 +1117,8 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens let scatterIdx2D = selectedCanonicalShaped.reshaped([B * S, totalCandidates]).asType(.int32) let selectedLogits2D = selectedLogits.reshaped([B * S, totalCandidates]) var output2D = output.reshaped([B * S, vocabSize]) - for bsIdx in 0 ..< B * S { - let idxRow = scatterIdx2D[bsIdx] // [totalCandidates] - let valRow = selectedLogits2D[bsIdx] // [totalCandidates] - output2D[bsIdx, idxRow] = valRow - } + let rowIndices = MLXArray.arange(B * S).asType(.int32).reshaped([B * S, 1]) + output2D[rowIndices, scatterIdx2D] = selectedLogits2D output = output2D.reshaped([B, S, vocabSize]) return output diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 22e1988e8..07d73dd06 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -503,6 +503,8 @@ protocol TokenIteratorProtocol: Sequence, IteratorProtocol where Element == Int var tokenCount: Int { get } var promptPrefillTime: TimeInterval { get } var streamingError: SSDStreamingError? { get } + var acceptedDraftTokens: Int { get } + var totalDraftTokens: Int { get } } /// Generator of tokens. @@ -549,6 +551,8 @@ public struct TokenIterator: TokenIteratorProtocol { var promptPrefillTime: TimeInterval = 0.0 var streamingError: SSDStreamingError? let ssdErrorLatch = SSDStreamingErrorLatch() + var acceptedDraftTokens = 0 + var totalDraftTokens = 0 /// Initialize a `TokenIterator` with the given tokens. Note: this has been /// replaced with ``init(input:model:cache:parameters:)``. @@ -794,6 +798,8 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol { // Internal metrics var promptPrefillTime: TimeInterval = 0.0 + var acceptedDraftTokens = 0 + var totalDraftTokens = 0 /// Initialize a `SpeculativeTokenIterator` with the given input. /// @@ -1014,6 +1020,9 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol { trimPromptCache(mainCache, numTokens: numDraft - accepted) trimPromptCache(draftCache, numTokens: Swift.max(numDraft - accepted - 1, 0)) + self.acceptedDraftTokens += accepted + self.totalDraftTokens += draftTokens.count + // Apply dynamic cache quantization after rewind quantizeKVCache(&mainCache) quantizeKVCache(&draftCache) @@ -2228,7 +2237,9 @@ private func generateLoopTask( generationTokenCount: tokenCount, promptTime: promptTime + iterator.promptPrefillTime, generationTime: generateTime, - stopReason: stopReason ?? .cancelled + stopReason: stopReason ?? .cancelled, + acceptedDraftTokens: iterator.acceptedDraftTokens, + totalDraftTokens: iterator.totalDraftTokens ) _ = continuation.yield(handler.infoEvent(info)) @@ -2298,6 +2309,12 @@ public struct GenerateCompletionInfo: Sendable { /// Reason generation stopped. public let stopReason: GenerateStopReason + /// Number of accepted draft tokens (if speculative decoding is active). + public let acceptedDraftTokens: Int + + /// Total number of draft tokens evaluated (if speculative decoding is active). + public let totalDraftTokens: Int + /// The number of tokens processed per second during the prompt phase. public var promptTokensPerSecond: Double { Double(promptTokenCount) / promptTime @@ -2313,13 +2330,17 @@ public struct GenerateCompletionInfo: Sendable { generationTokenCount: Int, promptTime: TimeInterval, generationTime: TimeInterval, - stopReason: GenerateStopReason = .stop + stopReason: GenerateStopReason = .stop, + acceptedDraftTokens: Int = 0, + totalDraftTokens: Int = 0 ) { self.promptTokenCount = promptTokenCount self.generationTokenCount = generationTokenCount self.promptTime = promptTime self.generateTime = generationTime self.stopReason = stopReason + self.acceptedDraftTokens = acceptedDraftTokens + self.totalDraftTokens = totalDraftTokens } public func summary() -> String {