Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions py/packages/genkit/src/genkit/core/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def do_POST(self) -> None: # noqa: N802
post_body = self.rfile.read(content_len)
payload = json.loads(post_body.decode(encoding=encoding))
action = registry.lookup_action_by_key(payload['key'])
action_input = payload.get('input')
context = payload['context'] if 'context' in payload else {}

query = urllib.parse.urlparse(self.path).query
Expand Down Expand Up @@ -215,7 +216,7 @@ async def run_fn():
else:
try:

async def run_fn():
async def run_fn():
return await action.arun_raw(raw_input=payload.get('input'), context=context)

output = run_async(loop, run_fn)
Expand Down Expand Up @@ -376,13 +377,15 @@ async def handle_run_action(

# Run the action.
context = payload.get('context', {})
action_input = payload.get('input')
stream = is_streaming_requested(request)
handler = run_streaming_action if stream else run_standard_action
return await handler(action, payload, context, version)
return await handler(action, payload, action_input, context, version)

async def run_streaming_action(
action: Action,
payload: dict[str, Any],
action_input: Any,
context: dict[str, Any],
version: str,
) -> StreamingResponse | JSONResponse:
Expand Down Expand Up @@ -447,6 +450,7 @@ async def send_chunk(chunk):
async def run_standard_action(
action: Action,
payload: dict[str, Any],
action_input: Any,
context: dict[str, Any],
version: str,
) -> JSONResponse:
Expand Down
62 changes: 62 additions & 0 deletions py/samples/prompt_demo/src/prompt_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pathlib import Path

import structlog
from pydantic import BaseModel

from genkit.ai import Genkit
from genkit.plugins.google_genai import GoogleAI
Expand All @@ -40,6 +41,51 @@ def my_helper(content, *_, **__):
ai.define_helper('my_helper', my_helper)


class OutputSchema(BaseModel):
short: str
friendly: str
like_a_pirate: str


@ai.flow(name='simplePrompt')
async def simple_prompt(input: str = ''):
return await ai.generate(prompt='You are a helpful AI assistant named Walt, say hello')


@ai.flow(name='simpleTemplate')
async def simple_template(input: str = ''):
name = 'Fred'
return await ai.generate(prompt=f'You are a helpful AI assistant named Walt. Say hello to {name}.')


hello_dotprompt = ai.define_prompt(
input_schema={'name': str},
prompt='You are a helpful AI assistant named Walt. Say hello to {{name}}',
)


class NameInput(BaseModel):
name: str = 'Fred'


@ai.flow(name='simpleDotprompt')
async def simple_dotprompt(input: NameInput):
return await hello_dotprompt(input={'name': input.name})


three_greetings_prompt = ai.define_prompt(
input_schema={'name': str},
output_schema=OutputSchema,
prompt='You are a helpful AI assistant named Walt. Say hello to {{name}}, write a response for each of the styles requested',
)


@ai.flow(name='threeGreetingsPrompt')
async def three_greetings(input: str = 'Fred') -> OutputSchema:
response = await three_greetings_prompt(input={'name': input})
return response.output


async def main():
# List actions to verify loading
actions = ai.registry.list_serializable_actions()
Expand All @@ -60,6 +106,22 @@ async def main():

await logger.ainfo('Prompt Execution Result', text=response.text)

res = await simple_prompt()
await logger.ainfo('Flow: simplePrompt', text=res.text)

res = await simple_template()
await logger.ainfo('Flow: simpleTemplate', text=res.text)

res = await simple_dotprompt(NameInput(name='Fred'))
await logger.ainfo('Flow: simpleDotprompt', text=res.text)

res = await three_greetings()
await logger.ainfo('Flow: threeGreetingsPrompt', output=res)

# Call one of the prompts just to validate everything is hooked up properly
res = await hello_dotprompt(input={'name': 'Bob'})
await logger.ainfo('Prompt: hello_dotprompt', text=res.text)


if __name__ == '__main__':
ai.run_main(main())
Loading