diff --git a/py/packages/genkit/src/genkit/core/reflection.py b/py/packages/genkit/src/genkit/core/reflection.py index 9aafdf94ae..2f5deea2b8 100644 --- a/py/packages/genkit/src/genkit/core/reflection.py +++ b/py/packages/genkit/src/genkit/core/reflection.py @@ -158,6 +158,7 @@ def do_POST(self) -> None: # noqa: N802 post_body = self.rfile.read(content_len) payload = json.loads(post_body.decode(encoding=encoding)) action = registry.lookup_action_by_key(payload['key']) + action_input = payload.get('input') context = payload['context'] if 'context' in payload else {} query = urllib.parse.urlparse(self.path).query @@ -215,7 +216,7 @@ async def run_fn(): else: try: - async def run_fn(): + async def run_fn(): return await action.arun_raw(raw_input=payload.get('input'), context=context) output = run_async(loop, run_fn) @@ -376,13 +377,15 @@ async def handle_run_action( # Run the action. context = payload.get('context', {}) + action_input = payload.get('input') stream = is_streaming_requested(request) handler = run_streaming_action if stream else run_standard_action - return await handler(action, payload, context, version) + return await handler(action, payload, action_input, context, version) async def run_streaming_action( action: Action, payload: dict[str, Any], + action_input: Any, context: dict[str, Any], version: str, ) -> StreamingResponse | JSONResponse: @@ -447,6 +450,7 @@ async def send_chunk(chunk): async def run_standard_action( action: Action, payload: dict[str, Any], + action_input: Any, context: dict[str, Any], version: str, ) -> JSONResponse: diff --git a/py/samples/prompt_demo/src/prompt_demo.py b/py/samples/prompt_demo/src/prompt_demo.py index 7ea72592f3..723821090e 100644 --- a/py/samples/prompt_demo/src/prompt_demo.py +++ b/py/samples/prompt_demo/src/prompt_demo.py @@ -18,6 +18,7 @@ from pathlib import Path import structlog +from pydantic import BaseModel from genkit.ai import Genkit from genkit.plugins.google_genai import GoogleAI @@ -40,6 +41,51 @@ def my_helper(content, *_, **__): ai.define_helper('my_helper', my_helper) +class OutputSchema(BaseModel): + short: str + friendly: str + like_a_pirate: str + + +@ai.flow(name='simplePrompt') +async def simple_prompt(input: str = ''): + return await ai.generate(prompt='You are a helpful AI assistant named Walt, say hello') + + +@ai.flow(name='simpleTemplate') +async def simple_template(input: str = ''): + name = 'Fred' + return await ai.generate(prompt=f'You are a helpful AI assistant named Walt. Say hello to {name}.') + + +hello_dotprompt = ai.define_prompt( + input_schema={'name': str}, + prompt='You are a helpful AI assistant named Walt. Say hello to {{name}}', +) + + +class NameInput(BaseModel): + name: str = 'Fred' + + +@ai.flow(name='simpleDotprompt') +async def simple_dotprompt(input: NameInput): + return await hello_dotprompt(input={'name': input.name}) + + +three_greetings_prompt = ai.define_prompt( + input_schema={'name': str}, + output_schema=OutputSchema, + prompt='You are a helpful AI assistant named Walt. Say hello to {{name}}, write a response for each of the styles requested', +) + + +@ai.flow(name='threeGreetingsPrompt') +async def three_greetings(input: str = 'Fred') -> OutputSchema: + response = await three_greetings_prompt(input={'name': input}) + return response.output + + async def main(): # List actions to verify loading actions = ai.registry.list_serializable_actions() @@ -60,6 +106,22 @@ async def main(): await logger.ainfo('Prompt Execution Result', text=response.text) + res = await simple_prompt() + await logger.ainfo('Flow: simplePrompt', text=res.text) + + res = await simple_template() + await logger.ainfo('Flow: simpleTemplate', text=res.text) + + res = await simple_dotprompt(NameInput(name='Fred')) + await logger.ainfo('Flow: simpleDotprompt', text=res.text) + + res = await three_greetings() + await logger.ainfo('Flow: threeGreetingsPrompt', output=res) + + # Call one of the prompts just to validate everything is hooked up properly + res = await hello_dotprompt(input={'name': 'Bob'}) + await logger.ainfo('Prompt: hello_dotprompt', text=res.text) + if __name__ == '__main__': ai.run_main(main())