Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ export type Action<
): StreamingResponse<O, S>;
};

export type ActionSchemaMode = 'none' | 'validate' | 'parse';

/**
* Action factory params.
*/
Expand All @@ -165,6 +167,7 @@ export type ActionParams<
O extends z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> = {
/** Name of the action. */
name:
| string
| {
Expand All @@ -180,6 +183,7 @@ export type ActionParams<
use?: Middleware<z.infer<I>, z.infer<O>, z.infer<S>>[];
streamSchema?: S;
actionType: ActionType;
schemaMode?: ActionSchemaMode;
};

export type ActionAsyncParams<
Expand Down Expand Up @@ -312,10 +316,15 @@ export function action<
input: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>> => {
input = parseSchema(input, {
schema: config.inputSchema,
jsonSchema: config.inputJsonSchema,
});
if (!config.schemaMode || config.schemaMode === 'validate') {
parseSchema(input, {
schema: config.inputSchema,
jsonSchema: config.inputJsonSchema,
});
} else if (config.schemaMode === 'parse' && config.inputSchema) {
console.log(' - - - - - - - parsing....', input);
input = config.inputSchema.parse(input);
}
let traceId;
let spanId;
let output = await runInNewSpan(
Expand Down Expand Up @@ -366,7 +375,15 @@ export function action<
});
// if context is explicitly passed in, we run action with the provided context,
// otherwise we let upstream context carry through.
const output = await runWithContext(options?.context, actFn);
let output = await runWithContext(options?.context, actFn);
if (!config.schemaMode || config.schemaMode === 'validate') {
parseSchema(output, {
schema: config.outputSchema,
jsonSchema: config.outputJsonSchema,
});
} else if (config.schemaMode === 'parse' && config.outputSchema) {
output = config.outputSchema.parse(output);
}

metadata.output = JSON.stringify(output);
return output;
Expand All @@ -378,10 +395,6 @@ export function action<
}
}
);
output = parseSchema(output, {
schema: config.outputSchema,
jsonSchema: config.outputJsonSchema,
});
return {
result: output,
telemetry: {
Expand Down
19 changes: 3 additions & 16 deletions js/core/src/flow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import type { z } from 'zod';
import { ActionFnArg, action, type Action } from './action.js';
import { ActionFnArg, ActionParams, action, type Action } from './action.js';
import { Registry, type HasRegistry } from './registry.js';
import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js';

Expand All @@ -35,17 +35,8 @@ export interface FlowConfig<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
/** Name of the flow. */
> extends Omit<ActionParams<I, O, S>, 'actionType'> {
name: string;
/** Schema of the input to the flow. */
inputSchema?: I;
/** Schema of the output from the flow. */
outputSchema?: O;
/** Schema of the streaming chunks from the flow. */
streamSchema?: S;
/** Metadata of the flow used by tooling. */
metadata?: Record<string, any>;
}

/**
Expand Down Expand Up @@ -115,11 +106,7 @@ function flowAction<
return action(
{
actionType: 'flow',
name: config.name,
inputSchema: config.inputSchema,
outputSchema: config.outputSchema,
streamSchema: config.streamSchema,
metadata: config.metadata,
...config,
},
async (
input,
Expand Down
92 changes: 92 additions & 0 deletions js/core/tests/flow_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,98 @@ describe('flow', () => {
}
);
});

it('should parse schema in parse schemaMode', async () => {
const testFlow = defineFlow(
registry,
{
name: 'testFlow',
inputSchema: z.object({
foo: z.string().default('default foo'),
}),
outputSchema: z.object({
input: z
.object({
foo: z.string().optional(),
})
.optional(),
bar: z.string().transform((val) => `${val}-transformed`),
}),
schemaMode: 'parse',
},
async (input) => {
return { input, bar: 'bar' };
}
);

const result = await testFlow({} as any);

assert.deepStrictEqual(result, {
bar: 'bar-transformed',
input: {
foo: 'default foo',
},
});
});

it('should only-validate schema in validate schemaMode', async () => {
const testFlow = defineFlow(
registry,
{
name: 'testFlow',
inputSchema: z.object({
foo: z.string().default('default foo'),
}),
outputSchema: z.object({
input: z
.object({
foo: z.string().optional(),
})
.optional(),
bar: z.string().transform((val) => `${val}-transformed`),
}),
schemaMode: 'validate',
},
async (input) => {
return { input, bar: 'bar' };
}
);

const result = await testFlow({} as any);

assert.deepStrictEqual(result, {
bar: 'bar',
input: {},
});
});

it('should ignore schema in none schemaMode', async () => {
const testFlow = defineFlow(
registry,
{
name: 'testFlow',
inputSchema: z.object({
foo: z.string().transform((val) => `${val}-transformed`),
}),
outputSchema: z.object({
bar: z.string().transform((val) => `${val}-transformed`),
}),
schemaMode: 'none',
},
async (input) => {
return { seriously: 'banana', input } as any;
}
);

const result = await testFlow({ banana: 'yeah' } as any);

assert.deepStrictEqual(result, {
input: {
banana: 'yeah',
},
seriously: 'banana',
});
});
});

describe('getContext', () => {
Expand Down