diff --git a/gui/src/redux/thunks/streamThunkWrapper.tsx b/gui/src/redux/thunks/streamThunkWrapper.tsx index 4454246cd7d..2cf30306573 100644 --- a/gui/src/redux/thunks/streamThunkWrapper.tsx +++ b/gui/src/redux/thunks/streamThunkWrapper.tsx @@ -8,43 +8,63 @@ import { ThunkApiType } from "../store"; import { cancelStream } from "./cancelStream"; import { saveCurrentSession } from "./session"; +const OVERLOADED_RETRIES = 3; +const OVERLOADED_DELAY_MS = 1000; + +function isOverloadedErrorMessage(message?: string | null): boolean { + if (!message) return false; + const lower = message.toLowerCase(); + return lower.includes("overloaded") || lower.includes("malformed json"); +} + export const streamThunkWrapper = createAsyncThunk< void, () => Promise, ThunkApiType ->("chat/streamWrapper", async (runStream, { dispatch, extra, getState }) => { - try { - await runStream(); - const state = getState(); - if (!state.session.isInEdit) { - await dispatch( - saveCurrentSession({ - openNewSession: false, - generateTitle: true, - }), - ); - } - } catch (e) { - await dispatch(cancelStream()); - dispatch(setDialogMessage()); - dispatch(setShowDialog(true)); +>("chat/streamWrapper", async (runStream, { dispatch, getState }) => { + for (let attempt = 0; attempt <= OVERLOADED_RETRIES; attempt++) { + try { + await runStream(); + const state = getState(); + if (!state.session.isInEdit) { + await dispatch( + saveCurrentSession({ + openNewSession: false, + generateTitle: true, + }), + ); + } + return; + } catch (e) { + // Get the selected model from the state for error analysis + const state = getState(); + const selectedModel = selectSelectedChatModel(state); + const { parsedError, statusCode, message, modelTitle, providerName } = + analyzeError(e, selectedModel); - // Get the selected model from the state for error analysis - const state = getState(); - const selectedModel = selectSelectedChatModel(state); + const shouldRetry = + isOverloadedErrorMessage(message) && attempt < OVERLOADED_RETRIES; - const { parsedError, statusCode, modelTitle, providerName } = analyzeError( - e, - selectedModel, - ); + if (shouldRetry) { + await dispatch(cancelStream()); + const delayMs = OVERLOADED_DELAY_MS * 2 ** attempt; + await new Promise((resolve) => setTimeout(resolve, delayMs)); + await dispatch(cancelStream()); + } else { + await dispatch(cancelStream()); + dispatch(setDialogMessage()); + dispatch(setShowDialog(true)); - const errorData = { - error_type: statusCode ? `HTTP ${statusCode}` : "Unknown", - error_message: parsedError, - model_provider: providerName, - model_title: modelTitle, - }; + const errorData = { + error_type: statusCode ? `HTTP ${statusCode}` : "Unknown", + error_message: parsedError, + model_provider: providerName, + model_title: modelTitle, + }; - posthog.capture("gui_stream_error", errorData); + posthog.capture("gui_stream_error", errorData); + return; + } + } } });