Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 67 additions & 10 deletions Type4Me/ASR/DeepgramASRClient.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import Foundation
import os

enum DeepgramASRError: Error, LocalizedError {
enum DeepgramASRError: Error, LocalizedError, Equatable {
case unsupportedProvider
case handshakeTimedOut
case closedBeforeHandshake(code: Int, reason: String?)
case closed(code: Int, reason: String?)

var errorDescription: String? {
switch self {
Expand All @@ -17,10 +18,47 @@ enum DeepgramASRError: Error, LocalizedError {
return "Deepgram WebSocket closed before handshake completed (\(code)): \(reason)"
}
return "Deepgram WebSocket closed before handshake completed (\(code))"
case .closed(let code, let reason):
if let reason, !reason.isEmpty {
return "Deepgram WebSocket closed unexpectedly (\(code)): \(reason)"
}
return "Deepgram WebSocket closed unexpectedly (\(code))"
}
}

static func unexpectedClose(
code: URLSessionWebSocketTask.CloseCode,
reason: String?
) -> DeepgramASRError? {
switch code {
case .normalClosure, .goingAway, .noStatusReceived:
return nil
default:
return .closed(code: Int(code.rawValue), reason: reason)
}
}
}

actor DeepgramCloseTracker {

private var closeError: DeepgramASRError?

func recordClose(
code: URLSessionWebSocketTask.CloseCode,
reason: String?
) {
guard closeError == nil,
let error = DeepgramASRError.unexpectedClose(code: code, reason: reason)
else { return }
closeError = error
}

func consumeCloseError() -> DeepgramASRError? {
defer { closeError = nil }
return closeError
}
}

actor DeepgramASRClient: SpeechRecognizer {

private let logger = Logger(
Expand All @@ -32,6 +70,7 @@ actor DeepgramASRClient: SpeechRecognizer {
private var receiveTask: Task<Void, Never>?
private var session: URLSession?
private var sessionDelegate: DeepgramWebSocketDelegate?
private var closeTracker: DeepgramCloseTracker?

private var eventContinuation: AsyncStream<RecognitionEvent>.Continuation?
private var _events: AsyncStream<RecognitionEvent>?
Expand Down Expand Up @@ -69,12 +108,17 @@ actor DeepgramASRClient: SpeechRecognizer {
request.setValue("Token \(deepgramConfig.apiKey)", forHTTPHeaderField: "Authorization")

let connectionGate = DeepgramConnectionGate()
let delegate = DeepgramWebSocketDelegate(connectionGate: connectionGate)
let closeTracker = DeepgramCloseTracker()
let delegate = DeepgramWebSocketDelegate(
connectionGate: connectionGate,
closeTracker: closeTracker
)
let session = URLSession(configuration: options.urlSessionConfiguration, delegate: delegate, delegateQueue: nil)
let task = session.webSocketTask(with: request)
task.resume()

self.connectionGate = connectionGate
self.closeTracker = closeTracker
sessionDelegate = delegate
self.session = session
webSocketTask = task
Expand Down Expand Up @@ -108,6 +152,7 @@ actor DeepgramASRClient: SpeechRecognizer {
session?.invalidateAndCancel()
session = nil
sessionDelegate = nil
closeTracker = nil
eventContinuation?.finish()
eventContinuation = nil
_events = nil
Expand Down Expand Up @@ -135,7 +180,10 @@ actor DeepgramASRClient: SpeechRecognizer {
logger.info("Deepgram receive loop ended: \(String(describing: error), privacy: .public)")
let didRequestClose = await self.didRequestClose
let audioPacketCount = await self.audioPacketCount
if didRequestClose || audioPacketCount > 0 {
if let closeError = await self.closeTracker?.consumeCloseError() {
await self.emitEvent(.error(closeError))
await self.emitEvent(.completed)
} else if didRequestClose || audioPacketCount > 0 {
await self.emitEvent(.completed)
} else {
await self.emitEvent(.error(error))
Expand Down Expand Up @@ -188,9 +236,14 @@ actor DeepgramASRClient: SpeechRecognizer {
final class DeepgramWebSocketDelegate: NSObject, URLSessionWebSocketDelegate, URLSessionTaskDelegate {

private let connectionGate: DeepgramConnectionGate
private let closeTracker: DeepgramCloseTracker

init(connectionGate: DeepgramConnectionGate) {
init(
connectionGate: DeepgramConnectionGate,
closeTracker: DeepgramCloseTracker = DeepgramCloseTracker()
) {
self.connectionGate = connectionGate
self.closeTracker = closeTracker
}

func urlSession(
Expand Down Expand Up @@ -224,13 +277,17 @@ final class DeepgramWebSocketDelegate: NSObject, URLSessionWebSocketDelegate, UR
// Post-handshake closes are normal session endings, not errors.
let reasonText = reason.flatMap { String(data: $0, encoding: .utf8) }
Task {
guard await !connectionGate.hasOpened else { return }
await connectionGate.markFailure(
DeepgramASRError.closedBeforeHandshake(
code: Int(closeCode.rawValue),
reason: reasonText
if await !connectionGate.hasOpened {
await connectionGate.markFailure(
DeepgramASRError.closedBeforeHandshake(
code: Int(closeCode.rawValue),
reason: reasonText
)
)
)
return
}

await closeTracker.recordClose(code: closeCode, reason: reasonText)
}
}
}
34 changes: 28 additions & 6 deletions Type4Me/Session/RecognitionSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ actor RecognitionSession {
private var audioChunkContinuation: AsyncStream<Data>.Continuation?
private var audioChunkSenderTask: Task<Void, Never>?
private var uploadFailureFlag: UploadFailureFlag?
private var lastStreamingError: Error?

// MARK: - Prompt context (selected text + clipboard captured at recording start)

Expand Down Expand Up @@ -153,6 +154,7 @@ actor RecognitionSession {
hasEmittedReadyForCurrentSession = false
injectionAborted = false
pendingLLMError = nil
lastStreamingError = nil
state = .starting

// Load credentials for selected provider
Expand Down Expand Up @@ -494,7 +496,7 @@ actor RecognitionSession {
return result
} catch {
DebugFileLogger.log("stop: fresh LLM FAILED +\(ContinuousClock.now - stopT0) error=\(error)")
await self.setPendingLLMError(error)
self.setPendingLLMError(error)
return nil
}
}
Expand All @@ -514,13 +516,21 @@ actor RecognitionSession {
// use whatever streaming produced rather than re-sending everything.
let uploadFailed = uploadFailureFlag?.failed == true
let hasUsableStreamingResult = !currentTranscript.confirmedSegments.isEmpty
let needsBatchFallback = uploadFailed || (!asrTeardownClean && !hasUsableStreamingResult)
if !asrTeardownClean && !uploadFailed && hasUsableStreamingResult {
let streamingFailed = Self.shouldAttemptBatchFallback(
uploadFailed: uploadFailed,
asrTeardownClean: asrTeardownClean,
streamingError: lastStreamingError
)
let needsBatchFallback = streamingFailed
&& (uploadFailed || lastStreamingError != nil || !hasUsableStreamingResult)
if streamingFailed && !needsBatchFallback {
DebugFileLogger.log("stop: drain timeout but streaming has confirmed text, skipping batch fallback")
}
if needsBatchFallback {
let partialText = currentTranscript.composedText
DebugFileLogger.log("stop: streaming failed (partial=\(partialText.count) chars, uploadFailed=\(uploadFailed)), attempting batch fallback")
DebugFileLogger.log(
"stop: streaming failed (partial=\(partialText.count) chars, uploadFailed=\(uploadFailed), hasStreamingError=\(lastStreamingError != nil)), attempting batch fallback"
)
let fullAudio = audioEngine.getRecordedAudio()
if !fullAudio.isEmpty, let config = currentConfig {
onASREvent?(.processingResult(text: partialText.isEmpty ? "重新识别中..." : partialText))
Expand All @@ -538,6 +548,7 @@ actor RecognitionSession {
}
}
uploadFailureFlag = nil
lastStreamingError = nil

// Combine confirmed segments + any trailing unconfirmed partial.
let effectiveText = currentTranscript.displayText
Expand Down Expand Up @@ -767,6 +778,7 @@ actor RecognitionSession {
}

case .error(let error):
lastStreamingError = error
logger.error("ASR error: \(error)")

case .completed:
Expand Down Expand Up @@ -836,7 +848,7 @@ actor RecognitionSession {
let failureFlag = UploadFailureFlag()
self.uploadFailureFlag = failureFlag

audioChunkSenderTask = Task.detached { [weak self] in
audioChunkSenderTask = Task.detached {
var chunkCount = 0
var lastLogTime: ContinuousClock.Instant?
for await data in stream {
Expand Down Expand Up @@ -934,7 +946,7 @@ actor RecognitionSession {
return result
} catch {
DebugFileLogger.log("speculative LLM: failed \(error)")
await self.setPendingLLMError(error)
self.setPendingLLMError(error)
return nil
}
}
Expand Down Expand Up @@ -991,6 +1003,14 @@ actor RecognitionSession {
}
}

static func shouldAttemptBatchFallback(
uploadFailed: Bool,
asrTeardownClean: Bool,
streamingError: Error?
) -> Bool {
uploadFailed || !asrTeardownClean || streamingError != nil
}

// MARK: - Batch Fallback

/// Try to transcribe full audio via the same provider.
Expand Down Expand Up @@ -1115,6 +1135,8 @@ actor RecognitionSession {
currentTranscript = .empty
hasEmittedReadyForCurrentSession = false
currentConfig = nil
uploadFailureFlag = nil
lastStreamingError = nil
SystemVolumeManager.restore()
}

Expand Down
51 changes: 51 additions & 0 deletions Type4MeTests/DeepgramWebSocketDelegateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,55 @@ final class DeepgramWebSocketDelegateTests: XCTestCase {
let opened = await gate.hasOpened
XCTAssertTrue(opened)
}

func testDidCloseWith_afterHandshakeRecordsUnexpectedCloseError() async throws {
let gate = DeepgramConnectionGate()
let closeTracker = DeepgramCloseTracker()
let delegate = DeepgramWebSocketDelegate(connectionGate: gate, closeTracker: closeTracker)
let session = URLSession(configuration: .ephemeral)
let task = session.webSocketTask(with: URL(string: "wss://example.com/socket")!)

delegate.urlSession(session, webSocketTask: task, didOpenWithProtocol: nil)
try await Task.sleep(for: .milliseconds(20))

delegate.urlSession(
session,
webSocketTask: task,
didCloseWith: .policyViolation,
reason: Data("bad payload".utf8)
)
try await Task.sleep(for: .milliseconds(20))

let error = await closeTracker.consumeCloseError()
guard let error,
case let DeepgramASRError.closed(code, reason) = error
else {
return XCTFail("Expected tracked post-handshake close error, got \(String(describing: error))")
}

XCTAssertEqual(code, Int(URLSessionWebSocketTask.CloseCode.policyViolation.rawValue))
XCTAssertEqual(reason, "bad payload")
}

func testDidCloseWith_afterHandshakeIgnoresNormalClosure() async throws {
let gate = DeepgramConnectionGate()
let closeTracker = DeepgramCloseTracker()
let delegate = DeepgramWebSocketDelegate(connectionGate: gate, closeTracker: closeTracker)
let session = URLSession(configuration: .ephemeral)
let task = session.webSocketTask(with: URL(string: "wss://example.com/socket")!)

delegate.urlSession(session, webSocketTask: task, didOpenWithProtocol: nil)
try await Task.sleep(for: .milliseconds(20))

delegate.urlSession(
session,
webSocketTask: task,
didCloseWith: .normalClosure,
reason: Data("ok".utf8)
)
try await Task.sleep(for: .milliseconds(20))

let error = await closeTracker.consumeCloseError()
XCTAssertNil(error)
}
}
10 changes: 10 additions & 0 deletions Type4MeTests/RecognitionSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,14 @@ final class RecognitionSessionTests: XCTestCase {
let mode = await session.currentModeForTesting()
XCTAssertEqual(mode.id, ProcessingMode.directId)
}

func testShouldAttemptBatchFallbackWhenStreamingErrorWasObserved() {
let shouldFallback = RecognitionSession.shouldAttemptBatchFallback(
uploadFailed: false,
asrTeardownClean: true,
streamingError: DeepgramASRError.closed(code: 1008, reason: "policy violation")
)

XCTAssertTrue(shouldFallback)
}
}