diff --git a/js/core/src/action.ts b/js/core/src/action.ts index a767dcedf0..dfbea2e570 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -157,6 +157,8 @@ export type Action< ): StreamingResponse; }; +export type ActionSchemaMode = 'none' | 'validate' | 'parse'; + /** * Action factory params. */ @@ -165,6 +167,7 @@ export type ActionParams< O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > = { + /** Name of the action. */ name: | string | { @@ -180,6 +183,7 @@ export type ActionParams< use?: Middleware, z.infer, z.infer>[]; streamSchema?: S; actionType: ActionType; + schemaMode?: ActionSchemaMode; }; export type ActionAsyncParams< @@ -312,10 +316,15 @@ export function action< input: z.infer, options?: ActionRunOptions> ): Promise>> => { - input = parseSchema(input, { - schema: config.inputSchema, - jsonSchema: config.inputJsonSchema, - }); + if (!config.schemaMode || config.schemaMode === 'validate') { + parseSchema(input, { + schema: config.inputSchema, + jsonSchema: config.inputJsonSchema, + }); + } else if (config.schemaMode === 'parse' && config.inputSchema) { + console.log(' - - - - - - - parsing....', input); + input = config.inputSchema.parse(input); + } let traceId; let spanId; let output = await runInNewSpan( @@ -366,7 +375,15 @@ export function action< }); // if context is explicitly passed in, we run action with the provided context, // otherwise we let upstream context carry through. - const output = await runWithContext(options?.context, actFn); + let output = await runWithContext(options?.context, actFn); + if (!config.schemaMode || config.schemaMode === 'validate') { + parseSchema(output, { + schema: config.outputSchema, + jsonSchema: config.outputJsonSchema, + }); + } else if (config.schemaMode === 'parse' && config.outputSchema) { + output = config.outputSchema.parse(output); + } metadata.output = JSON.stringify(output); return output; @@ -378,10 +395,6 @@ export function action< } } ); - output = parseSchema(output, { - schema: config.outputSchema, - jsonSchema: config.outputJsonSchema, - }); return { result: output, telemetry: { diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 56c3a3d6b9..fcc0564c6b 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -15,7 +15,7 @@ */ import type { z } from 'zod'; -import { ActionFnArg, action, type Action } from './action.js'; +import { ActionFnArg, ActionParams, action, type Action } from './action.js'; import { Registry, type HasRegistry } from './registry.js'; import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js'; @@ -35,17 +35,8 @@ export interface FlowConfig< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, -> { - /** Name of the flow. */ +> extends Omit, 'actionType'> { name: string; - /** Schema of the input to the flow. */ - inputSchema?: I; - /** Schema of the output from the flow. */ - outputSchema?: O; - /** Schema of the streaming chunks from the flow. */ - streamSchema?: S; - /** Metadata of the flow used by tooling. */ - metadata?: Record; } /** @@ -115,11 +106,7 @@ function flowAction< return action( { actionType: 'flow', - name: config.name, - inputSchema: config.inputSchema, - outputSchema: config.outputSchema, - streamSchema: config.streamSchema, - metadata: config.metadata, + ...config, }, async ( input, diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts index d6af65f60a..60b425a150 100644 --- a/js/core/tests/flow_test.ts +++ b/js/core/tests/flow_test.ts @@ -142,6 +142,98 @@ describe('flow', () => { } ); }); + + it('should parse schema in parse schemaMode', async () => { + const testFlow = defineFlow( + registry, + { + name: 'testFlow', + inputSchema: z.object({ + foo: z.string().default('default foo'), + }), + outputSchema: z.object({ + input: z + .object({ + foo: z.string().optional(), + }) + .optional(), + bar: z.string().transform((val) => `${val}-transformed`), + }), + schemaMode: 'parse', + }, + async (input) => { + return { input, bar: 'bar' }; + } + ); + + const result = await testFlow({} as any); + + assert.deepStrictEqual(result, { + bar: 'bar-transformed', + input: { + foo: 'default foo', + }, + }); + }); + + it('should only-validate schema in validate schemaMode', async () => { + const testFlow = defineFlow( + registry, + { + name: 'testFlow', + inputSchema: z.object({ + foo: z.string().default('default foo'), + }), + outputSchema: z.object({ + input: z + .object({ + foo: z.string().optional(), + }) + .optional(), + bar: z.string().transform((val) => `${val}-transformed`), + }), + schemaMode: 'validate', + }, + async (input) => { + return { input, bar: 'bar' }; + } + ); + + const result = await testFlow({} as any); + + assert.deepStrictEqual(result, { + bar: 'bar', + input: {}, + }); + }); + + it('should ignore schema in none schemaMode', async () => { + const testFlow = defineFlow( + registry, + { + name: 'testFlow', + inputSchema: z.object({ + foo: z.string().transform((val) => `${val}-transformed`), + }), + outputSchema: z.object({ + bar: z.string().transform((val) => `${val}-transformed`), + }), + schemaMode: 'none', + }, + async (input) => { + return { seriously: 'banana', input } as any; + } + ); + + const result = await testFlow({ banana: 'yeah' } as any); + + assert.deepStrictEqual(result, { + input: { + banana: 'yeah', + }, + seriously: 'banana', + }); + }); }); describe('getContext', () => {