-
Notifications
You must be signed in to change notification settings - Fork 287
switch to swift 6 -- prevent concurrency issues, fix concurrency issues #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| // swift-tools-version: 5.12 | ||
| // swift-tools-version: 6.1 | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switch to swift 6 -- prevent concurrency issues from making it past CI |
||
| // The swift-tools-version declares the minimum version of Swift required to build this package. | ||
|
|
||
| import PackageDescription | ||
|
|
@@ -26,7 +26,7 @@ let package = Package( | |
| targets: ["MLXEmbedders"]), | ||
| ], | ||
| dependencies: [ | ||
| .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.1")), | ||
| .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")), | ||
| .package( | ||
| url: "https://github.com/huggingface/swift-transformers", | ||
| .upToNextMinor(from: "1.2.0") | ||
|
|
@@ -45,9 +45,6 @@ let package = Package( | |
| path: "Libraries/MLXLLM", | ||
| exclude: [ | ||
| "README.md" | ||
| ], | ||
| swiftSettings: [ | ||
| .enableExperimentalFeature("StrictConcurrency") | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No longer needed. |
||
| ] | ||
| ), | ||
| .target( | ||
|
|
@@ -62,9 +59,6 @@ let package = Package( | |
| path: "Libraries/MLXVLM", | ||
| exclude: [ | ||
| "README.md" | ||
| ], | ||
| swiftSettings: [ | ||
| .enableExperimentalFeature("StrictConcurrency") | ||
| ] | ||
| ), | ||
| .target( | ||
|
|
@@ -78,9 +72,6 @@ let package = Package( | |
| path: "Libraries/MLXLMCommon", | ||
| exclude: [ | ||
| "README.md" | ||
| ], | ||
| swiftSettings: [ | ||
| .enableExperimentalFeature("StrictConcurrency") | ||
| ] | ||
| ), | ||
| .target( | ||
|
|
@@ -94,9 +85,6 @@ let package = Package( | |
| path: "Libraries/MLXEmbedders", | ||
| exclude: [ | ||
| "README.md" | ||
| ], | ||
| swiftSettings: [ | ||
| .enableExperimentalFeature("StrictConcurrency") | ||
| ] | ||
| ), | ||
| .testTarget( | ||
|
|
@@ -115,10 +103,7 @@ let package = Package( | |
| exclude: [ | ||
| "README.md" | ||
| ], | ||
| resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")], | ||
| swiftSettings: [ | ||
| .enableExperimentalFeature("StrictConcurrency") | ||
| ] | ||
| resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")] | ||
| ), | ||
| .testTarget( | ||
| name: "MLXLMIntegrationTests", | ||
|
|
@@ -135,9 +120,6 @@ let package = Package( | |
| path: "Tests/MLXLMIntegrationTests", | ||
| exclude: [ | ||
| "README.md" | ||
| ], | ||
| swiftSettings: [ | ||
| .enableExperimentalFeature("StrictConcurrency") | ||
| ] | ||
| ), | ||
| .testTarget( | ||
|
|
@@ -147,10 +129,7 @@ let package = Package( | |
| "MLXVLM", | ||
| "MLXLMCommon", | ||
| ], | ||
| path: "Tests/Benchmarks", | ||
| swiftSettings: [ | ||
| .enableExperimentalFeature("StrictConcurrency") | ||
| ] | ||
| path: "Tests/Benchmarks" | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,12 +3,13 @@ | |
| import Foundation | ||
| import MLX | ||
| import MLXLLM | ||
| import MLXLMCommon | ||
| import MLXNN | ||
| import MLXOptimizers | ||
| import Tokenizers | ||
| import XCTest | ||
|
|
||
| @testable import MLXLMCommon | ||
|
|
||
| /// See also ChatSessionIntegrationTests | ||
| public class ChatSessionTests: XCTestCase { | ||
|
|
||
|
|
@@ -162,36 +163,43 @@ public class ChatSessionTests: XCTestCase { | |
| // MARK: - KV Cache | ||
|
|
||
| func testCurrentCacheNilBeforeGeneration() async throws { | ||
| let session = ChatSession(model()) | ||
| let cache = await session.currentCache() | ||
| XCTAssertNil(cache) | ||
| let session = ChatSession(model(), generateParameters: generationParameters) | ||
| await session.withCache { cache in | ||
| XCTAssertNil(cache) | ||
| } | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the pattern to fix the tests -- just use the visitor. |
||
| } | ||
|
|
||
| func testCurrentCacheAfterGeneration() async throws { | ||
| let session = ChatSession(model()) | ||
| let session = ChatSession(model(), generateParameters: generationParameters) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #128, those changes were in flight while these calls were added. This just syncs everything up -- this is needed to make the random weight token generation stop. |
||
| _ = try await session.respond(to: "hello") | ||
| let cache = await session.currentCache() | ||
| XCTAssertNotNil(cache) | ||
| await session.withCache { cache in | ||
| XCTAssertNotNil(cache) | ||
| } | ||
| } | ||
|
|
||
| func testInitWithKVCache() async throws { | ||
| // build a cache from an initial session | ||
| let ctx = model() | ||
| let initial = ChatSession(ctx) | ||
| let container = ModelContainer(context: model()) | ||
| let initial = ChatSession(container, generateParameters: generationParameters) | ||
| _ = try await initial.respond(to: "hello") | ||
| guard let cache = await initial.currentCache() else { | ||
| XCTFail("expected cache after generation") | ||
| return | ||
| } | ||
|
|
||
| // restore the cache into a new session and verify generation continues | ||
| let restored = ChatSession(ctx, cache: cache) | ||
| let result = try await restored.respond(to: "hello again") | ||
| XCTAssertGreaterThan(result.count, targetLength, result) | ||
| try await initial.withCache { [targetLength, generationParameters] cache in | ||
| XCTAssertNotNil(cache) | ||
|
|
||
| if let cache { | ||
| // restore the cache into a new session and verify generation continues | ||
| let restored = ChatSession( | ||
| container, | ||
| cache: cache.map { $0.copy() }, | ||
| generateParameters: generationParameters) | ||
| let result = try await restored.respond(to: "hello again") | ||
| XCTAssertGreaterThan(result.count, targetLength, result) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| func testSaveCacheThrowsBeforeGeneration() async throws { | ||
| let session = ChatSession(model()) | ||
| let session = ChatSession(model(), generateParameters: generationParameters) | ||
| let url = FileManager.default.temporaryDirectory | ||
| .appendingPathComponent(UUID().uuidString) | ||
| .appendingPathExtension("safetensors") | ||
|
|
@@ -205,7 +213,7 @@ public class ChatSessionTests: XCTestCase { | |
|
|
||
| func testSaveAndRestoreCache() async throws { | ||
| let ctx = model() | ||
| let initial = ChatSession(ctx) | ||
| let initial = ChatSession(ctx, generateParameters: generationParameters) | ||
| _ = try await initial.respond(to: "hello") | ||
|
|
||
| let url = FileManager.default.temporaryDirectory | ||
|
|
@@ -214,37 +222,46 @@ public class ChatSessionTests: XCTestCase { | |
| try await initial.saveCache(to: url) | ||
|
|
||
| let (loadedCache, _) = try loadPromptCache(url: url) | ||
| let restored = ChatSession(ctx, cache: loadedCache) | ||
| let restored = ChatSession( | ||
| ctx, cache: loadedCache, generateParameters: generationParameters) | ||
| let result = try await restored.respond(to: "hello again") | ||
| XCTAssertGreaterThan(result.count, targetLength, result) | ||
| } | ||
|
|
||
| func testCurrentCacheNilForHistorySessionBeforeGeneration() async throws { | ||
| // .history state should behave like .empty: no cache until first generation | ||
| let history: [Chat.Message] = [.user("hello"), .assistant("hi")] | ||
| let session = ChatSession(model(), history: history) | ||
| let cache = await session.currentCache() | ||
| XCTAssertNil(cache) | ||
| let session = ChatSession( | ||
| model(), history: history, generateParameters: generationParameters) | ||
| await session.withCache { cache in | ||
| XCTAssertNil(cache) | ||
| } | ||
| } | ||
|
|
||
| func testCurrentCacheNonNilForHistorySessionAfterGeneration() async throws { | ||
| // after generation from .history state, cache transitions to .kvcache | ||
| let history: [Chat.Message] = [.user("hello"), .assistant("hi")] | ||
| let session = ChatSession(model(), history: history) | ||
| let session = ChatSession( | ||
| model(), | ||
| history: history, | ||
| generateParameters: generationParameters) | ||
| _ = try await session.respond(to: "hello again") | ||
| let cache = await session.currentCache() | ||
| XCTAssertNotNil(cache) | ||
| await session.withCache { cache in | ||
| XCTAssertNotNil(cache) | ||
| } | ||
| } | ||
|
|
||
| func testCurrentCacheNilAfterClear() async throws { | ||
| // clear() resets to .empty; currentCache() should return nil again | ||
| let session = ChatSession(model()) | ||
| let session = ChatSession(model(), generateParameters: generationParameters) | ||
| _ = try await session.respond(to: "hello") | ||
| let cacheBeforeClear = await session.currentCache() | ||
| XCTAssertNotNil(cacheBeforeClear) | ||
| await session.withCache { cache in | ||
| XCTAssertNotNil(cache) | ||
| } | ||
| await session.clear() | ||
| let cacheAfterClear = await session.currentCache() | ||
| XCTAssertNil(cacheAfterClear) | ||
| await session.withCache { cache in | ||
| XCTAssertNil(cache) | ||
| } | ||
| } | ||
|
|
||
| /// something that looks like a view model | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ import MLX | |
| import MLXLMCommon | ||
| import Testing | ||
|
|
||
| private let cacheCreators: [() -> any KVCache] = [ | ||
| private let cacheCreators: [@Sendable () -> any KVCache] = [ | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix sendable issue |
||
| { KVCacheSimple() }, | ||
| { RotatingKVCache(maxSize: 32) }, | ||
| { QuantizedKVCache() }, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a little too permissive -- it is added for test support so we can just switch it to an internal scoped visitor.