Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions Libraries/MLXLLM/Models/Gemma4Text.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1117,11 +1117,8 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens
let scatterIdx2D = selectedCanonicalShaped.reshaped([B * S, totalCandidates]).asType(.int32)
let selectedLogits2D = selectedLogits.reshaped([B * S, totalCandidates])
var output2D = output.reshaped([B * S, vocabSize])
for bsIdx in 0 ..< B * S {
let idxRow = scatterIdx2D[bsIdx] // [totalCandidates]
let valRow = selectedLogits2D[bsIdx] // [totalCandidates]
output2D[bsIdx, idxRow] = valRow
}
let rowIndices = MLXArray.arange(B * S).asType(.int32).reshaped([B * S, 1])
output2D[rowIndices, scatterIdx2D] = selectedLogits2D
output = output2D.reshaped([B, S, vocabSize])

return output
Expand Down
25 changes: 23 additions & 2 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ protocol TokenIteratorProtocol: Sequence, IteratorProtocol where Element == Int
var tokenCount: Int { get }
var promptPrefillTime: TimeInterval { get }
var streamingError: SSDStreamingError? { get }
var acceptedDraftTokens: Int { get }
var totalDraftTokens: Int { get }
}

/// Generator of tokens.
Expand Down Expand Up @@ -549,6 +551,8 @@ public struct TokenIterator: TokenIteratorProtocol {
var promptPrefillTime: TimeInterval = 0.0
var streamingError: SSDStreamingError?
let ssdErrorLatch = SSDStreamingErrorLatch()
var acceptedDraftTokens = 0
var totalDraftTokens = 0

/// Initialize a `TokenIterator` with the given tokens. Note: this has been
/// replaced with ``init(input:model:cache:parameters:)``.
Expand Down Expand Up @@ -794,6 +798,8 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol {

// Internal metrics
var promptPrefillTime: TimeInterval = 0.0
var acceptedDraftTokens = 0
var totalDraftTokens = 0

/// Initialize a `SpeculativeTokenIterator` with the given input.
///
Expand Down Expand Up @@ -1014,6 +1020,9 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol {
trimPromptCache(mainCache, numTokens: numDraft - accepted)
trimPromptCache(draftCache, numTokens: Swift.max(numDraft - accepted - 1, 0))

self.acceptedDraftTokens += accepted
self.totalDraftTokens += draftTokens.count

// Apply dynamic cache quantization after rewind
quantizeKVCache(&mainCache)
quantizeKVCache(&draftCache)
Expand Down Expand Up @@ -2228,7 +2237,9 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
generationTokenCount: tokenCount,
promptTime: promptTime + iterator.promptPrefillTime,
generationTime: generateTime,
stopReason: stopReason ?? .cancelled
stopReason: stopReason ?? .cancelled,
acceptedDraftTokens: iterator.acceptedDraftTokens,
totalDraftTokens: iterator.totalDraftTokens
)
_ = continuation.yield(handler.infoEvent(info))

Expand Down Expand Up @@ -2298,6 +2309,12 @@ public struct GenerateCompletionInfo: Sendable {
/// Reason generation stopped.
public let stopReason: GenerateStopReason

/// Number of accepted draft tokens (if speculative decoding is active).
public let acceptedDraftTokens: Int

/// Total number of draft tokens evaluated (if speculative decoding is active).
public let totalDraftTokens: Int

/// The number of tokens processed per second during the prompt phase.
public var promptTokensPerSecond: Double {
Double(promptTokenCount) / promptTime
Expand All @@ -2313,13 +2330,17 @@ public struct GenerateCompletionInfo: Sendable {
generationTokenCount: Int,
promptTime: TimeInterval,
generationTime: TimeInterval,
stopReason: GenerateStopReason = .stop
stopReason: GenerateStopReason = .stop,
acceptedDraftTokens: Int = 0,
totalDraftTokens: Int = 0
) {
self.promptTokenCount = promptTokenCount
self.generationTokenCount = generationTokenCount
self.promptTime = promptTime
self.generateTime = generationTime
self.stopReason = stopReason
self.acceptedDraftTokens = acceptedDraftTokens
self.totalDraftTokens = totalDraftTokens
}

public func summary() -> String {
Expand Down
Loading