Skip to content

Commit ad9dc47

Browse files
committed
chore: added context rendering in the templates
1 parent 31879e1 commit ad9dc47

File tree

2 files changed

+111
-24
lines changed

2 files changed

+111
-24
lines changed

py/packages/genkit/src/genkit/blocks/prompt.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ async def __call__(
192192
"""
193193
return await generate_action(
194194
self._registry,
195-
await self.render(input=input, config=config),
195+
await self.render(input=input, config=config, context=context),
196196
on_chunk=on_chunk,
197197
middleware=self._use,
198198
context=context if context else ActionRunContext._current_context(),
@@ -236,6 +236,7 @@ async def render(
236236
self,
237237
input: dict[str, Any] | None = None,
238238
config: GenerationCommonConfig | dict[str, Any] | None = None,
239+
context: dict[str, Any] | None = None,
239240
) -> GenerateActionOptions:
240241
"""Renders the prompt with the given input and configuration.
241242
@@ -277,18 +278,26 @@ async def render(
277278
input,
278279
options,
279280
self._cache_prompt,
280-
ActionRunContext._current_context() or {}
281+
context
281282
)
282283
resolved_msgs.append(result)
283284
if options.messages:
284-
resolved_msgs += options.messages
285+
resolved_msgs.extend(
286+
await render_message_prompt(
287+
self._registry,
288+
input,
289+
options,
290+
self._cache_prompt,
291+
context
292+
)
293+
)
285294
if options.prompt:
286295
result = await render_user_prompt(
287296
self._registry,
288297
input,
289298
options,
290299
self._cache_prompt,
291-
ActionRunContext._current_context() or {}
300+
context
292301
)
293302
resolved_msgs.append(result)
294303

@@ -437,14 +446,7 @@ async def to_generate_action_options(
437446
raise Exception('No model configured.')
438447
resolved_msgs: list[Message] = []
439448
if options.system:
440-
result = await render_system_prompt(
441-
registry,
442-
None,
443-
options,
444-
PromptCache(),
445-
ActionRunContext._current_context() or {}
446-
)
447-
resolved_msgs.append(result)
449+
resolved_msgs.append(Message(role=Role.SYSTEM, content=_normalize_prompt_arg(options.system)))
448450
if options.messages:
449451
resolved_msgs += options.messages
450452
if options.prompt:
@@ -551,10 +553,9 @@ async def render_system_prompt(
551553
context,
552554
prompt_cache.system,
553555
input,
554-
555556
PromptMetadata(
556557
input=PromptInputConfig(
557-
558+
schema=options.input_schema,
558559
)
559560
),
560561
)
@@ -565,6 +566,7 @@ async def render_system_prompt(
565566
content=_normalize_prompt_arg(options.system)
566567
)
567568

569+
568570
async def render_dotprompt_to_parts(
569571
context: dict[str, Any],
570572
prompt_function: PromptFunction,
@@ -585,9 +587,10 @@ async def render_dotprompt_to_parts(
585587
Raises:
586588
Exception: If the template produces more than one message.
587589
"""
590+
merged_input = input_
588591
rendered = await prompt_function(
589592
data=DataArgument[dict[str, Any]](
590-
input=input_,
593+
input=merged_input,
591594
context=context,
592595
),
593596
options=options,
@@ -606,6 +609,50 @@ async def render_dotprompt_to_parts(
606609

607610

608611

612+
async def render_message_prompt(
613+
registry: Registry,
614+
input: dict[str, Any],
615+
options: PromptConfig,
616+
prompt_cache: PromptCache,
617+
context: dict[str, Any] | None = None,
618+
) -> list[Message]:
619+
"""Renders the user prompt for a prompt action."""
620+
if isinstance(options.messages, str):
621+
if prompt_cache.messages is None:
622+
prompt_cache.messages = await registry.dotprompt.compile(options.messages)
623+
624+
if options.metadata:
625+
context = {**context, "state": options.metadata.get("state")}
626+
627+
messages_ = None
628+
if isinstance(options.messages, list):
629+
messages_ = [e.model_dump() for e in options.messages]
630+
631+
rendered = await prompt_cache.messages(
632+
data=DataArgument[dict[str, Any]](
633+
input=input,
634+
context=context,
635+
messages=messages_,
636+
),
637+
options=PromptMetadata(
638+
input=PromptInputConfig(
639+
)
640+
),
641+
)
642+
return [Message.model_validate(e.model_dump()) for e in rendered.messages]
643+
644+
elif isinstance(options.messages, list):
645+
return options.messages
646+
647+
return [
648+
Message(
649+
role=Role.USER,
650+
content=_normalize_prompt_arg(options.prompt)
651+
)
652+
]
653+
654+
655+
609656
async def render_user_prompt(
610657
registry: Registry,
611658
input: dict[str, Any],
@@ -615,17 +662,17 @@ async def render_user_prompt(
615662
) -> Message:
616663
"""Renders the user prompt for a prompt action."""
617664
if isinstance(options.prompt, str):
618-
if prompt_cache.prompt is None:
619-
prompt_cache.prompt = await registry.dotprompt.compile(options.prompt)
665+
if prompt_cache.user_prompt is None:
666+
prompt_cache.user_prompt = await registry.dotprompt.compile(options.prompt)
620667

621668
if options.metadata:
622669
context = {**context, "state": options.metadata.get("state")}
623670

624671
return Message(
625-
role=Role.SYSTEM,
672+
role=Role.USER,
626673
content=await render_dotprompt_to_parts(
627674
context,
628-
prompt_cache.system,
675+
prompt_cache.user_prompt,
629676
input,
630677
PromptMetadata(
631678
input=PromptInputConfig(

py/packages/genkit/tests/genkit/blocks/prompt_test.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_tool(input: ToolInput):
124124
tools=['testTool'],
125125
tool_choice=ToolChoice.REQUIRED,
126126
max_turns=5,
127-
input_schema=PromptInput,
127+
input_schema=PromptInput.model_json_schema(),
128128
output_constrained=True,
129129
output_format='json',
130130
description='a prompt descr',
@@ -143,7 +143,7 @@ def test_tool(input: ToolInput):
143143

144144
test_cases_parse_partial_json = [
145145
(
146-
"renders user prompt",
146+
"renders system prompt",
147147
{
148148
"model": "echoModel",
149149
"config": {"banana": "ripe"},
@@ -158,28 +158,68 @@ def test_tool(input: ToolInput):
158158
},
159159
{"name": "foo"},
160160
GenerationCommonConfig.model_validate({"temperature": 11}),
161+
{},
161162
"""[ECHO] system: "hello foo (bar)" {"temperature":11.0}"""
163+
),
164+
(
165+
"renders user prompt",
166+
{
167+
"model": "echoModel",
168+
"config": {"banana": "ripe"},
169+
"input_schema": {
170+
'type': 'object',
171+
'properties': {
172+
'name': {'type': 'string'},
173+
},
174+
}, # Note: Schema representation might need adjustment
175+
"prompt": "hello {{name}} ({{@state.name}})",
176+
"metadata": {"state": {"name": "bar_system"}}
177+
},
178+
{"name": "foo"},
179+
GenerationCommonConfig.model_validate({"temperature": 11}),
180+
{},
181+
"""[ECHO] user: "hello foo (bar_system)" {"temperature":11.0}"""
182+
),
183+
(
184+
"renders user prompt with context",
185+
{
186+
"model": "echoModel",
187+
"config": {"banana": "ripe"},
188+
"input_schema": {
189+
'type': 'object',
190+
'properties': {
191+
'name': {'type': 'string'},
192+
},
193+
}, # Note: Schema representation might need adjustment
194+
"prompt": "hello {{name}} ({{@state.name}}, {{@auth.email}})",
195+
"metadata": {"state": {"name": "bar"}}
196+
},
197+
{"name": "foo"},
198+
GenerationCommonConfig.model_validate({"temperature": 11}),
199+
{ "auth": { "email": '[email protected]' } },
200+
"""[ECHO] user: "hello foo (bar, [email protected])" {"temperature":11.0}"""
162201
)
163202
]
164203
@pytest.mark.asyncio
165204
@pytest.mark.parametrize(
166-
'test_case, prompt, input, input_option, want_rendered',
205+
'test_case, prompt, input, input_option, context, want_rendered',
167206
test_cases_parse_partial_json,
168207
ids=[tc[0] for tc in test_cases_parse_partial_json],
169208
)
170-
async def test_prompt_with_system(
209+
async def test_prompt_rendering_dotprompt(
171210
test_case: str,
172211
prompt: dict[str, Any],
173212
input: dict[str, Any],
174213
input_option: GenerationCommonConfig,
214+
context: dict[str, Any],
175215
want_rendered: str
176216
) -> None:
177217
"""Test system prompt rendering."""
178218
ai, *_ = setup_test()
179219

180220
my_prompt = ai.define_prompt(**prompt)
181221

182-
response = await my_prompt(input, input_option)
222+
response = await my_prompt(input, input_option, context=context)
183223

184224
assert response.text == want_rendered
185225

0 commit comments

Comments
 (0)