Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions Libraries/MLXLMCommon/ChatSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -499,21 +499,19 @@ public final class ChatSession {
await cache.read { _ in }
}

/// Returns the current KV cache, if one has been built.
/// Visit the current cache value, if realized as a `[KVCache]`.
///
/// Returns `nil` if no generation has occurred yet (cache is still empty) or if the
/// session is in history-rehydration mode and generation has not started.
///
/// The returned array holds references to the live cache objects — do not use them
/// concurrently with an active ``respond(to:role:images:videos:)`` or
/// ``streamResponse(_:)`` call on the same session. To persist the cache
/// across process launches, use ``saveCache(to:)`` instead.
public func currentCache() async -> [KVCache]? {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a little too permissive -- it is added for test support so we can just switch it to an internal scoped visitor.

await cache.read { cache in
if case .kvcache(let array) = cache {
return array
/// This method is meant for test support.
func withCache<R: Sendable>(_ body: @Sendable ([KVCache]?) async throws -> R) async rethrows
-> R?
{
try await cache.read { cache in
switch cache {
case .kvcache(let cache):
return try await body(cache)
default:
return try await body(nil)
}
return nil
}
}

Expand All @@ -526,10 +524,14 @@ public final class ChatSession {
/// - Throws: ``ChatSessionError/noCacheAvailable`` if no generation has occurred yet,
/// or any error thrown by the underlying file write
public func saveCache(to url: URL) async throws {
guard let kvCache = await currentCache() else {
throw ChatSessionError.noCacheAvailable
try await cache.read { cache in
switch cache {
case .kvcache(let cache):
try savePromptCache(url: url, cache: cache)
default:
throw ChatSessionError.noCacheAvailable
}
}
try savePromptCache(url: url, cache: kvCache)
}
}

Expand Down
29 changes: 4 additions & 25 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// swift-tools-version: 5.12
// swift-tools-version: 6.1

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to swift 6 -- prevent concurrency issues from making it past CI

// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription
Expand Down Expand Up @@ -26,7 +26,7 @@ let package = Package(
targets: ["MLXEmbedders"]),
],
dependencies: [
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.1")),
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")),
.package(
url: "https://github.com/huggingface/swift-transformers",
.upToNextMinor(from: "1.2.0")
Expand All @@ -45,9 +45,6 @@ let package = Package(
path: "Libraries/MLXLLM",
exclude: [
"README.md"
],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed.

]
),
.target(
Expand All @@ -62,9 +59,6 @@ let package = Package(
path: "Libraries/MLXVLM",
exclude: [
"README.md"
],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.target(
Expand All @@ -78,9 +72,6 @@ let package = Package(
path: "Libraries/MLXLMCommon",
exclude: [
"README.md"
],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.target(
Expand All @@ -94,9 +85,6 @@ let package = Package(
path: "Libraries/MLXEmbedders",
exclude: [
"README.md"
],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.testTarget(
Expand All @@ -115,10 +103,7 @@ let package = Package(
exclude: [
"README.md"
],
resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")]
),
.testTarget(
name: "MLXLMIntegrationTests",
Expand All @@ -135,9 +120,6 @@ let package = Package(
path: "Tests/MLXLMIntegrationTests",
exclude: [
"README.md"
],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.testTarget(
Expand All @@ -147,10 +129,7 @@ let package = Package(
"MLXVLM",
"MLXLMCommon",
],
path: "Tests/Benchmarks",
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
path: "Tests/Benchmarks"
),
]
)
Expand Down
79 changes: 48 additions & 31 deletions Tests/MLXLMTests/ChatSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import Foundation
import MLX
import MLXLLM
import MLXLMCommon
import MLXNN
import MLXOptimizers
import Tokenizers
import XCTest

@testable import MLXLMCommon

/// See also ChatSessionIntegrationTests
public class ChatSessionTests: XCTestCase {

Expand Down Expand Up @@ -162,36 +163,43 @@ public class ChatSessionTests: XCTestCase {
// MARK: - KV Cache

func testCurrentCacheNilBeforeGeneration() async throws {
let session = ChatSession(model())
let cache = await session.currentCache()
XCTAssertNil(cache)
let session = ChatSession(model(), generateParameters: generationParameters)
await session.withCache { cache in
XCTAssertNil(cache)
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the pattern to fix the tests -- just use the visitor.

}

func testCurrentCacheAfterGeneration() async throws {
let session = ChatSession(model())
let session = ChatSession(model(), generateParameters: generationParameters)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #128, those changes were in flight while these calls were added. This just syncs everything up -- this is needed to make the random weight token generation stop.

_ = try await session.respond(to: "hello")
let cache = await session.currentCache()
XCTAssertNotNil(cache)
await session.withCache { cache in
XCTAssertNotNil(cache)
}
}

func testInitWithKVCache() async throws {
// build a cache from an initial session
let ctx = model()
let initial = ChatSession(ctx)
let container = ModelContainer(context: model())
let initial = ChatSession(container, generateParameters: generationParameters)
_ = try await initial.respond(to: "hello")
guard let cache = await initial.currentCache() else {
XCTFail("expected cache after generation")
return
}

// restore the cache into a new session and verify generation continues
let restored = ChatSession(ctx, cache: cache)
let result = try await restored.respond(to: "hello again")
XCTAssertGreaterThan(result.count, targetLength, result)
try await initial.withCache { [targetLength, generationParameters] cache in
XCTAssertNotNil(cache)

if let cache {
// restore the cache into a new session and verify generation continues
let restored = ChatSession(
container,
cache: cache.map { $0.copy() },
generateParameters: generationParameters)
let result = try await restored.respond(to: "hello again")
XCTAssertGreaterThan(result.count, targetLength, result)
}
}
}

func testSaveCacheThrowsBeforeGeneration() async throws {
let session = ChatSession(model())
let session = ChatSession(model(), generateParameters: generationParameters)
let url = FileManager.default.temporaryDirectory
.appendingPathComponent(UUID().uuidString)
.appendingPathExtension("safetensors")
Expand All @@ -205,7 +213,7 @@ public class ChatSessionTests: XCTestCase {

func testSaveAndRestoreCache() async throws {
let ctx = model()
let initial = ChatSession(ctx)
let initial = ChatSession(ctx, generateParameters: generationParameters)
_ = try await initial.respond(to: "hello")

let url = FileManager.default.temporaryDirectory
Expand All @@ -214,37 +222,46 @@ public class ChatSessionTests: XCTestCase {
try await initial.saveCache(to: url)

let (loadedCache, _) = try loadPromptCache(url: url)
let restored = ChatSession(ctx, cache: loadedCache)
let restored = ChatSession(
ctx, cache: loadedCache, generateParameters: generationParameters)
let result = try await restored.respond(to: "hello again")
XCTAssertGreaterThan(result.count, targetLength, result)
}

func testCurrentCacheNilForHistorySessionBeforeGeneration() async throws {
// .history state should behave like .empty: no cache until first generation
let history: [Chat.Message] = [.user("hello"), .assistant("hi")]
let session = ChatSession(model(), history: history)
let cache = await session.currentCache()
XCTAssertNil(cache)
let session = ChatSession(
model(), history: history, generateParameters: generationParameters)
await session.withCache { cache in
XCTAssertNil(cache)
}
}

func testCurrentCacheNonNilForHistorySessionAfterGeneration() async throws {
// after generation from .history state, cache transitions to .kvcache
let history: [Chat.Message] = [.user("hello"), .assistant("hi")]
let session = ChatSession(model(), history: history)
let session = ChatSession(
model(),
history: history,
generateParameters: generationParameters)
_ = try await session.respond(to: "hello again")
let cache = await session.currentCache()
XCTAssertNotNil(cache)
await session.withCache { cache in
XCTAssertNotNil(cache)
}
}

func testCurrentCacheNilAfterClear() async throws {
// clear() resets to .empty; currentCache() should return nil again
let session = ChatSession(model())
let session = ChatSession(model(), generateParameters: generationParameters)
_ = try await session.respond(to: "hello")
let cacheBeforeClear = await session.currentCache()
XCTAssertNotNil(cacheBeforeClear)
await session.withCache { cache in
XCTAssertNotNil(cache)
}
await session.clear()
let cacheAfterClear = await session.currentCache()
XCTAssertNil(cacheAfterClear)
await session.withCache { cache in
XCTAssertNil(cache)
}
}

/// something that looks like a view model
Expand Down
2 changes: 1 addition & 1 deletion Tests/MLXLMTests/KVCacheTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import MLX
import MLXLMCommon
import Testing

private let cacheCreators: [() -> any KVCache] = [
private let cacheCreators: [@Sendable () -> any KVCache] = [

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix sendable issue

{ KVCacheSimple() },
{ RotatingKVCache(maxSize: 32) },
{ QuantizedKVCache() },
Expand Down
Loading