Skip to content

add support for token provider #1587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 2 additions & 49 deletions src/azure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import * as Errors from './error';
import { FinalRequestOptions } from './internal/request-options';
import { isObj, readEnv } from './internal/utils';
import { ClientOptions, OpenAI } from './client';
import { buildHeaders, NullableHeaders } from './internal/headers';

/** API Client for interfacing with the Azure OpenAI API. */
export interface AzureClientOptions extends ClientOptions {
Expand Down Expand Up @@ -37,7 +36,6 @@ export interface AzureClientOptions extends ClientOptions {

/** API Client for interfacing with the Azure OpenAI API. */
export class AzureOpenAI extends OpenAI {
private _azureADTokenProvider: (() => Promise<string>) | undefined;
deploymentName: string | undefined;
apiVersion: string = '';

Expand Down Expand Up @@ -90,9 +88,6 @@ export class AzureOpenAI extends OpenAI {
);
}

// define a sentinel value to avoid any typing issues
apiKey ??= API_KEY_SENTINEL;

opts.defaultQuery = { ...opts.defaultQuery, 'api-version': apiVersion };

if (!baseURL) {
Expand All @@ -116,11 +111,12 @@ export class AzureOpenAI extends OpenAI {
super({
apiKey,
baseURL,
tokenProvider:
!azureADTokenProvider ? undefined : async () => ({ token: await azureADTokenProvider() }),
...opts,
...(dangerouslyAllowBrowser !== undefined ? { dangerouslyAllowBrowser } : {}),
});

this._azureADTokenProvider = azureADTokenProvider;
this.apiVersion = apiVersion;
this.deploymentName = deployment;
}
Expand All @@ -140,47 +136,6 @@ export class AzureOpenAI extends OpenAI {
}
return super.buildRequest(options, props);
}

async _getAzureADToken(): Promise<string | undefined> {
if (typeof this._azureADTokenProvider === 'function') {
const token = await this._azureADTokenProvider();
if (!token || typeof token !== 'string') {
throw new Errors.OpenAIError(
`Expected 'azureADTokenProvider' argument to return a string but it returned ${token}`,
);
}
return token;
}
return undefined;
}

protected override async authHeaders(opts: FinalRequestOptions): Promise<NullableHeaders | undefined> {
return;
}

protected override async prepareOptions(opts: FinalRequestOptions): Promise<void> {
opts.headers = buildHeaders([opts.headers]);

/**
* The user should provide a bearer token provider if they want
* to use Azure AD authentication. The user shouldn't set the
* Authorization header manually because the header is overwritten
* with the Azure AD token if a bearer token provider is provided.
*/
if (opts.headers.values.get('Authorization') || opts.headers.values.get('api-key')) {
return super.prepareOptions(opts);
}

const token = await this._getAzureADToken();
if (token) {
opts.headers.values.set('Authorization', `Bearer ${token}`);
} else if (this.apiKey !== API_KEY_SENTINEL) {
opts.headers.values.set('api-key', this.apiKey);
} else {
throw new Errors.OpenAIError('Unable to handle auth');
}
return super.prepareOptions(opts);
}
}

const _deployments_endpoints = new Set([
Expand All @@ -194,5 +149,3 @@ const _deployments_endpoints = new Set([
'/batches',
'/images/edits',
]);

const API_KEY_SENTINEL = '<Missing Key>';
22 changes: 13 additions & 9 deletions src/beta/realtime/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,24 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {
}
}

static async create(
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_setToken'>,
props: { model: string; dangerouslyAllowBrowser?: boolean },
): Promise<OpenAIRealtimeWebSocket> {
await client._setToken();
return new OpenAIRealtimeWebSocket(props, client);
}

static async azure(
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
client: Pick<AzureOpenAI, '_setToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
options: { deploymentName?: string; dangerouslyAllowBrowser?: boolean } = {},
): Promise<OpenAIRealtimeWebSocket> {
const token = await client._getAzureADToken();
const isToken = await client._setToken();
function onURL(url: URL) {
if (client.apiKey !== '<Missing Key>') {
url.searchParams.set('api-key', client.apiKey);
if (isToken) {
url.searchParams.set('Authorization', `Bearer ${client.apiKey}`);
} else {
if (token) {
url.searchParams.set('Authorization', `Bearer ${token}`);
} else {
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
}
url.searchParams.set('api-key', client.apiKey);
}
}
const deploymentName = options.deploymentName ?? client.deploymentName;
Expand Down
24 changes: 14 additions & 10 deletions src/beta/realtime/ws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,16 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
});
}

static async create(
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_setToken'>,
props: { model: string; options?: WS.ClientOptions | undefined },
): Promise<OpenAIRealtimeWS> {
await client._setToken();
return new OpenAIRealtimeWS(props, client);
}

static async azure(
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
client: Pick<AzureOpenAI, '_setToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {},
): Promise<OpenAIRealtimeWS> {
const deploymentName = options.deploymentName ?? client.deploymentName;
Expand Down Expand Up @@ -82,15 +90,11 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
}
}

