diff --git a/src/azure.ts b/src/azure.ts index 490b82b9f..0fb2264a6 100644 --- a/src/azure.ts +++ b/src/azure.ts @@ -3,7 +3,6 @@ import * as Errors from './error'; import { FinalRequestOptions } from './internal/request-options'; import { isObj, readEnv } from './internal/utils'; import { ClientOptions, OpenAI } from './client'; -import { buildHeaders, NullableHeaders } from './internal/headers'; /** API Client for interfacing with the Azure OpenAI API. */ export interface AzureClientOptions extends ClientOptions { @@ -37,7 +36,6 @@ export interface AzureClientOptions extends ClientOptions { /** API Client for interfacing with the Azure OpenAI API. */ export class AzureOpenAI extends OpenAI { - private _azureADTokenProvider: (() => Promise) | undefined; deploymentName: string | undefined; apiVersion: string = ''; @@ -90,9 +88,6 @@ export class AzureOpenAI extends OpenAI { ); } - // define a sentinel value to avoid any typing issues - apiKey ??= API_KEY_SENTINEL; - opts.defaultQuery = { ...opts.defaultQuery, 'api-version': apiVersion }; if (!baseURL) { @@ -116,11 +111,12 @@ export class AzureOpenAI extends OpenAI { super({ apiKey, baseURL, + tokenProvider: + !azureADTokenProvider ? undefined : async () => ({ token: await azureADTokenProvider() }), ...opts, ...(dangerouslyAllowBrowser !== undefined ? { dangerouslyAllowBrowser } : {}), }); - this._azureADTokenProvider = azureADTokenProvider; this.apiVersion = apiVersion; this.deploymentName = deployment; } @@ -140,47 +136,6 @@ export class AzureOpenAI extends OpenAI { } return super.buildRequest(options, props); } - - async _getAzureADToken(): Promise { - if (typeof this._azureADTokenProvider === 'function') { - const token = await this._azureADTokenProvider(); - if (!token || typeof token !== 'string') { - throw new Errors.OpenAIError( - `Expected 'azureADTokenProvider' argument to return a string but it returned ${token}`, - ); - } - return token; - } - return undefined; - } - - protected override async authHeaders(opts: FinalRequestOptions): Promise { - return; - } - - protected override async prepareOptions(opts: FinalRequestOptions): Promise { - opts.headers = buildHeaders([opts.headers]); - - /** - * The user should provide a bearer token provider if they want - * to use Azure AD authentication. The user shouldn't set the - * Authorization header manually because the header is overwritten - * with the Azure AD token if a bearer token provider is provided. - */ - if (opts.headers.values.get('Authorization') || opts.headers.values.get('api-key')) { - return super.prepareOptions(opts); - } - - const token = await this._getAzureADToken(); - if (token) { - opts.headers.values.set('Authorization', `Bearer ${token}`); - } else if (this.apiKey !== API_KEY_SENTINEL) { - opts.headers.values.set('api-key', this.apiKey); - } else { - throw new Errors.OpenAIError('Unable to handle auth'); - } - return super.prepareOptions(opts); - } } const _deployments_endpoints = new Set([ @@ -194,5 +149,3 @@ const _deployments_endpoints = new Set([ '/batches', '/images/edits', ]); - -const API_KEY_SENTINEL = ''; diff --git a/src/beta/realtime/websocket.ts b/src/beta/realtime/websocket.ts index 2bf0b75d5..136a740b0 100644 --- a/src/beta/realtime/websocket.ts +++ b/src/beta/realtime/websocket.ts @@ -94,20 +94,24 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { } } + static async create( + client: Pick, + props: { model: string; dangerouslyAllowBrowser?: boolean }, + ): Promise { + await client._setToken(); + return new OpenAIRealtimeWebSocket(props, client); + } + static async azure( - client: Pick, + client: Pick, options: { deploymentName?: string; dangerouslyAllowBrowser?: boolean } = {}, ): Promise { - const token = await client._getAzureADToken(); + const isToken = await client._setToken(); function onURL(url: URL) { - if (client.apiKey !== '') { - url.searchParams.set('api-key', client.apiKey); + if (isToken) { + url.searchParams.set('Authorization', `Bearer ${client.apiKey}`); } else { - if (token) { - url.searchParams.set('Authorization', `Bearer ${token}`); - } else { - throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.'); - } + url.searchParams.set('api-key', client.apiKey); } } const deploymentName = options.deploymentName ?? client.deploymentName; diff --git a/src/beta/realtime/ws.ts b/src/beta/realtime/ws.ts index 3f51dfc4b..ac7c3ed3f 100644 --- a/src/beta/realtime/ws.ts +++ b/src/beta/realtime/ws.ts @@ -51,8 +51,16 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { }); } + static async create( + client: Pick, + props: { model: string; options?: WS.ClientOptions | undefined }, + ): Promise { + await client._setToken(); + return new OpenAIRealtimeWS(props, client); + } + static async azure( - client: Pick, + client: Pick, options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {}, ): Promise { const deploymentName = options.deploymentName ?? client.deploymentName; @@ -82,15 +90,11 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { } } -async function getAzureHeaders(client: Pick) { - if (client.apiKey !== '') { - return { 'api-key': client.apiKey }; +async function getAzureHeaders(client: Pick) { + const isToken = await client._setToken(); + if (isToken) { + return { Authorization: `Bearer ${isToken}` }; } else { - const token = await client._getAzureADToken(); - if (token) { - return { Authorization: `Bearer ${token}` }; - } else { - throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.'); - } + return { 'api-key': client.apiKey }; } } diff --git a/src/client.ts b/src/client.ts index 81102c552..1bcdc8816 100644 --- a/src/client.ts +++ b/src/client.ts @@ -191,12 +191,20 @@ import { } from './internal/utils/log'; import { isEmptyObj } from './internal/utils/values'; +export interface AccessToken { + token: string; +} +export type TokenProvider = () => Promise; + export interface ClientOptions { /** * Defaults to process.env['OPENAI_API_KEY']. */ apiKey?: string | undefined; - + /** + * A function that returns a token to use for authentication. + */ + tokenProvider?: TokenProvider | undefined; /** * Defaults to process.env['OPENAI_ORG_ID']. */ @@ -307,6 +315,7 @@ export class OpenAI { #encoder: Opts.RequestEncoder; protected idempotencyHeader?: string; private _options: ClientOptions; + private _tokenProvider: TokenProvider | undefined; /** * API Client for interfacing with the OpenAI API. @@ -330,11 +339,18 @@ export class OpenAI { organization = readEnv('OPENAI_ORG_ID') ?? null, project = readEnv('OPENAI_PROJECT_ID') ?? null, webhookSecret = readEnv('OPENAI_WEBHOOK_SECRET') ?? null, + tokenProvider, ...opts }: ClientOptions = {}) { - if (apiKey === undefined) { + if (apiKey === undefined && !tokenProvider) { + throw new Errors.OpenAIError( + 'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.', + ); + } + + if (tokenProvider && apiKey) { throw new Errors.OpenAIError( - "The OPENAI_API_KEY environment variable is missing or empty; either provide it, or instantiate the OpenAI client with an apiKey option, like new OpenAI({ apiKey: 'My API Key' }).", + 'The `apiKey` and `tokenProvider` arguments are mutually exclusive; only one can be passed at a time.', ); } @@ -343,6 +359,7 @@ export class OpenAI { organization, project, webhookSecret, + tokenProvider, ...opts, baseURL: baseURL || `https://api.openai.com/v1`, }; @@ -370,7 +387,8 @@ export class OpenAI { this._options = options; - this.apiKey = apiKey; + this.apiKey = apiKey ?? 'Missing Key'; + this._tokenProvider = tokenProvider; this.organization = organization; this.project = project; this.webhookSecret = webhookSecret; @@ -390,6 +408,7 @@ export class OpenAI { fetch: this.fetch, fetchOptions: this.fetchOptions, apiKey: this.apiKey, + tokenProvider: this._tokenProvider, organization: this.organization, project: this.project, webhookSecret: this.webhookSecret, @@ -438,6 +457,31 @@ export class OpenAI { return Errors.APIError.generate(status, error, message, headers); } + async _setToken(): Promise { + if (typeof this._tokenProvider === 'function') { + try { + const token = await this._tokenProvider(); + if (!token || typeof token.token !== 'string') { + throw new Errors.OpenAIError( + `Expected 'tokenProvider' argument to return a string but it returned ${token}`, + ); + } + this.apiKey = token.token; + return true; + } catch (err: any) { + if (err instanceof Errors.OpenAIError) { + throw err; + } + throw new Errors.OpenAIError( + `Failed to get token from 'tokenProvider' function: ${err.message}`, + // @ts-ignore + { cause: err }, + ); + } + } + return false; + } + buildURL( path: string, query: Record | null | undefined, @@ -464,7 +508,9 @@ export class OpenAI { /** * Used as a callback for mutating the given `FinalRequestOptions` object. */ - protected async prepareOptions(options: FinalRequestOptions): Promise {} + protected async prepareOptions(options: FinalRequestOptions): Promise { + await this._setToken(); + } /** * Used as a callback for mutating the given `RequestInit` object. diff --git a/tests/index.test.ts b/tests/index.test.ts index c8b4b819c..23020c47a 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -719,4 +719,96 @@ describe('retries', () => { ).toEqual(JSON.stringify({ a: 1 })); expect(count).toEqual(3); }); + + describe('auth', () => { + test('apiKey', async () => { + const client = new OpenAI({ + baseURL: 'http://localhost:5000/', + apiKey: 'My API Key', + }); + const { req } = await client.buildRequest({ path: '/foo', method: 'get' }); + expect(req.headers.get('authorization')).toEqual('Bearer My API Key'); + }); + + test('token', async () => { + const testFetch = async (url: any, { headers }: RequestInit = {}): Promise => { + return new Response(JSON.stringify({}), { headers: headers ?? [] }); + }; + const client = new OpenAI({ + baseURL: 'http://localhost:5000/', + tokenProvider: async () => ({ token: 'my token' }), + fetch: testFetch, + }); + expect( + (await client.request({ method: 'post', path: 'https://example.com' }).asResponse()).headers.get( + 'authorization', + ), + ).toEqual('Bearer my token'); + }); + + test('token is refreshed', async () => { + let fail = true; + const testFetch = async (url: any, { headers }: RequestInit = {}): Promise => { + if (fail) { + fail = false; + return new Response(undefined, { + status: 429, + headers: { + 'Retry-After': '0.1', + }, + }); + } + return new Response(JSON.stringify({}), { + headers: headers ?? [], + }); + }; + let counter = 0; + async function tokenProvider() { + return { token: `token-${counter++}` }; + } + const client = new OpenAI({ + baseURL: 'http://localhost:5000/', + tokenProvider, + fetch: testFetch, + }); + expect( + ( + await client.chat.completions + .create({ + model: '', + messages: [{ role: 'system', content: 'Hello' }], + }) + .asResponse() + ).headers.get('authorization'), + ).toEqual('Bearer token-1'); + }); + + test('mutual exclusive', () => { + try { + new OpenAI({ + baseURL: 'http://localhost:5000/', + tokenProvider: async () => ({ token: 'my token' }), + apiKey: 'my api key', + }); + } catch (error: any) { + expect(error).toBeInstanceOf(Error); + expect(error.message).toEqual( + 'The `apiKey` and `tokenProvider` arguments are mutually exclusive; only one can be passed at a time.', + ); + } + }); + + test('at least one', () => { + try { + new OpenAI({ + baseURL: 'http://localhost:5000/', + }); + } catch (error: any) { + expect(error).toBeInstanceOf(Error); + expect(error.message).toEqual( + 'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.', + ); + } + }); + }); }); diff --git a/tests/lib/azure.test.ts b/tests/lib/azure.test.ts index 49e3df1c0..a6a30d9f0 100644 --- a/tests/lib/azure.test.ts +++ b/tests/lib/azure.test.ts @@ -268,9 +268,9 @@ describe('instantiate azure client', () => { ); }); - test.skip('AAD token is refreshed', async () => { + test('AAD token is refreshed', async () => { let fail = true; - const testFetch = async (url: RequestInfo, req: RequestInit | undefined): Promise => { + const testFetch = async (url: any, { headers }: RequestInit = {}): Promise => { if (fail) { fail = false; return new Response(undefined, { @@ -280,8 +280,8 @@ describe('instantiate azure client', () => { }, }); } - return new Response(JSON.stringify({ auth: (req?.headers as Headers).get('authorization') }), { - headers: { 'content-type': 'application/json' }, + return new Response(JSON.stringify({}), { + headers: headers ?? [], }); }; let counter = 0; @@ -295,13 +295,15 @@ describe('instantiate azure client', () => { fetch: testFetch, }); expect( - await client.chat.completions.create({ - model, - messages: [{ role: 'system', content: 'Hello' }], - }), - ).toStrictEqual({ - auth: 'Bearer token-1', - }); + ( + await client.chat.completions + .create({ + model, + messages: [{ role: 'system', content: 'Hello' }], + }) + .asResponse() + ).headers.get('authorization'), + ).toEqual('Bearer token-1'); }); });