Skip to content

Commit 08ceed8

Browse files
Harden SwiftBuddy model loading recovery
1 parent dc1cff2 commit 08ceed8

9 files changed

Lines changed: 1554 additions & 93 deletions

Sources/MLXInferenceCore/InferenceEngine.swift

Lines changed: 161 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ public final class InferenceEngine: ObservableObject {
114114
@Published public private(set) var activeContextTokens: Int = 0
115115
@Published public private(set) var maxContextWindow: Int = 0
116116

117+
/// Set when a corrupted/truncated model is detected during inference.
118+
/// The UI should observe this and offer to delete & re-download.
119+
@Published public var corruptedModelId: String? = nil
120+
117121
/// Whether to automatically unload the model when the app backgrounds
118122
/// and reload it when returning to foreground.
119123
/// Defaults to true on iOS (prevents jetsam), false on macOS.
@@ -277,7 +281,44 @@ public final class InferenceEngine: ObservableObject {
277281
state = .error("Device is too hot. Let it cool before loading a model.")
278282
return
279283
}
284+
corruptedModelId = nil
285+
286+
guard ModelStorage.verifyModelIntegrity(for: modelId) else {
287+
await downloadThenLoad(modelId: modelId)
288+
return
289+
}
290+
291+
await loadVerifiedModel(modelId: modelId)
292+
}
293+
294+
private func downloadThenLoad(modelId: String) async {
295+
print("[InferenceEngine] Model \(modelId) is missing or incomplete. Starting download before load.")
296+
releaseLoadedModelResources()
297+
state = .downloading(progress: 0.0, speed: "Preparing...")
280298

299+
let task = downloadManager.startDownload(modelId: modelId)
300+
301+
do {
302+
try await task.value
303+
state = .downloading(progress: 1.0, speed: "Verifying...")
304+
305+
guard ModelStorage.verifyModelIntegrity(for: modelId) else {
306+
markModelCorrupted(
307+
modelId: modelId,
308+
message: "Model files are incomplete after download. Choose a recovery option."
309+
)
310+
return
311+
}
312+
313+
await loadVerifiedModel(modelId: modelId)
314+
} catch is CancellationError {
315+
state = .idle
316+
} catch {
317+
state = .error("Failed to download \(modelId): \(error.localizedDescription)")
318+
}
319+
}
320+
321+
private func loadVerifiedModel(modelId: String) async {
281322
state = .loading
282323
currentModelId = modelId
283324

@@ -312,25 +353,29 @@ public final class InferenceEngine: ObservableObject {
312353
downloader: downloader
313354
)
314355

356+
let speedTracker = DownloadSpeedTracker()
357+
315358
if architecture.supportsVision {
316359
container = try await VLMModelFactory.shared.loadContainer(
317360
from: downloader,
318361
using: TransformersTokenizerLoader(),
319362
configuration: config
320363
) { [weak self] progress in
364+
speedTracker.record(totalBytes: progress.completedUnitCount)
365+
let smoothedSpeed = speedTracker.speedBytesPerSec
366+
321367
Task { @MainActor in
322368
guard let self else { return }
323369
let pct = progress.fractionCompleted
324-
let speedBytesPerSec = progress.userInfo[ProgressUserInfoKey("throughputKey")] as? Double
325-
let speedStr = speedBytesPerSec
370+
let speedStr = smoothedSpeed
326371
.map { String(format: "%.1f MB/s", $0 / 1_000_000) } ?? ""
327372
self.state = .downloading(progress: pct, speed: speedStr)
328373

329374
self.downloadManager.updateProgress(ModelDownloadProgress(
330375
modelId: modelId,
331376
fractionCompleted: pct,
332377
currentFile: "",
333-
speedMBps: speedBytesPerSec.map { $0 / 1_000_000 }
378+
speedMBps: smoothedSpeed.map { $0 / 1_000_000 }
334379
))
335380
}
336381
}
@@ -340,19 +385,21 @@ public final class InferenceEngine: ObservableObject {
340385
using: TransformersTokenizerLoader(),
341386
configuration: config
342387
) { [weak self] progress in
388+
speedTracker.record(totalBytes: progress.completedUnitCount)
389+
let smoothedSpeed = speedTracker.speedBytesPerSec
390+
343391
Task { @MainActor in
344392
guard let self else { return }
345393
let pct = progress.fractionCompleted
346-
let speedBytesPerSec = progress.userInfo[ProgressUserInfoKey("throughputKey")] as? Double
347-
let speedStr = speedBytesPerSec
394+
let speedStr = smoothedSpeed
348395
.map { String(format: "%.1f MB/s", $0 / 1_000_000) } ?? ""
349396
self.state = .downloading(progress: pct, speed: speedStr)
350397

351398
self.downloadManager.updateProgress(ModelDownloadProgress(
352399
modelId: modelId,
353400
fractionCompleted: pct,
354401
currentFile: "",
355-
speedMBps: speedBytesPerSec.map { $0 / 1_000_000 }
402+
speedMBps: smoothedSpeed.map { $0 / 1_000_000 }
356403
))
357404
}
358405
}
@@ -361,26 +408,85 @@ public final class InferenceEngine: ObservableObject {
361408
downloadManager.clearProgress(modelId: modelId)
362409
downloadManager.lastLoadedModelId = modelId
363410
downloadManager.refresh()
411+
412+
// Verify integrity to catch incomplete downloads before marking as ready
413+
guard ModelStorage.verifyModelIntegrity(for: modelId) else {
414+
throw NSError(domain: "InferenceEngine", code: 1, userInfo: [NSLocalizedDescriptionKey: "Model safetensors files are incomplete. Please delete and re-download."])
415+
}
416+
417+
// Read the model's actual max context length from config.json
418+
if let ctxLen = ModelStorage.readMaxContextLength(for: modelId) {
419+
self.maxContextWindow = ctxLen
420+
print("[InferenceEngine] Model context window: \(ctxLen) tokens")
421+
} else {
422+
self.maxContextWindow = 8192 // conservative fallback for models without explicit limits
423+
print("[InferenceEngine] No explicit context limit found in config.json, defaulting to 8192")
424+
}
425+
364426
state = .ready(modelId: modelId)
365427

366428
} catch {
367429
ExpertStreamingConfig.shared.deactivate()
368430
downloadManager.clearProgress(modelId: modelId)
369431
state = .error("Failed to load \(modelId): \(error.localizedDescription)")
432+
433+
// If the model is incomplete/corrupted, flag it so the UI shows the "Delete & Re-download" button
434+
let nsError = error as NSError
435+
if nsError.domain == "InferenceEngine" && nsError.code == 1 || Self.isModelCorruptionError(error) {
436+
markModelCorrupted(
437+
modelId: modelId,
438+
message: "Model weights are corrupted or incomplete. Choose a recovery option."
439+
)
440+
return
441+
}
442+
370443
container = nil
444+
self.maxContextWindow = 0
445+
self.activeContextTokens = 0
371446
}
372447
}
373448

374449
/// Unload the current model and free all GPU memory.
375450
public func unload() {
451+
releaseLoadedModelResources()
452+
corruptedModelId = nil
453+
state = .idle
454+
}
455+
456+
private func releaseLoadedModelResources() {
376457
generationTask?.cancel()
458+
generationTask = nil
377459
container = nil
378460
currentModelId = nil
379-
state = .idle
461+
maxContextWindow = 0
462+
activeContextTokens = 0
380463
ExpertStreamingConfig.shared.deactivate()
381464
MLX.Memory.cacheLimit = 0
382465
}
383466

467+
private func markModelCorrupted(modelId: String?, message: String) {
468+
let failedModelId = modelId ?? currentModelId
469+
releaseLoadedModelResources()
470+
state = .error(message)
471+
corruptedModelId = failedModelId
472+
}
473+
474+
private static func isModelCorruptionError(_ error: Error) -> Bool {
475+
let description = error.localizedDescription.lowercased()
476+
return description.contains("ssd streaming")
477+
|| description.contains("pread")
478+
|| description.contains("safetensors")
479+
|| description.contains("corrupt")
480+
|| description.contains("incomplete")
481+
}
482+
483+
public func clearCorruptionRecovery() {
484+
corruptedModelId = nil
485+
if case .error = state {
486+
state = .idle
487+
}
488+
}
489+
384490
// MARK: — Generation
385491

386492
public nonisolated func generate(
@@ -442,9 +548,7 @@ public final class InferenceEngine: ObservableObject {
442548
let baseTokens = Int(Double(stringLength) / 3.5)
443549
self.activeContextTokens = baseTokens
444550

445-
// If we have a max length config, expose it
446-
// TODO: Safely extract from ModelConfiguration when MLX exposes it dynamically
447-
self.maxContextWindow = 8192
551+
// maxContextWindow is already set during loadModel() from config.json
448552

449553
let stream: AsyncStream<Generation> = try await container.generate(
450554
input: lmInput,
@@ -485,11 +589,32 @@ public final class InferenceEngine: ObservableObject {
485589
continuation.yield(GenerationToken(text: text, isThinking: thinkingActive))
486590
}
487591
}
592+
} catch let ssdError as SSDStreamingError {
593+
// Corrupted/truncated safetensors — surface a clear, actionable error
594+
let msg = "Model weights are corrupted or incomplete. Please re-download the model."
595+
print("[InferenceEngine] SSD Streaming Error: \(ssdError.localizedDescription)")
596+
continuation.yield(GenerationToken(text: "\n\n[Error: \(msg)]"))
597+
self.markModelCorrupted(modelId: self.currentModelId, message: msg)
488598
} catch {
599+
// Check if the generic error is also an SSD streaming issue
600+
if Self.isModelCorruptionError(error) {
601+
let msg = "Model weights are corrupted or incomplete. Please re-download the model."
602+
self.markModelCorrupted(modelId: self.currentModelId, message: msg)
603+
}
489604
continuation.yield(GenerationToken(text: "\n\n[Error: \(error.localizedDescription)]"))
490605
}
491606

492-
self.state = self.currentModelId.map { .ready(modelId: $0) } ?? .idle
607+
// Also check the latch one final time (for errors that occurred during
608+
// generation and caused the stream to end without throwing)
609+
if let latchedError = SSDStreamingErrorLatch.shared.consume() {
610+
let msg = "Model weights are corrupted or incomplete. Please re-download the model."
611+
print("[InferenceEngine] Latched SSD error after generation: \(latchedError.localizedDescription)")
612+
self.markModelCorrupted(modelId: self.currentModelId, message: msg)
613+
} else if case .error = self.state {
614+
// Already in error state from catch block above
615+
} else {
616+
self.state = self.currentModelId.map { .ready(modelId: $0) } ?? .idle
617+
}
493618
continuation.finish()
494619
}
495620
}
@@ -500,4 +625,29 @@ public final class InferenceEngine: ObservableObject {
500625
generationTask = nil
501626
if let id = currentModelId { state = .ready(modelId: id) }
502627
}
628+
629+
/// Delete corrupted model files and start a fresh download.
630+
/// Called from the UI when the user confirms re-download after corruption is detected.
631+
public func deleteCorruptedAndRedownload() {
632+
guard let modelId = corruptedModelId else { return }
633+
634+
releaseLoadedModelResources()
635+
state = .downloading(progress: 0.0, speed: "Deleting corrupted files...")
636+
637+
do {
638+
try ModelStorage.delete(modelId)
639+
print("[InferenceEngine] Successfully deleted corrupted cache directory for \(modelId).")
640+
} catch {
641+
print("[InferenceEngine] FAILED to delete corrupted cache: \(error.localizedDescription)")
642+
state = .error("Failed to delete corrupted model: \(error.localizedDescription)")
643+
return
644+
}
645+
downloadManager.refresh()
646+
corruptedModelId = nil
647+
648+
print("[InferenceEngine] Deleted corrupted files for \(modelId), starting fresh download")
649+
Task { @MainActor in
650+
await downloadThenLoad(modelId: modelId)
651+
}
652+
}
503653
}

0 commit comments

Comments
 (0)