async function getAzureHeaders(client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiKey'>) {
if (client.apiKey !== '<Missing Key>') {
return { 'api-key': client.apiKey };
async function getAzureHeaders(client: Pick<AzureOpenAI, '_setToken' | 'apiKey'>) {
const isToken = await client._setToken();
if (isToken) {
return { Authorization: `Bearer ${isToken}` };
} else {
const token = await client._getAzureADToken();
if (token) {
return { Authorization: `Bearer ${token}` };
} else {
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
}
return { 'api-key': client.apiKey };
}
}
56 changes: 51 additions & 5 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,20 @@ import {
} from './internal/utils/log';
import { isEmptyObj } from './internal/utils/values';

export interface AccessToken {
token: string;
}
export type TokenProvider = () => Promise<AccessToken>;

export interface ClientOptions {
/**
* Defaults to process.env['OPENAI_API_KEY'].
*/
apiKey?: string | undefined;

/**
* A function that returns a token to use for authentication.
*/
tokenProvider?: TokenProvider | undefined;
/**
* Defaults to process.env['OPENAI_ORG_ID'].
*/
Expand Down Expand Up @@ -307,6 +315,7 @@ export class OpenAI {
#encoder: Opts.RequestEncoder;
protected idempotencyHeader?: string;
private _options: ClientOptions;
private _tokenProvider: TokenProvider | undefined;

/**
* API Client for interfacing with the OpenAI API.
Expand All @@ -330,11 +339,18 @@ export class OpenAI {
organization = readEnv('OPENAI_ORG_ID') ?? null,
project = readEnv('OPENAI_PROJECT_ID') ?? null,
webhookSecret = readEnv('OPENAI_WEBHOOK_SECRET') ?? null,
tokenProvider,
...opts
}: ClientOptions = {}) {
if (apiKey === undefined) {
if (apiKey === undefined && !tokenProvider) {
throw new Errors.OpenAIError(
'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.',
);
}

if (tokenProvider && apiKey) {
throw new Errors.OpenAIError(
"The OPENAI_API_KEY environment variable is missing or empty; either provide it, or instantiate the OpenAI client with an apiKey option, like new OpenAI({ apiKey: 'My API Key' }).",
'The `apiKey` and `tokenProvider` arguments are mutually exclusive; only one can be passed at a time.',
);
}

Expand All @@ -343,6 +359,7 @@ export class OpenAI {
organization,
project,
webhookSecret,
tokenProvider,
...opts,
baseURL: baseURL || `https://api.openai.com/v1`,
};
Expand Down Expand Up @@ -370,7 +387,8 @@ export class OpenAI {

this._options = options;

this.apiKey = apiKey;
this.apiKey = apiKey ?? 'Missing Key';
this._tokenProvider = tokenProvider;
this.organization = organization;
this.project = project;
this.webhookSecret = webhookSecret;
Expand All @@ -390,6 +408,7 @@ export class OpenAI {
fetch: this.fetch,
fetchOptions: this.fetchOptions,
apiKey: this.apiKey,
tokenProvider: this._tokenProvider,
organization: this.organization,
project: this.project,
webhookSecret: this.webhookSecret,
Expand Down Expand Up @@ -438,6 +457,31 @@ export class OpenAI {
return Errors.APIError.generate(status, error, message, headers);
}

async _setToken(): Promise<boolean> {
if (typeof this._tokenProvider === 'function') {
try {
const token = await this._tokenProvider();
if (!token || typeof token.token !== 'string') {
throw new Errors.OpenAIError(
`Expected 'tokenProvider' argument to return a string but it returned ${token}`,
);
}
this.apiKey = token.token;
return true;
} catch (err: any) {
if (err instanceof Errors.OpenAIError) {
throw err;
}
throw new Errors.OpenAIError(
`Failed to get token from 'tokenProvider' function: ${err.message}`,
// @ts-ignore
{ cause: err },
);
}
}
return false;
}

buildURL(
path: string,
query: Record<string, unknown> | null | undefined,
Expand All @@ -464,7 +508,9 @@ export class OpenAI {
/**
* Used as a callback for mutating the given `FinalRequestOptions` object.
*/
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {}
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {
await this._setToken();
}

/**
* Used as a callback for mutating the given `RequestInit` object.
Expand Down
92 changes: 92 additions & 0 deletions tests/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -719,4 +719,96 @@ describe('retries', () => {
).toEqual(JSON.stringify({ a: 1 }));
expect(count).toEqual(3);
});

describe('auth', () => {
test('apiKey', async () => {
const client = new OpenAI({
baseURL: 'http://localhost:5000/',
apiKey: 'My API Key',
});
const { req } = await client.buildRequest({ path: '/foo', method: 'get' });
expect(req.headers.get('authorization')).toEqual('Bearer My API Key');
});

test('token', async () => {
const testFetch = async (url: any, { headers }: RequestInit = {}): Promise<Response> => {
return new Response(JSON.stringify({}), { headers: headers ?? [] });
};
const client = new OpenAI({
baseURL: 'http://localhost:5000/',
tokenProvider: async () => ({ token: 'my token' }),
fetch: testFetch,
});
expect(
(await client.request({ method: 'post', path: 'https://example.com' }).asResponse()).headers.get(
'authorization',
),
).toEqual('Bearer my token');
});

test('token is refreshed', async () => {
let fail = true;
const testFetch = async (url: any, { headers }: RequestInit = {}): Promise<Response> => {
if (fail) {
fail = false;
return new Response(undefined, {
status: 429,
headers: {
'Retry-After': '0.1',
},
});
}
return new Response(JSON.stringify({}), {
headers: headers ?? [],
});
};
let counter = 0;
async function tokenProvider() {
return { token: `token-${counter++}` };
}
const client = new OpenAI({
baseURL: 'http://localhost:5000/',
tokenProvider,
fetch: testFetch,
});
expect(
(
await client.chat.completions
.create({
model: '',
messages: [{ role: 'system', content: 'Hello' }],
})
.asResponse()
).headers.get('authorization'),
).toEqual('Bearer token-1');
});

test('mutual exclusive', () => {
try {
new OpenAI({
baseURL: 'http://localhost:5000/',
tokenProvider: async () => ({ token: 'my token' }),
apiKey: 'my api key',
});
} catch (error: any) {
expect(error).toBeInstanceOf(Error);
expect(error.message).toEqual(
'The `apiKey` and `tokenProvider` arguments are mutually exclusive; only one can be passed at a time.',
);
}
});

test('at least one', () => {
try {
new OpenAI({
baseURL: 'http://localhost:5000/',
});
} catch (error: any) {
expect(error).toBeInstanceOf(Error);
expect(error.message).toEqual(
'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.',
);
}
});
});
});
Loading
Loading