@@ -192,7 +192,7 @@ async def __call__(
192192 """
193193 return await generate_action (
194194 self ._registry ,
195- await self .render (input = input , config = config ),
195+ await self .render (input = input , config = config , context = context ),
196196 on_chunk = on_chunk ,
197197 middleware = self ._use ,
198198 context = context if context else ActionRunContext ._current_context (),
@@ -236,6 +236,7 @@ async def render(
236236 self ,
237237 input : dict [str , Any ] | None = None ,
238238 config : GenerationCommonConfig | dict [str , Any ] | None = None ,
239+ context : dict [str , Any ] | None = None ,
239240 ) -> GenerateActionOptions :
240241 """Renders the prompt with the given input and configuration.
241242
@@ -277,18 +278,26 @@ async def render(
277278 input ,
278279 options ,
279280 self ._cache_prompt ,
280- ActionRunContext . _current_context () or {}
281+ context
281282 )
282283 resolved_msgs .append (result )
283284 if options .messages :
284- resolved_msgs += options .messages
285+ resolved_msgs .extend (
286+ await render_message_prompt (
287+ self ._registry ,
288+ input ,
289+ options ,
290+ self ._cache_prompt ,
291+ context
292+ )
293+ )
285294 if options .prompt :
286295 result = await render_user_prompt (
287296 self ._registry ,
288297 input ,
289298 options ,
290299 self ._cache_prompt ,
291- ActionRunContext . _current_context () or {}
300+ context
292301 )
293302 resolved_msgs .append (result )
294303
@@ -437,14 +446,7 @@ async def to_generate_action_options(
437446 raise Exception ('No model configured.' )
438447 resolved_msgs : list [Message ] = []
439448 if options .system :
440- result = await render_system_prompt (
441- registry ,
442- None ,
443- options ,
444- PromptCache (),
445- ActionRunContext ._current_context () or {}
446- )
447- resolved_msgs .append (result )
449+ resolved_msgs .append (Message (role = Role .SYSTEM , content = _normalize_prompt_arg (options .system )))
448450 if options .messages :
449451 resolved_msgs += options .messages
450452 if options .prompt :
@@ -551,10 +553,9 @@ async def render_system_prompt(
551553 context ,
552554 prompt_cache .system ,
553555 input ,
554-
555556 PromptMetadata (
556557 input = PromptInputConfig (
557-
558+ schema = options . input_schema ,
558559 )
559560 ),
560561 )
@@ -565,6 +566,7 @@ async def render_system_prompt(
565566 content = _normalize_prompt_arg (options .system )
566567 )
567568
569+
568570async def render_dotprompt_to_parts (
569571 context : dict [str , Any ],
570572 prompt_function : PromptFunction ,
@@ -585,9 +587,10 @@ async def render_dotprompt_to_parts(
585587 Raises:
586588 Exception: If the template produces more than one message.
587589 """
590+ merged_input = input_
588591 rendered = await prompt_function (
589592 data = DataArgument [dict [str , Any ]](
590- input = input_ ,
593+ input = merged_input ,
591594 context = context ,
592595 ),
593596 options = options ,
@@ -606,6 +609,50 @@ async def render_dotprompt_to_parts(
606609
607610
608611
612+ async def render_message_prompt (
613+ registry : Registry ,
614+ input : dict [str , Any ],
615+ options : PromptConfig ,
616+ prompt_cache : PromptCache ,
617+ context : dict [str , Any ] | None = None ,
618+ ) -> list [Message ]:
619+ """Renders the user prompt for a prompt action."""
620+ if isinstance (options .messages , str ):
621+ if prompt_cache .messages is None :
622+ prompt_cache .messages = await registry .dotprompt .compile (options .messages )
623+
624+ if options .metadata :
625+ context = {** context , "state" : options .metadata .get ("state" )}
626+
627+ messages_ = None
628+ if isinstance (options .messages , list ):
629+ messages_ = [e .model_dump () for e in options .messages ]
630+
631+ rendered = await prompt_cache .messages (
632+ data = DataArgument [dict [str , Any ]](
633+ input = input ,
634+ context = context ,
635+ messages = messages_ ,
636+ ),
637+ options = PromptMetadata (
638+ input = PromptInputConfig (
639+ )
640+ ),
641+ )
642+ return [Message .model_validate (e .model_dump ()) for e in rendered .messages ]
643+
644+ elif isinstance (options .messages , list ):
645+ return options .messages
646+
647+ return [
648+ Message (
649+ role = Role .USER ,
650+ content = _normalize_prompt_arg (options .prompt )
651+ )
652+ ]
653+
654+
655+
609656async def render_user_prompt (
610657 registry : Registry ,
611658 input : dict [str , Any ],
@@ -615,17 +662,17 @@ async def render_user_prompt(
615662) -> Message :
616663 """Renders the user prompt for a prompt action."""
617664 if isinstance (options .prompt , str ):
618- if prompt_cache .prompt is None :
619- prompt_cache .prompt = await registry .dotprompt .compile (options .prompt )
665+ if prompt_cache .user_prompt is None :
666+ prompt_cache .user_prompt = await registry .dotprompt .compile (options .prompt )
620667
621668 if options .metadata :
622669 context = {** context , "state" : options .metadata .get ("state" )}
623670
624671 return Message (
625- role = Role .SYSTEM ,
672+ role = Role .USER ,
626673 content = await render_dotprompt_to_parts (
627674 context ,
628- prompt_cache .system ,
675+ prompt_cache .user_prompt ,
629676 input ,
630677 PromptMetadata (
631678 input = PromptInputConfig (
0 commit comments