diff --git a/Libraries/MLXLMCommon/ChatSession.swift b/Libraries/MLXLMCommon/ChatSession.swift index 147d9797b..be8b0d71d 100644 --- a/Libraries/MLXLMCommon/ChatSession.swift +++ b/Libraries/MLXLMCommon/ChatSession.swift @@ -499,21 +499,19 @@ public final class ChatSession { await cache.read { _ in } } - /// Returns the current KV cache, if one has been built. + /// Visit the current cache value, if realized as a `[KVCache]`. /// - /// Returns `nil` if no generation has occurred yet (cache is still empty) or if the - /// session is in history-rehydration mode and generation has not started. - /// - /// The returned array holds references to the live cache objects — do not use them - /// concurrently with an active ``respond(to:role:images:videos:)`` or - /// ``streamResponse(_:)`` call on the same session. To persist the cache - /// across process launches, use ``saveCache(to:)`` instead. - public func currentCache() async -> [KVCache]? { - await cache.read { cache in - if case .kvcache(let array) = cache { - return array + /// This method is meant for test support. + func withCache(_ body: @Sendable ([KVCache]?) async throws -> R) async rethrows + -> R? + { + try await cache.read { cache in + switch cache { + case .kvcache(let cache): + return try await body(cache) + default: + return try await body(nil) } - return nil } } @@ -526,10 +524,14 @@ public final class ChatSession { /// - Throws: ``ChatSessionError/noCacheAvailable`` if no generation has occurred yet, /// or any error thrown by the underlying file write public func saveCache(to url: URL) async throws { - guard let kvCache = await currentCache() else { - throw ChatSessionError.noCacheAvailable + try await cache.read { cache in + switch cache { + case .kvcache(let cache): + try savePromptCache(url: url, cache: cache) + default: + throw ChatSessionError.noCacheAvailable + } } - try savePromptCache(url: url, cache: kvCache) } } diff --git a/Package.swift b/Package.swift index 9c5154c12..b2a8c2528 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version: 5.12 +// swift-tools-version: 6.1 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription @@ -26,7 +26,7 @@ let package = Package( targets: ["MLXEmbedders"]), ], dependencies: [ - .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.1")), + .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")), .package( url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "1.2.0") @@ -45,9 +45,6 @@ let package = Package( path: "Libraries/MLXLLM", exclude: [ "README.md" - ], - swiftSettings: [ - .enableExperimentalFeature("StrictConcurrency") ] ), .target( @@ -62,9 +59,6 @@ let package = Package( path: "Libraries/MLXVLM", exclude: [ "README.md" - ], - swiftSettings: [ - .enableExperimentalFeature("StrictConcurrency") ] ), .target( @@ -78,9 +72,6 @@ let package = Package( path: "Libraries/MLXLMCommon", exclude: [ "README.md" - ], - swiftSettings: [ - .enableExperimentalFeature("StrictConcurrency") ] ), .target( @@ -94,9 +85,6 @@ let package = Package( path: "Libraries/MLXEmbedders", exclude: [ "README.md" - ], - swiftSettings: [ - .enableExperimentalFeature("StrictConcurrency") ] ), .testTarget( @@ -115,10 +103,7 @@ let package = Package( exclude: [ "README.md" ], - resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")], - swiftSettings: [ - .enableExperimentalFeature("StrictConcurrency") - ] + resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")] ), .testTarget( name: "MLXLMIntegrationTests", @@ -135,9 +120,6 @@ let package = Package( path: "Tests/MLXLMIntegrationTests", exclude: [ "README.md" - ], - swiftSettings: [ - .enableExperimentalFeature("StrictConcurrency") ] ), .testTarget( @@ -147,10 +129,7 @@ let package = Package( "MLXVLM", "MLXLMCommon", ], - path: "Tests/Benchmarks", - swiftSettings: [ - .enableExperimentalFeature("StrictConcurrency") - ] + path: "Tests/Benchmarks" ), ] ) diff --git a/Tests/MLXLMTests/ChatSessionTests.swift b/Tests/MLXLMTests/ChatSessionTests.swift index 6cf87b87a..b98f8e425 100644 --- a/Tests/MLXLMTests/ChatSessionTests.swift +++ b/Tests/MLXLMTests/ChatSessionTests.swift @@ -3,12 +3,13 @@ import Foundation import MLX import MLXLLM -import MLXLMCommon import MLXNN import MLXOptimizers import Tokenizers import XCTest +@testable import MLXLMCommon + /// See also ChatSessionIntegrationTests public class ChatSessionTests: XCTestCase { @@ -162,36 +163,43 @@ public class ChatSessionTests: XCTestCase { // MARK: - KV Cache func testCurrentCacheNilBeforeGeneration() async throws { - let session = ChatSession(model()) - let cache = await session.currentCache() - XCTAssertNil(cache) + let session = ChatSession(model(), generateParameters: generationParameters) + await session.withCache { cache in + XCTAssertNil(cache) + } } func testCurrentCacheAfterGeneration() async throws { - let session = ChatSession(model()) + let session = ChatSession(model(), generateParameters: generationParameters) _ = try await session.respond(to: "hello") - let cache = await session.currentCache() - XCTAssertNotNil(cache) + await session.withCache { cache in + XCTAssertNotNil(cache) + } } func testInitWithKVCache() async throws { // build a cache from an initial session - let ctx = model() - let initial = ChatSession(ctx) + let container = ModelContainer(context: model()) + let initial = ChatSession(container, generateParameters: generationParameters) _ = try await initial.respond(to: "hello") - guard let cache = await initial.currentCache() else { - XCTFail("expected cache after generation") - return - } - // restore the cache into a new session and verify generation continues - let restored = ChatSession(ctx, cache: cache) - let result = try await restored.respond(to: "hello again") - XCTAssertGreaterThan(result.count, targetLength, result) + try await initial.withCache { [targetLength, generationParameters] cache in + XCTAssertNotNil(cache) + + if let cache { + // restore the cache into a new session and verify generation continues + let restored = ChatSession( + container, + cache: cache.map { $0.copy() }, + generateParameters: generationParameters) + let result = try await restored.respond(to: "hello again") + XCTAssertGreaterThan(result.count, targetLength, result) + } + } } func testSaveCacheThrowsBeforeGeneration() async throws { - let session = ChatSession(model()) + let session = ChatSession(model(), generateParameters: generationParameters) let url = FileManager.default.temporaryDirectory .appendingPathComponent(UUID().uuidString) .appendingPathExtension("safetensors") @@ -205,7 +213,7 @@ public class ChatSessionTests: XCTestCase { func testSaveAndRestoreCache() async throws { let ctx = model() - let initial = ChatSession(ctx) + let initial = ChatSession(ctx, generateParameters: generationParameters) _ = try await initial.respond(to: "hello") let url = FileManager.default.temporaryDirectory @@ -214,7 +222,8 @@ public class ChatSessionTests: XCTestCase { try await initial.saveCache(to: url) let (loadedCache, _) = try loadPromptCache(url: url) - let restored = ChatSession(ctx, cache: loadedCache) + let restored = ChatSession( + ctx, cache: loadedCache, generateParameters: generationParameters) let result = try await restored.respond(to: "hello again") XCTAssertGreaterThan(result.count, targetLength, result) } @@ -222,29 +231,37 @@ public class ChatSessionTests: XCTestCase { func testCurrentCacheNilForHistorySessionBeforeGeneration() async throws { // .history state should behave like .empty: no cache until first generation let history: [Chat.Message] = [.user("hello"), .assistant("hi")] - let session = ChatSession(model(), history: history) - let cache = await session.currentCache() - XCTAssertNil(cache) + let session = ChatSession( + model(), history: history, generateParameters: generationParameters) + await session.withCache { cache in + XCTAssertNil(cache) + } } func testCurrentCacheNonNilForHistorySessionAfterGeneration() async throws { // after generation from .history state, cache transitions to .kvcache let history: [Chat.Message] = [.user("hello"), .assistant("hi")] - let session = ChatSession(model(), history: history) + let session = ChatSession( + model(), + history: history, + generateParameters: generationParameters) _ = try await session.respond(to: "hello again") - let cache = await session.currentCache() - XCTAssertNotNil(cache) + await session.withCache { cache in + XCTAssertNotNil(cache) + } } func testCurrentCacheNilAfterClear() async throws { // clear() resets to .empty; currentCache() should return nil again - let session = ChatSession(model()) + let session = ChatSession(model(), generateParameters: generationParameters) _ = try await session.respond(to: "hello") - let cacheBeforeClear = await session.currentCache() - XCTAssertNotNil(cacheBeforeClear) + await session.withCache { cache in + XCTAssertNotNil(cache) + } await session.clear() - let cacheAfterClear = await session.currentCache() - XCTAssertNil(cacheAfterClear) + await session.withCache { cache in + XCTAssertNil(cache) + } } /// something that looks like a view model diff --git a/Tests/MLXLMTests/KVCacheTests.swift b/Tests/MLXLMTests/KVCacheTests.swift index fe342bb7b..156f393aa 100644 --- a/Tests/MLXLMTests/KVCacheTests.swift +++ b/Tests/MLXLMTests/KVCacheTests.swift @@ -3,7 +3,7 @@ import MLX import MLXLMCommon import Testing -private let cacheCreators: [() -> any KVCache] = [ +private let cacheCreators: [@Sendable () -> any KVCache] = [ { KVCacheSimple() }, { RotatingKVCache(maxSize: 32) }, { QuantizedKVCache() },