Skip to content

Commit

Permalink
feat(js/core): added a way to run actions with side-channel data and …
Browse files Browse the repository at this point in the history
…streaming (#1375)
  • Loading branch information
pavelgj authored Nov 26, 2024
1 parent 2b85cb8 commit f4261f9
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 60 deletions.
12 changes: 6 additions & 6 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

import {
getStreamingCallback,
Middleware,
runWithStreamingCallback,
z,
} from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing';
import * as clc from 'colorette';
import { DocumentDataSchema } from '../document.js';
import { resolveFormat } from '../formats/index.js';
Expand All @@ -40,13 +39,14 @@ import {
GenerateResponseData,
MessageData,
MessageSchema,
ModelMiddleware,
Part,
resolveModel,
Role,
ToolDefinitionSchema,
ToolResponsePart,
resolveModel,
} from '../model.js';
import { resolveTools, ToolAction, toToolDefinition } from '../tool.js';
import { ToolAction, resolveTools, toToolDefinition } from '../tool.js';

export const GenerateUtilParamSchema = z.object({
/** A model name (e.g. `vertexai/gemini-1.0-pro`). */
Expand Down Expand Up @@ -78,7 +78,7 @@ export const GenerateUtilParamSchema = z.object({
export async function generateHelper(
registry: Registry,
input: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
middleware?: ModelMiddleware[]
): Promise<GenerateResponseData> {
// do tracing
return await runInNewSpan(
Expand All @@ -103,7 +103,7 @@ export async function generateHelper(
async function generate(
registry: Registry,
rawRequest: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
middleware?: ModelMiddleware[]
): Promise<GenerateResponseData> {
const { modelAction: model } = await resolveModel(registry, rawRequest.model);
if (model.__action.metadata?.model.stage === 'deprecated') {
Expand Down
6 changes: 3 additions & 3 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
defineAction,
GenkitError,
getStreamingCallback,
Middleware,
SimpleMiddleware,
StreamingCallback,
z,
} from '@genkit-ai/core';
Expand Down Expand Up @@ -299,12 +299,12 @@ export type ModelAction<
> = Action<
typeof GenerateRequestSchema,
typeof GenerateResponseSchema,
{ model: ModelInfo }
typeof GenerateResponseChunkSchema
> & {
__configSchema: CustomOptionsSchema;
};

export type ModelMiddleware = Middleware<
export type ModelMiddleware = SimpleMiddleware<
z.infer<typeof GenerateRequestSchema>,
z.infer<typeof GenerateResponseSchema>
>;
Expand Down
4 changes: 3 additions & 1 deletion js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
import {
GenerateRequest,
GenerateRequestSchema,
GenerateResponseChunkSchema,
ModelArgument,
} from './model.js';
import { ToolAction } from './tool.js';
Expand All @@ -36,7 +37,8 @@ export type PromptFn<

export type PromptAction<I extends z.ZodTypeAny = z.ZodTypeAny> = Action<
I,
typeof GenerateRequestSchema
typeof GenerateRequestSchema,
typeof GenerateResponseChunkSchema
> & {
__action: {
metadata: {
Expand Down
6 changes: 1 addition & 5 deletions js/ai/src/reranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ export const RerankerInfoSchema = z.object({
export type RerankerInfo = z.infer<typeof RerankerInfoSchema>;

export type RerankerAction<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
Action<
typeof RerankerRequestSchema,
typeof RerankerResponseSchema,
{ model: RerankerInfo }
> & {
Action<typeof RerankerRequestSchema, typeof RerankerResponseSchema> & {
__configSchema?: CustomOptions;
};

Expand Down
6 changes: 1 addition & 5 deletions js/ai/src/retriever.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ export const RetrieverInfoSchema = z.object({
export type RetrieverInfo = z.infer<typeof RetrieverInfoSchema>;

export type RetrieverAction<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
Action<
typeof RetrieverRequestSchema,
typeof RetrieverResponseSchema,
{ model: RetrieverInfo }
> & {
Action<typeof RetrieverRequestSchema, typeof RetrieverResponseSchema> & {
__configSchema?: CustomOptions;
};

Expand Down
154 changes: 116 additions & 38 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export { JSONSchema7 };
export interface ActionMetadata<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny,
> {
actionType?: ActionType;
name: string;
Expand All @@ -43,7 +43,8 @@ export interface ActionMetadata<
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: M;
streamSchema?: S;
metadata?: Record<string, any>;
}

/**
Expand All @@ -57,16 +58,52 @@ export interface ActionResult<O> {
};
}

/**
* Options (side channel) data to pass to the model.
*/
export interface ActionRunOptions<S> {
/**
* Streaming callback (optional).
*/
onChunk?: StreamingCallback<S>;

/**
* Additional runtime context data (ex. auth context data).
*/
context?: any;
}

/**
* Options (side channel) data to pass to the model.
*/
export interface ActionFnArg<S> {
/**
* Streaming callback (optional).
*/
sendChunk: StreamingCallback<S>;

/**
* Additional runtime context data (ex. auth context data).
*/
context?: any;
}

/**
* Self-describing, validating, observable, locally and remotely callable function.
*/
export type Action<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> = ((input: z.infer<I>) => Promise<z.infer<O>>) & {
__action: ActionMetadata<I, O, M>;
run(input: z.infer<I>): Promise<ActionResult<z.infer<O>>>;
S extends z.ZodTypeAny = z.ZodTypeAny,
> = ((
input: z.infer<I>,
options?: ActionRunOptions<S>
) => Promise<z.infer<O>>) & {
__action: ActionMetadata<I, O, S>;
run(
input: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>>;
};

/**
Expand All @@ -75,7 +112,7 @@ export type Action<
type ActionParams<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny = z.ZodTypeAny,
> = {
name:
| string
Expand All @@ -88,49 +125,80 @@ type ActionParams<
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: M;
use?: Middleware<z.infer<I>, z.infer<O>>[];
metadata?: Record<string, any>;
use?: Middleware<z.infer<I>, z.infer<O>, z.infer<S>>[];
streamingSchema?: S;
};

export type SimpleMiddleware<I = any, O = any> = (
req: I,
next: (req?: I) => Promise<O>
) => Promise<O>;

export type MiddlewareWithOptions<I = any, O = any, S = any> = (
req: I,
options: ActionRunOptions<S> | undefined,
next: (req?: I, options?: ActionRunOptions<S>) => Promise<O>
) => Promise<O>;

/**
* Middleware function for actions.
*/
export interface Middleware<I = any, O = any> {
(req: I, next: (req?: I) => Promise<O>): Promise<O>;
}
export type Middleware<I = any, O = any, S = any> =
| SimpleMiddleware<I, O>
| MiddlewareWithOptions<I, O, S>;

/**
* Creates an action with provided middleware.
*/
export function actionWithMiddleware<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
action: Action<I, O, M>,
middleware: Middleware<z.infer<I>, z.infer<O>>[]
): Action<I, O, M> {
action: Action<I, O, S>,
middleware: Middleware<z.infer<I>, z.infer<O>, z.infer<S>>[]
): Action<I, O, S> {
const wrapped = (async (req: z.infer<I>) => {
return (await wrapped.run(req)).result;
}) as Action<I, O, M>;
}) as Action<I, O, S>;
wrapped.__action = action.__action;
wrapped.run = async (req: z.infer<I>): Promise<ActionResult<z.infer<O>>> => {
wrapped.run = async (
req: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>> => {
let telemetry;
const dispatch = async (index: number, req: z.infer<I>) => {
const dispatch = async (
index: number,
req: z.infer<I>,
opts?: ActionRunOptions<z.infer<S>>
) => {
if (index === middleware.length) {
// end of the chain, call the original model action
const result = await action.run(req);
const result = await action.run(req, opts);
telemetry = result.telemetry;
return result.result;
}

const currentMiddleware = middleware[index];
return currentMiddleware(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req)
);
if (currentMiddleware.length === 3) {
return (currentMiddleware as MiddlewareWithOptions<I, O, z.infer<S>>)(
req,
opts,
async (modifiedReq, modifiedOptions) =>
dispatch(index + 1, modifiedReq || req, modifiedOptions || opts)
);
} else if (currentMiddleware.length === 2) {
return (currentMiddleware as SimpleMiddleware<I, O>)(
req,
async (modifiedReq) => dispatch(index + 1, modifiedReq || req, opts)
);
} else {
throw new Error('unspported middleware function shape');
}
};

return { result: await dispatch(0, req), telemetry };
return { result: await dispatch(0, req, options), telemetry };
};
return wrapped;
}
Expand All @@ -141,17 +209,20 @@ export function actionWithMiddleware<
export function action<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
config: ActionParams<I, O, M>,
fn: (input: z.infer<I>) => Promise<z.infer<O>>
): Action<I, O> {
config: ActionParams<I, O, S>,
fn: (
input: z.infer<I>,
options: ActionFnArg<z.infer<S>>
) => Promise<z.infer<O>>
): Action<I, O, z.infer<S>> {
const actionName =
typeof config.name === 'string'
? config.name
: `${config.name.pluginId}/${config.name.actionId}`;
const actionFn = async (input: I) => {
return (await actionFn.run(input)).result;
const actionFn = async (input: I, options?: ActionRunOptions<z.infer<S>>) => {
return (await actionFn.run(input, options)).result;
};
actionFn.__action = {
name: actionName,
Expand All @@ -161,9 +232,10 @@ export function action<
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
metadata: config.metadata,
} as ActionMetadata<I, O, M>;
} as ActionMetadata<I, O, S>;
actionFn.run = async (
input: z.infer<I>
input: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>> => {
input = parseSchema(input, {
schema: config.inputSchema,
Expand All @@ -184,7 +256,10 @@ export function action<
metadata.name = actionName;
metadata.input = input;

const output = await fn(input);
const output = await fn(input, {
context: options?.context,
sendChunk: options?.onChunk ?? ((c) => {}),
});

metadata.output = JSON.stringify(output);
return output;
Expand Down Expand Up @@ -239,13 +314,16 @@ function validateActionId(actionId: string) {
export function defineAction<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
config: ActionParams<I, O, M> & {
config: ActionParams<I, O, S> & {
actionType: ActionType;
},
fn: (input: z.infer<I>) => Promise<z.infer<O>>
fn: (
input: z.infer<I>,
options: ActionFnArg<z.infer<S>>
) => Promise<z.infer<O>>
): Action<I, O> {
if (isInRuntimeContext()) {
throw new Error(
Expand All @@ -258,10 +336,10 @@ export function defineAction<
} else {
validateActionId(config.name.actionId);
}
const act = action(config, async (i: I): Promise<z.infer<O>> => {
const act = action(config, async (i: I, options): Promise<z.infer<O>> => {
setCustomMetadataAttributes({ subtype: config.actionType });
await registry.initializeAllPlugins();
return await runInActionRuntimeContext(() => fn(i));
return await runInActionRuntimeContext(() => fn(i, options));
});
act.__action.actionType = config.actionType;
registry.registerAction(config.actionType, act);
Expand Down
Loading

0 comments on commit f4261f9

Please sign in to comment.