Skip to content

Commit

Permalink
review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dam-Buty committed Jul 8, 2024
1 parent f5f9aa7 commit 5e03413
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
literalai-client-*.tgz

# Logs
logs
*.log
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"types": "./dist/index.d.ts",
"scripts": {
"build": "tsup ./src",
"test": "jest --runInBand --watchAll=false",
"test": "jest --detectOpenHandles --runInBand --watchAll=false",
"prepare": "husky install"
},
"author": "Literal AI",
Expand Down
27 changes: 24 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ export type PaginatedResponse<T> = {
pageInfo: PageInfo;
};

function isPlainObject(value: unknown): value is Record<string, any> {
if (typeof value !== 'object' || value === null) {
return false;
}

const prototype = Object.getPrototypeOf(value);
return prototype === null || prototype === Object.prototype;
}

/**
* Represents a utility class with serialization capabilities.
*/
Expand Down Expand Up @@ -197,6 +206,12 @@ export class Thread extends ThreadFields {
return this;
}

/**
* Sends the thread to the API, handling disabled state and setting the end time if not already set.
* @param cb The callback function to run within the context of the thread.
* @param updateThread Optional update function to modify the thread after the callback.
* @returns The output of the wrapped callback function.
*/
async wrap<Output>(
cb: (thread: Thread) => Output | Promise<Output>,
updateThread?:
Expand All @@ -217,7 +232,7 @@ export class Thread extends ThreadFields {
Object.assign(this, updatedThread);
}

await this.upsert();
this.upsert().catch(console.error);

return output;
}
Expand Down Expand Up @@ -367,6 +382,12 @@ export class Step extends StepFields {
return this;
}

/**
* Sends the step to the API, handling disabled state and setting the end time if not already set.
* @param cb The callback function to run within the context of the step.
* @param updateStep Optional update function to modify the step after the callback.
* @returns The output of the wrapped callback function.
*/
async wrap<Output>(
cb: (step: Step) => Output | Promise<Output>,
updateStep?:
Expand All @@ -383,7 +404,7 @@ export class Step extends StepFields {
() => cb(this)
);

this.output = { output };
this.output = isPlainObject(output) ? output : { output };
this.endTime = new Date().toISOString();

if (updateStep) {
Expand All @@ -394,7 +415,7 @@ export class Step extends StepFields {
Object.assign(this, updatedStep);
}

await this.send();
this.send().catch(console.error);

return output;
}
Expand Down
51 changes: 50 additions & 1 deletion tests/wrappers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ if (!url || !apiKey) {

const client = new LiteralClient(apiKey, url);

function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}

describe('Wrapper', () => {
it('handles failing step', async () => {
let threadId: Maybe<string>;
Expand Down Expand Up @@ -58,6 +62,7 @@ describe('Wrapper', () => {
});
});

await sleep(1000);
const thread = await client.api.getThread(threadId!);
const step = await client.api.getStep(stepId!);

Expand Down Expand Up @@ -113,6 +118,7 @@ describe('Wrapper', () => {
});
});

await sleep(1000);
const thread = await client.api.getThread(threadId!);
const run = await client.api.getStep(runId!);
const retrieveStep = await client.api.getStep(retrieveStepId!);
Expand All @@ -131,7 +137,7 @@ describe('Wrapper', () => {
expect(completionStep!.threadId).toEqual(thread!.id);
expect(completionStep!.parentId).toEqual(run!.id);
expect(completionStep!.output).toEqual({
output: { content: 'Paris is a city in Europe' }
content: 'Paris is a city in Europe'
});
});

Expand All @@ -153,6 +159,7 @@ describe('Wrapper', () => {
});
});

await sleep(1000);
const run = await client.api.getStep(runId!);
const step = await client.api.getStep(stepId!);

Expand Down Expand Up @@ -210,6 +217,7 @@ describe('Wrapper', () => {
{ metadata: { key: 'thread-value' } }
);

await sleep(1000);
const thread = await client.api.getThread(threadId!);
const step = await client.api.getStep(stepId!);

Expand Down Expand Up @@ -242,6 +250,7 @@ describe('Wrapper', () => {
(output) => ({ metadata: { assistantMessage: output.content } })
);

await sleep(1000);
const thread = await client.api.getThread(threadId!);
const step = await client.api.getStep(stepId!);

Expand Down Expand Up @@ -272,6 +281,7 @@ describe('Wrapper', () => {
});
});

await sleep(1000);
const thread = await client.api.getThread(threadId!);
const step = await client.api.getStep(stepId!);

Expand Down Expand Up @@ -299,6 +309,7 @@ describe('Wrapper', () => {
});
});

await sleep(1000);
const thread = await client.api.getThread(threadId!);
const step = await client.api.getStep(stepId!);

Expand All @@ -313,6 +324,7 @@ describe('Wrapper', () => {
.thread({ name: 'Test Wrappers Thread' })
.upsert();

await sleep(1000);
const thread = await client.api.getThread(threadId);

const wrappedThreadId = await thread!.wrap(async () => {
Expand All @@ -327,6 +339,7 @@ describe('Wrapper', () => {
.run({ name: 'Test Wrappers Thread' })
.send();

await sleep(1000);
const step = await client.api.getStep(stepId!);

const wrappedStepId = await step!.wrap(async () => {
Expand All @@ -336,4 +349,40 @@ describe('Wrapper', () => {
expect(wrappedStepId).toEqual(stepId);
});
});

describe('Concurrency', () => {
it("doesn't mix up threads and steps", async () => {
let firstThreadId: Maybe<string>;
let secondThreadId: Maybe<string>;
let firstStep: Maybe<Step>;
let secondStep: Maybe<Step>;

await Promise.all([
client.thread({ name: 'Thread 1' }).wrap(async () => {
firstThreadId = client.getCurrentThread()!.id;

return client
.step({ name: 'Step 1', type: 'assistant_message' })
.wrap(async () => {
firstStep = client.getCurrentStep();
return 'Paris is a city in Europe';
});
}),
client.thread({ name: 'Thread 2' }).wrap(async () => {
secondThreadId = client.getCurrentThread()!.id;

return client
.step({ name: 'Step 2', type: 'assistant_message' })
.wrap(async () => {
secondStep = client.getCurrentStep();
return 'London is a city in Europe';
});
})
]);

expect(firstThreadId).not.toEqual(secondThreadId);
expect(firstStep?.threadId).toEqual(firstThreadId);
expect(secondStep?.threadId).toEqual(secondThreadId);
});
});
});

0 comments on commit 5e03413

Please sign in to comment.