Skip to content

Commit

Permalink
refactor: migrate to new openai SDK (#73)
Browse files Browse the repository at this point in the history
* refactor: migrate to new openai SDK

* chore: update deps

* refactor: improve upload errors
  • Loading branch information
sinedied authored May 16, 2024
1 parent 4dae38c commit 0304c3c
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 231 deletions.
10 changes: 5 additions & 5 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ param searchServiceSkuName string
param openAiLocation string // Set in main.parameters.json
param openAiSkuName string = 'S0'
param openAiUrl string = ''
param openAiApiVersion string // Set in main.parameters.json

// Location is not relevant here as it's only for the built-in api
// which is not used here. Static Web App is a global service otherwise
Expand Down Expand Up @@ -100,7 +101,8 @@ module api './core/host/functions.bicep' = {
storageAccountName: storage.outputs.name
managedIdentity: true
appSettings: {
AZURE_OPENAI_API_ENDPOINT: finalOpenAiUrl
AZURE_OPENAI_API_INSTANCE_NAME: openAi.outputs.name
AZURE_OPENAI_API_VERSION: openAiApiVersion
AZURE_OPENAI_API_DEPLOYMENT_NAME: chatDeploymentName
AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME: embeddingsDeploymentName
AZURE_AISEARCH_ENDPOINT: searchUrl
Expand Down Expand Up @@ -294,12 +296,10 @@ output AZURE_TENANT_ID string = tenant().tenantId
output AZURE_RESOURCE_GROUP string = resourceGroup.name

output AZURE_OPENAI_API_ENDPOINT string = finalOpenAiUrl
output AZURE_OPENAI_API_INSTANCE_NAME string = openAi.outputs.name
output AZURE_OPENAI_API_VERSION string = openAiApiVersion
output AZURE_OPENAI_API_DEPLOYMENT_NAME string = chatDeploymentName
output AZURE_OPENAI_API_MODEL string = chatModelName
output AZURE_OPENAI_API_MODEL_VERSION string = chatModelVersion
output AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME string = embeddingsDeploymentName
output AZURE_OPENAI_API_EMBEDDINGS_MODEL string = embeddingsModelName
output AZURE_OPENAI_API_EMBEDDINGS_MODEL_VERSION string = embeddingsModelVersion
output AZURE_STORAGE_URL string = storageUrl
output AZURE_STORAGE_CONTAINER_NAME string = blobContainerName
output AZURE_AISEARCH_ENDPOINT string = searchUrl
Expand Down
3 changes: 3 additions & 0 deletions infra/main.parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"openAiLocation": {
"value": "${AZURE_OPENAI_LOCATION=eastus2}"
},
"openAiApiVersion": {
"value": "${AZURE_OPENAI_API_VERSION=2024-02-01}"
},
"chatModelName": {
"value": "${AZURE_OPENAI_API_MODEL=gpt-4}"
},
Expand Down
348 changes: 133 additions & 215 deletions package-lock.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions packages/api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
"@azure/identity": "^4.2.0",
"@azure/search-documents": "^12.0.0",
"@azure/storage-blob": "^12.17.0",
"@langchain/azure-openai": "^0.0.8",
"@langchain/community": "^0.0.55",
"@langchain/community": "^0.0.57",
"@langchain/core": "^0.1.61",
"@langchain/openai": "^0.0.31",
"@microsoft/ai-chat-protocol": "^1.0.0-alpha.20240418.1",
"dotenv": "^16.4.5",
"faiss-node": "^0.5.1",
Expand Down
10 changes: 6 additions & 4 deletions packages/api/src/functions/chat-post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Readable } from 'node:stream';
import { HttpRequest, InvocationContext, HttpResponseInit, app } from '@azure/functions';
import { AIChatCompletionRequest, AIChatCompletionDelta } from '@microsoft/ai-chat-protocol';
import { Document } from '@langchain/core/documents';
import { AzureOpenAIEmbeddings, AzureChatOpenAI } from '@langchain/azure-openai';
import { AzureOpenAIEmbeddings, AzureChatOpenAI } from '@langchain/openai';
import { Embeddings } from '@langchain/core/embeddings';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { VectorStore } from '@langchain/core/vectorstores';
Expand All @@ -16,7 +16,7 @@ import { createRetrievalChain } from 'langchain/chains/retrieval';
import 'dotenv/config';
import { badRequest, data, serviceUnavailable } from '../http-response';
import { ollamaChatModel, ollamaEmbeddingsModel, faissStoreFolder } from '../constants';
import { getCredentials } from '../security';
import { getAzureOpenAiTokenProvider, getCredentials } from '../security';

const systemPrompt = `Assistant helps the Consto Real Estate company customers with questions and support requests. Be brief in your answers. Answer only plain text, DO NOT use Markdown.
Answer ONLY with information from the sources below. If there isn't enough information in the sources, say you don't know. Do not generate answers that don't use the sources. If asking a clarifying question to the user would help, ask the question.
Expand Down Expand Up @@ -53,12 +53,14 @@ export async function postChat(request: HttpRequest, context: InvocationContext)

if (azureOpenAiEndpoint) {
const credentials = getCredentials();
const azureADTokenProvider = getAzureOpenAiTokenProvider();

// Initialize models and vector database
embeddings = new AzureOpenAIEmbeddings({ credentials });
embeddings = new AzureOpenAIEmbeddings({ azureADTokenProvider });
model = new AzureChatOpenAI({
// Controls randomness. 0 = deterministic, 1 = maximum randomness
temperature: 0.7,
credentials,
azureADTokenProvider,
});
store = new AzureAISearchVectorStore(embeddings, { credentials });
} else {
Expand Down
9 changes: 6 additions & 3 deletions packages/api/src/functions/documents-post.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { HttpRequest, HttpResponseInit, InvocationContext, app } from '@azure/functions';
import { AzureOpenAIEmbeddings } from '@langchain/azure-openai';
import { AzureOpenAIEmbeddings } from '@langchain/openai';
import { PDFLoader } from 'langchain/document_loaders/fs/pdf';
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { AzureAISearchVectorStore } from '@langchain/community/vectorstores/azure_aisearch';
Expand All @@ -9,7 +9,7 @@ import 'dotenv/config';
import { BlobServiceClient } from '@azure/storage-blob';
import { badRequest, serviceUnavailable, ok } from '../http-response';
import { ollamaEmbeddingsModel, faissStoreFolder } from '../constants';
import { getCredentials } from '../security';
import { getAzureOpenAiTokenProvider, getCredentials } from '../security';

export async function postDocuments(request: HttpRequest, context: InvocationContext): Promise<HttpResponseInit> {
const storageUrl = process.env.AZURE_STORAGE_URL;
Expand Down Expand Up @@ -44,7 +44,10 @@ export async function postDocuments(request: HttpRequest, context: InvocationCon
// Generate embeddings and save in database
if (azureOpenAiEndpoint) {
const credentials = getCredentials();
const embeddings = new AzureOpenAIEmbeddings({ credentials });
const azureADTokenProvider = getAzureOpenAiTokenProvider();

// Initialize embeddings model and vector database
const embeddings = new AzureOpenAIEmbeddings({ azureADTokenProvider });
await AzureAISearchVectorStore.fromDocuments(documents, embeddings, { credentials });
} else {
// If no environment variables are set, it means we are running locally
Expand Down
8 changes: 7 additions & 1 deletion packages/api/src/security.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { DefaultAzureCredential } from '@azure/identity';
import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity';

const azureOpenAiScope = 'https://cognitiveservices.azure.com/.default';

let credentials: DefaultAzureCredential | undefined;

Expand All @@ -9,3 +11,7 @@ export function getCredentials(): DefaultAzureCredential {
credentials ||= new DefaultAzureCredential();
return credentials;
}

export function getAzureOpenAiTokenProvider() {
return getBearerTokenProvider(getCredentials(), azureOpenAiScope);
}
2 changes: 1 addition & 1 deletion scripts/upload-documents.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async function uploadDocuments(apiUrl, dataFolder) {
}
/* eslint-enable no-await-in-loop */
} catch (error) {
console.error(error);
console.error(`Could not upload documents: ${error.message}`);
}
}

Expand Down

0 comments on commit 0304c3c

Please sign in to comment.