diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts index 73263a0b8..287821ef2 100644 --- a/packages/core/src/utils/retry.test.ts +++ b/packages/core/src/utils/retry.test.ts @@ -6,8 +6,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import type { HttpError } from './retry.js'; -import { retryWithBackoff } from './retry.js'; +import { retryWithBackoff, HttpError } from './retry.js'; import { setSimulate429 } from './testUtils.js'; import { AuthType } from '../core/contentGenerator.js'; @@ -84,7 +83,6 @@ describe('retryWithBackoff', () => { // 2. IMPORTANT: Attach the rejection expectation to the promise *immediately*. // This ensures a 'catch' handler is present before the promise can reject. // The result is a new promise that resolves when the assertion is met. - // eslint-disable-next-line vitest/valid-expect const assertionPromise = expect(promise).rejects.toThrow( 'Simulated error attempt 3', ); @@ -129,7 +127,7 @@ describe('retryWithBackoff', () => { // Attach the rejection expectation *before* running timers const assertionPromise = - expect(promise).rejects.toThrow('Too Many Requests'); // eslint-disable-line vitest/valid-expect + expect(promise).rejects.toThrow('Too Many Requests'); // Run timers to trigger retries and eventual rejection await vi.runAllTimersAsync(); @@ -197,7 +195,6 @@ describe('retryWithBackoff', () => { // We expect rejections as mockFn fails 5 times const promise1 = runRetry(); // Attach the rejection expectation *before* running timers - // eslint-disable-next-line vitest/valid-expect const assertionPromise1 = expect(promise1).rejects.toThrow(); await vi.runAllTimersAsync(); // Advance for the delay in the first runRetry await assertionPromise1; @@ -212,7 +209,6 @@ describe('retryWithBackoff', () => { const promise2 = runRetry(); // Attach the rejection expectation *before* running timers - // eslint-disable-next-line vitest/valid-expect const assertionPromise2 = expect(promise2).rejects.toThrow(); await vi.runAllTimersAsync(); // Advance for the delay in the second runRetry await assertionPromise2; @@ -573,4 +569,169 @@ describe('retryWithBackoff', () => { expect(fn).toHaveBeenCalledTimes(3); }); }); + + describe('Cerebras rate limiting', () => { + const originalBaseUrl = process.env.OPENAI_BASE_URL; + + beforeEach(() => { + process.env.OPENAI_BASE_URL = 'https://api.cerebras.ai/v1'; + }); + + afterEach(() => { + process.env.OPENAI_BASE_URL = originalBaseUrl; + }); + + it('should wait for the specified seconds when x-ratelimit-reset-tokens-minute header exists for Cerebras API', async () => { + const mockFn = vi.fn(async () => { + const error: HttpError = new Error('Too Many Requests'); + error.status = 429; + (error as any).response = { + headers: { + 'x-ratelimit-reset-tokens-minute': '10', + }, + }; + throw error; + }); + + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 1000, + }); + + const assertionPromise = + expect(promise).rejects.toThrow('Too Many Requests'); + await vi.runAllTimersAsync(); + await assertionPromise; + + expect(mockFn).toHaveBeenCalledTimes(2); + expect(setTimeoutSpy).toHaveBeenCalledWith(expect.any(Function), 10000); + }); + + it('should use Retry-After header for non-Cerebras API even when x-ratelimit-reset-tokens-minute exists', async () => { + process.env.OPENAI_BASE_URL = 'https://api.openai.com/v1'; + + const mockFn = vi.fn(async () => { + const error: HttpError = new Error('Too Many Requests'); + error.status = 429; + (error as any).response = { + headers: { + 'x-ratelimit-reset-tokens-minute': '10', + 'retry-after': '5', + }, + }; + throw error; + }); + + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 1000, + }); + + const assertionPromise = + expect(promise).rejects.toThrow('Too Many Requests'); + await vi.runAllTimersAsync(); + await assertionPromise; + + expect(mockFn).toHaveBeenCalledTimes(2); + expect(setTimeoutSpy).toHaveBeenCalledWith(expect.any(Function), 5000); + }); + + it('should correctly handle floating point values in x-ratelimit-reset-tokens-minute header', async () => { + const mockFn = vi.fn(async () => { + const error: HttpError = new Error('Too Many Requests'); + error.status = 429; + (error as any).response = { + headers: { + 'x-ratelimit-reset-tokens-minute': '2.5', + }, + }; + throw error; + }); + + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 1000, + }); + + const assertionPromise = + expect(promise).rejects.toThrow('Too Many Requests'); + await vi.runAllTimersAsync(); + await assertionPromise; + + expect(mockFn).toHaveBeenCalledTimes(2); + expect(setTimeoutSpy).toHaveBeenCalledWith(expect.any(Function), 2500); + }); + + it('should return 0 when x-ratelimit-reset-tokens-minute header does not exist', async () => { + const mockFn = vi.fn(async () => { + const error: HttpError = new Error('Too Many Requests'); + error.status = 429; + (error as any).response = { + headers: {}, + }; + throw error; + }); + + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 1000, + }); + + const assertionPromise = + expect(promise).rejects.toThrow('Too Many Requests'); + await vi.runAllTimersAsync(); + await assertionPromise; + + expect(mockFn).toHaveBeenCalledTimes(2); + expect(setTimeoutSpy).toHaveBeenCalledWith( + expect.any(Function), + expect.any(Number), + ); + const delay = setTimeoutSpy.mock.calls[0][1] as number; + expect(delay).toBeGreaterThanOrEqual(700); + expect(delay).toBeLessThanOrEqual(1300); + }); + + it('should return 0 for invalid values in x-ratelimit-reset-tokens-minute header', async () => { + const mockFn = vi.fn(async () => { + const error: HttpError = new Error('Too Many Requests'); + error.status = 429; + (error as any).response = { + headers: { + 'x-ratelimit-reset-tokens-minute': 'invalid', + }, + }; + throw error; + }); + + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 1000, + }); + + const assertionPromise = + expect(promise).rejects.toThrow('Too Many Requests'); + await vi.runAllTimersAsync(); + await assertionPromise; + + expect(mockFn).toHaveBeenCalledTimes(2); + expect(setTimeoutSpy).toHaveBeenCalledWith( + expect.any(Function), + expect.any(Number), + ); + const delay = setTimeoutSpy.mock.calls[0][1] as number; + expect(delay).toBeGreaterThanOrEqual(700); + expect(delay).toBeLessThanOrEqual(1300); + }); + }); }); diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index 5a3828d64..55ebf40d0 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -296,6 +296,41 @@ function getRetryAfterDelayMs(error: unknown): number { return 0; } +/** + * Extracts the Cerebras rate limit delay from an error object's headers. + * @param error The error object. + * @returns The delay in milliseconds, or 0 if not found or invalid. + */ +function getCerebrasRateLimitDelayMs(error: unknown): number { + if (typeof error === 'object' && error !== null) { + // Check for error.response.headers (common in axios errors) + if ( + 'response' in error && + typeof (error as { response?: unknown }).response === 'object' && + (error as { response?: unknown }).response !== null + ) { + const response = (error as { response: { headers?: unknown } }).response; + if ( + 'headers' in response && + typeof response.headers === 'object' && + response.headers !== null + ) { + const headers = response.headers as { + 'x-ratelimit-reset-tokens-minute'?: unknown; + }; + const resetTokensHeader = headers['x-ratelimit-reset-tokens-minute']; + if (typeof resetTokensHeader === 'string') { + const resetTokensSeconds = parseFloat(resetTokensHeader); + if (!isNaN(resetTokensSeconds)) { + return Math.ceil(resetTokensSeconds * 1000); + } + } + } + } + } + return 0; +} + /** * Determines the delay duration based on the error, prioritizing Retry-After header. * @param error The error object. @@ -309,7 +344,18 @@ function getDelayDurationAndStatus(error: unknown): { let delayDurationMs = 0; if (errorStatus === 429) { - delayDurationMs = getRetryAfterDelayMs(error); + // Check if this is a Cerebras API request + const isOpenAiCerebras = + typeof process !== 'undefined' && + process.env && + process.env.OPENAI_BASE_URL && + process.env.OPENAI_BASE_URL.includes('cerebras.ai'); + + if (isOpenAiCerebras) { + delayDurationMs = getCerebrasRateLimitDelayMs(error); + } else { + delayDurationMs = getRetryAfterDelayMs(error); + } } return { delayDurationMs, errorStatus }; }