@@ -114,6 +114,10 @@ public final class InferenceEngine: ObservableObject {
114114 @Published public private( set) var activeContextTokens : Int = 0
115115 @Published public private( set) var maxContextWindow : Int = 0
116116
117+ /// Set when a corrupted/truncated model is detected during inference.
118+ /// The UI should observe this and offer to delete & re-download.
119+ @Published public var corruptedModelId : String ? = nil
120+
117121 /// Whether to automatically unload the model when the app backgrounds
118122 /// and reload it when returning to foreground.
119123 /// Defaults to true on iOS (prevents jetsam), false on macOS.
@@ -277,7 +281,44 @@ public final class InferenceEngine: ObservableObject {
277281 state = . error( " Device is too hot. Let it cool before loading a model. " )
278282 return
279283 }
284+ corruptedModelId = nil
285+
286+ guard ModelStorage . verifyModelIntegrity ( for: modelId) else {
287+ await downloadThenLoad ( modelId: modelId)
288+ return
289+ }
290+
291+ await loadVerifiedModel ( modelId: modelId)
292+ }
293+
294+ private func downloadThenLoad( modelId: String ) async {
295+ print ( " [InferenceEngine] Model \( modelId) is missing or incomplete. Starting download before load. " )
296+ releaseLoadedModelResources ( )
297+ state = . downloading( progress: 0.0 , speed: " Preparing... " )
280298
299+ let task = downloadManager. startDownload ( modelId: modelId)
300+
301+ do {
302+ try await task. value
303+ state = . downloading( progress: 1.0 , speed: " Verifying... " )
304+
305+ guard ModelStorage . verifyModelIntegrity ( for: modelId) else {
306+ markModelCorrupted (
307+ modelId: modelId,
308+ message: " Model files are incomplete after download. Choose a recovery option. "
309+ )
310+ return
311+ }
312+
313+ await loadVerifiedModel ( modelId: modelId)
314+ } catch is CancellationError {
315+ state = . idle
316+ } catch {
317+ state = . error( " Failed to download \( modelId) : \( error. localizedDescription) " )
318+ }
319+ }
320+
321+ private func loadVerifiedModel( modelId: String ) async {
281322 state = . loading
282323 currentModelId = modelId
283324
@@ -312,25 +353,29 @@ public final class InferenceEngine: ObservableObject {
312353 downloader: downloader
313354 )
314355
356+ let speedTracker = DownloadSpeedTracker ( )
357+
315358 if architecture. supportsVision {
316359 container = try await VLMModelFactory . shared. loadContainer (
317360 from: downloader,
318361 using: TransformersTokenizerLoader ( ) ,
319362 configuration: config
320363 ) { [ weak self] progress in
364+ speedTracker. record ( totalBytes: progress. completedUnitCount)
365+ let smoothedSpeed = speedTracker. speedBytesPerSec
366+
321367 Task { @MainActor in
322368 guard let self else { return }
323369 let pct = progress. fractionCompleted
324- let speedBytesPerSec = progress. userInfo [ ProgressUserInfoKey ( " throughputKey " ) ] as? Double
325- let speedStr = speedBytesPerSec
370+ let speedStr = smoothedSpeed
326371 . map { String ( format: " %.1f MB/s " , $0 / 1_000_000 ) } ?? " "
327372 self . state = . downloading( progress: pct, speed: speedStr)
328373
329374 self . downloadManager. updateProgress ( ModelDownloadProgress (
330375 modelId: modelId,
331376 fractionCompleted: pct,
332377 currentFile: " " ,
333- speedMBps: speedBytesPerSec . map { $0 / 1_000_000 }
378+ speedMBps: smoothedSpeed . map { $0 / 1_000_000 }
334379 ) )
335380 }
336381 }
@@ -340,19 +385,21 @@ public final class InferenceEngine: ObservableObject {
340385 using: TransformersTokenizerLoader ( ) ,
341386 configuration: config
342387 ) { [ weak self] progress in
388+ speedTracker. record ( totalBytes: progress. completedUnitCount)
389+ let smoothedSpeed = speedTracker. speedBytesPerSec
390+
343391 Task { @MainActor in
344392 guard let self else { return }
345393 let pct = progress. fractionCompleted
346- let speedBytesPerSec = progress. userInfo [ ProgressUserInfoKey ( " throughputKey " ) ] as? Double
347- let speedStr = speedBytesPerSec
394+ let speedStr = smoothedSpeed
348395 . map { String ( format: " %.1f MB/s " , $0 / 1_000_000 ) } ?? " "
349396 self . state = . downloading( progress: pct, speed: speedStr)
350397
351398 self . downloadManager. updateProgress ( ModelDownloadProgress (
352399 modelId: modelId,
353400 fractionCompleted: pct,
354401 currentFile: " " ,
355- speedMBps: speedBytesPerSec . map { $0 / 1_000_000 }
402+ speedMBps: smoothedSpeed . map { $0 / 1_000_000 }
356403 ) )
357404 }
358405 }
@@ -361,26 +408,85 @@ public final class InferenceEngine: ObservableObject {
361408 downloadManager. clearProgress ( modelId: modelId)
362409 downloadManager. lastLoadedModelId = modelId
363410 downloadManager. refresh ( )
411+
412+ // Verify integrity to catch incomplete downloads before marking as ready
413+ guard ModelStorage . verifyModelIntegrity ( for: modelId) else {
414+ throw NSError ( domain: " InferenceEngine " , code: 1 , userInfo: [ NSLocalizedDescriptionKey: " Model safetensors files are incomplete. Please delete and re-download. " ] )
415+ }
416+
417+ // Read the model's actual max context length from config.json
418+ if let ctxLen = ModelStorage . readMaxContextLength ( for: modelId) {
419+ self . maxContextWindow = ctxLen
420+ print ( " [InferenceEngine] Model context window: \( ctxLen) tokens " )
421+ } else {
422+ self . maxContextWindow = 8192 // conservative fallback for models without explicit limits
423+ print ( " [InferenceEngine] No explicit context limit found in config.json, defaulting to 8192 " )
424+ }
425+
364426 state = . ready( modelId: modelId)
365427
366428 } catch {
367429 ExpertStreamingConfig . shared. deactivate ( )
368430 downloadManager. clearProgress ( modelId: modelId)
369431 state = . error( " Failed to load \( modelId) : \( error. localizedDescription) " )
432+
433+ // If the model is incomplete/corrupted, flag it so the UI shows the "Delete & Re-download" button
434+ let nsError = error as NSError
435+ if nsError. domain == " InferenceEngine " && nsError. code == 1 || Self . isModelCorruptionError ( error) {
436+ markModelCorrupted (
437+ modelId: modelId,
438+ message: " Model weights are corrupted or incomplete. Choose a recovery option. "
439+ )
440+ return
441+ }
442+
370443 container = nil
444+ self . maxContextWindow = 0
445+ self . activeContextTokens = 0
371446 }
372447 }
373448
374449 /// Unload the current model and free all GPU memory.
375450 public func unload( ) {
451+ releaseLoadedModelResources ( )
452+ corruptedModelId = nil
453+ state = . idle
454+ }
455+
456+ private func releaseLoadedModelResources( ) {
376457 generationTask? . cancel ( )
458+ generationTask = nil
377459 container = nil
378460 currentModelId = nil
379- state = . idle
461+ maxContextWindow = 0
462+ activeContextTokens = 0
380463 ExpertStreamingConfig . shared. deactivate ( )
381464 MLX . Memory. cacheLimit = 0
382465 }
383466
467+ private func markModelCorrupted( modelId: String ? , message: String ) {
468+ let failedModelId = modelId ?? currentModelId
469+ releaseLoadedModelResources ( )
470+ state = . error( message)
471+ corruptedModelId = failedModelId
472+ }
473+
474+ private static func isModelCorruptionError( _ error: Error ) -> Bool {
475+ let description = error. localizedDescription. lowercased ( )
476+ return description. contains ( " ssd streaming " )
477+ || description. contains ( " pread " )
478+ || description. contains ( " safetensors " )
479+ || description. contains ( " corrupt " )
480+ || description. contains ( " incomplete " )
481+ }
482+
483+ public func clearCorruptionRecovery( ) {
484+ corruptedModelId = nil
485+ if case . error = state {
486+ state = . idle
487+ }
488+ }
489+
384490 // MARK: — Generation
385491
386492 public nonisolated func generate(
@@ -442,9 +548,7 @@ public final class InferenceEngine: ObservableObject {
442548 let baseTokens = Int ( Double ( stringLength) / 3.5 )
443549 self . activeContextTokens = baseTokens
444550
445- // If we have a max length config, expose it
446- // TODO: Safely extract from ModelConfiguration when MLX exposes it dynamically
447- self . maxContextWindow = 8192
551+ // maxContextWindow is already set during loadModel() from config.json
448552
449553 let stream : AsyncStream < Generation > = try await container. generate (
450554 input: lmInput,
@@ -485,11 +589,32 @@ public final class InferenceEngine: ObservableObject {
485589 continuation. yield ( GenerationToken ( text: text, isThinking: thinkingActive) )
486590 }
487591 }
592+ } catch let ssdError as SSDStreamingError {
593+ // Corrupted/truncated safetensors — surface a clear, actionable error
594+ let msg = " Model weights are corrupted or incomplete. Please re-download the model. "
595+ print ( " [InferenceEngine] SSD Streaming Error: \( ssdError. localizedDescription) " )
596+ continuation. yield ( GenerationToken ( text: " \n \n [Error: \( msg) ] " ) )
597+ self . markModelCorrupted ( modelId: self . currentModelId, message: msg)
488598 } catch {
599+ // Check if the generic error is also an SSD streaming issue
600+ if Self . isModelCorruptionError ( error) {
601+ let msg = " Model weights are corrupted or incomplete. Please re-download the model. "
602+ self . markModelCorrupted ( modelId: self . currentModelId, message: msg)
603+ }
489604 continuation. yield ( GenerationToken ( text: " \n \n [Error: \( error. localizedDescription) ] " ) )
490605 }
491606
492- self . state = self . currentModelId. map { . ready( modelId: $0) } ?? . idle
607+ // Also check the latch one final time (for errors that occurred during
608+ // generation and caused the stream to end without throwing)
609+ if let latchedError = SSDStreamingErrorLatch . shared. consume ( ) {
610+ let msg = " Model weights are corrupted or incomplete. Please re-download the model. "
611+ print ( " [InferenceEngine] Latched SSD error after generation: \( latchedError. localizedDescription) " )
612+ self . markModelCorrupted ( modelId: self . currentModelId, message: msg)
613+ } else if case . error = self . state {
614+ // Already in error state from catch block above
615+ } else {
616+ self . state = self . currentModelId. map { . ready( modelId: $0) } ?? . idle
617+ }
493618 continuation. finish ( )
494619 }
495620 }
@@ -500,4 +625,29 @@ public final class InferenceEngine: ObservableObject {
500625 generationTask = nil
501626 if let id = currentModelId { state = . ready( modelId: id) }
502627 }
628+
629+ /// Delete corrupted model files and start a fresh download.
630+ /// Called from the UI when the user confirms re-download after corruption is detected.
631+ public func deleteCorruptedAndRedownload( ) {
632+ guard let modelId = corruptedModelId else { return }
633+
634+ releaseLoadedModelResources ( )
635+ state = . downloading( progress: 0.0 , speed: " Deleting corrupted files... " )
636+
637+ do {
638+ try ModelStorage . delete ( modelId)
639+ print ( " [InferenceEngine] Successfully deleted corrupted cache directory for \( modelId) . " )
640+ } catch {
641+ print ( " [InferenceEngine] FAILED to delete corrupted cache: \( error. localizedDescription) " )
642+ state = . error( " Failed to delete corrupted model: \( error. localizedDescription) " )
643+ return
644+ }
645+ downloadManager. refresh ( )
646+ corruptedModelId = nil
647+
648+ print ( " [InferenceEngine] Deleted corrupted files for \( modelId) , starting fresh download " )
649+ Task { @MainActor in
650+ await downloadThenLoad ( modelId: modelId)
651+ }
652+ }
503653}
0 commit comments