diff --git a/src/api.ts b/src/api.ts index dd4426c..5dc9787 100644 --- a/src/api.ts +++ b/src/api.ts @@ -1,6 +1,7 @@ import axios, { AxiosError } from 'axios'; import FormData from 'form-data'; import { createReadStream } from 'fs'; +import { ReadStream } from 'fs'; import { v4 as uuidv4 } from 'uuid'; import { LiteralClient } from '.'; @@ -21,6 +22,7 @@ import { PersistedGeneration } from './generation'; import { + Attachment, CleanThreadFields, Dataset, DatasetExperiment, @@ -327,6 +329,28 @@ function addGenerationsToDatasetQueryBuilder(generationIds: string[]) { `; } +type UploadFileBaseParams = { + id?: Maybe; + threadId?: string; + mime?: Maybe; +}; +type UploadFileParamsWithPath = UploadFileBaseParams & { + path: string; +}; +type UploadFileParamsWithContent = UploadFileBaseParams & { + content: + | ReadableStream + | ReadStream + | Buffer + | File + | Blob + | ArrayBuffer; +}; +type CreateAttachmentParams = { + name?: string; + metadata?: Maybe>; +}; + export class API { /** @ignore */ private client: LiteralClient; @@ -596,19 +620,25 @@ export class API { * @returns An object containing the `objectKey` of the uploaded file and the signed `url`, or `null` values if the upload fails. * @throws {Error} Throws an error if neither `content` nor `path` is provided, or if the server response is invalid. */ + + async uploadFile(params: UploadFileParamsWithContent): Promise<{ + objectKey: Maybe; + url: Maybe; + }>; + async uploadFile(params: UploadFileParamsWithPath): Promise<{ + objectKey: Maybe; + url: Maybe; + }>; async uploadFile({ content, path, id, threadId, mime - }: { - content?: Maybe; - path?: Maybe; - id?: Maybe; - threadId: string; - mime?: Maybe; - }) { + }: UploadFileParamsWithContent & UploadFileParamsWithPath): Promise<{ + objectKey: Maybe; + url: Maybe; + }> { if (!content && !path) { throw new Error('Either content or path must be provided'); } @@ -678,6 +708,52 @@ export class API { } } + async createAttachment( + params: UploadFileParamsWithContent & CreateAttachmentParams + ): Promise; + async createAttachment( + params: UploadFileParamsWithPath & CreateAttachmentParams + ): Promise; + async createAttachment( + params: UploadFileParamsWithContent & + UploadFileParamsWithPath & + CreateAttachmentParams + ): Promise { + if (params.content instanceof Blob) { + params.content = Buffer.from(await params.content.arrayBuffer()); + } + if (params.content instanceof ArrayBuffer) { + params.content = Buffer.from(params.content); + } + + const threadFromStore = this.client._currentThread(); + const stepFromStore = this.client._currentStep(); + + if (threadFromStore) { + params.threadId = threadFromStore.id; + } + + const { objectKey, url } = await this.uploadFile(params); + + const attachment = new Attachment({ + name: params.name, + objectKey, + mime: params.mime, + metadata: params.metadata, + url + }); + + if (stepFromStore) { + if (!stepFromStore.attachments) { + stepFromStore.attachments = []; + } + + stepFromStore.attachments.push(attachment); + } + + return attachment; + } + // Generation /** * Retrieves a paginated list of Generations based on the provided filters and sorting order. diff --git a/src/index.ts b/src/index.ts index 9750dfa..5f44b65 100644 --- a/src/index.ts +++ b/src/index.ts @@ -49,6 +49,18 @@ export class LiteralClient { return this.step({ ...data, type: 'run' }); } + _currentThread(): Thread | null { + const store = storage.getStore(); + + return store?.currentThread || null; + } + + _currentStep(): Step | null { + const store = storage.getStore(); + + return store?.currentStep || null; + } + /** * Gets the current thread from the context. * WARNING : this will throw if run outside of a thread context. diff --git a/src/instrumentation/openai.ts b/src/instrumentation/openai.ts index f489c95..786dea0 100644 --- a/src/instrumentation/openai.ts +++ b/src/instrumentation/openai.ts @@ -13,9 +13,7 @@ import { IGenerationMessage, LiteralClient, Maybe, - Step, - StepConstructor, - Thread + StepConstructor } from '..'; // Define a generic type for the original function to be wrapped @@ -310,25 +308,13 @@ const processOpenAIOutput = async ( tags: tags }; - let threadFromStore: Thread | null = null; - try { - threadFromStore = client.getCurrentThread(); - } catch (error) { - // Ignore error thrown if getCurrentThread is called outside of a context - } - - let stepFromStore: Step | null = null; - try { - stepFromStore = client.getCurrentStep(); - } catch (error) { - // Ignore error thrown if getCurrentStep is called outside of a context - } + const threadFromStore = client._currentThread(); + const stepFromStore = client._currentStep(); const parent = stepFromStore || threadFromStore; if ('data' in output) { // Image Generation - const stepData: StepConstructor = { name: inputs.model || 'openai', type: 'llm', diff --git a/src/openai.ts b/src/openai.ts index 359c135..48c5bec 100644 --- a/src/openai.ts +++ b/src/openai.ts @@ -42,21 +42,16 @@ class OpenAIAssistantSyncer { ); const mime = 'image/png'; - const { objectKey } = await this.client.api.uploadFile({ - threadId: litThreadId, - id: attachmentId, - content: file.body, - mime - }); - - const attachment = new Attachment({ - name: content.image_file.file_id, - id: attachmentId, - objectKey, - mime - }); - - attachments.push(attachment); + if (file.body) { + const attachment = await this.client.api.createAttachment({ + threadId: litThreadId, + id: attachmentId, + content: file.body, + mime + }); + + attachments.push(attachment); + } } else if (content.type === 'text') { output.content += content.text.value; } diff --git a/tests/api.test.ts b/tests/api.test.ts index 490573f..3c436be 100644 --- a/tests/api.test.ts +++ b/tests/api.test.ts @@ -1,14 +1,7 @@ import 'dotenv/config'; -import { createReadStream } from 'fs'; import { v4 as uuidv4 } from 'uuid'; -import { - Attachment, - ChatGeneration, - Dataset, - LiteralClient, - Score -} from '../src'; +import { ChatGeneration, Dataset, LiteralClient, Score } from '../src'; describe('End to end tests for the SDK', function () { let client: LiteralClient; @@ -336,42 +329,6 @@ describe('End to end tests for the SDK', function () { expect(scores[1].scorer).toBe('openai:gpt-3.5-turbo'); }); - it('should test attachment', async function () { - const thread = await client.thread({ id: uuidv4() }); - // Upload an attachment - const fileStream = createReadStream('./tests/chainlit-logo.png'); - const mime = 'image/png'; - - const { objectKey } = await client.api.uploadFile({ - threadId: thread.id, - content: fileStream, - mime - }); - - const attachment = new Attachment({ - name: 'test', - objectKey, - mime - }); - - const step = await thread - .step({ - name: 'test', - type: 'run', - attachments: [attachment] - }) - .send(); - - await new Promise((resolve) => setTimeout(resolve, 1000)); - - const fetchedStep = await client.api.getStep(step.id!); - expect(fetchedStep?.attachments?.length).toBe(1); - expect(fetchedStep?.attachments![0].objectKey).toBe(objectKey); - expect(fetchedStep?.attachments![0].url).toBeDefined(); - - await client.api.deleteThread(thread.id); - }); - it('should get project id', async () => { const projectId = await client.api.getProjectId(); expect(projectId).toEqual(expect.any(String)); diff --git a/tests/attachments.test.ts b/tests/attachments.test.ts new file mode 100644 index 0000000..8877b8d --- /dev/null +++ b/tests/attachments.test.ts @@ -0,0 +1,96 @@ +import 'dotenv/config'; +import { createReadStream, readFileSync } from 'fs'; + +import { Attachment, LiteralClient, Maybe } from '../src'; + +const url = process.env.LITERAL_API_URL; +const apiKey = process.env.LITERAL_API_KEY; +if (!url || !apiKey) { + throw new Error('Missing environment variables'); +} +const client = new LiteralClient(apiKey, url); + +const filePath = './tests/chainlit-logo.png'; +const mime = 'image/png'; + +function removeVariableParts(url: string) { + return url.split('X-Amz-Date')[0].split('X-Goog-Date')[0]; +} + +describe('Attachments', () => { + describe('Uploading a file', () => { + const stream = createReadStream(filePath); + const buffer = readFileSync(filePath); + const arrayBuffer = buffer.buffer; + const blob = new Blob([buffer]); + // We wrap the blob in a blob and simulate the structure of a File + const file = new Blob([blob], { type: 'image/png' }); + + it.each([ + { type: 'Stream', content: stream! }, + { type: 'Buffer', content: buffer! }, + { type: 'ArrayBuffer', content: arrayBuffer! }, + { type: 'Blob', content: blob! }, + { type: 'File', content: file! } + ])('handles $type objects', async function ({ type, content }) { + const attachment = await client.api.createAttachment({ + content, + mime, + name: `Attachment ${type}`, + metadata: { type } + }); + + const step = await client + .run({ + name: `Test ${type}`, + attachments: [attachment] + }) + .send(); + + await new Promise((resolve) => setTimeout(resolve, 1000)); + + const fetchedStep = await client.api.getStep(step.id!); + + const urlWithoutVariables = removeVariableParts(attachment.url!); + const fetchedUrlWithoutVariables = removeVariableParts( + fetchedStep?.attachments![0].url as string + ); + + expect(fetchedStep?.attachments?.length).toBe(1); + expect(fetchedStep?.attachments![0].objectKey).toEqual( + attachment.objectKey + ); + expect(fetchedStep?.attachments![0].name).toEqual(attachment.name); + expect(fetchedStep?.attachments![0].metadata).toEqual( + attachment.metadata + ); + expect(urlWithoutVariables).toEqual(fetchedUrlWithoutVariables); + }); + }); + + describe('Handling context', () => { + it('attaches the attachment to the step in the context', async () => { + const stream = createReadStream(filePath); + + let stepId: Maybe; + let attachment: Maybe; + + await client.run({ name: 'Attachment test ' }).wrap(async () => { + stepId = client.getCurrentStep().id!; + attachment = await client.api.createAttachment({ + content: stream!, + mime, + name: 'Attachment', + metadata: { type: 'Stream' } + }); + }); + + await new Promise((resolve) => setTimeout(resolve, 1000)); + + const fetchedStep = await client.api.getStep(stepId!); + + expect(fetchedStep?.attachments?.length).toBe(1); + expect(fetchedStep?.attachments![0].id).toEqual(attachment!.id); + }); + }); +}); diff --git a/tests/chainlit-logo.png b/tests/chainlit-logo.png index ee5cccb..30c4a4f 100644 Binary files a/tests/chainlit-logo.png and b/tests/chainlit-logo.png differ