diff --git a/automa_ai/agents/agent_factory.py b/automa_ai/agents/agent_factory.py index 2a7de82..75e1b9c 100644 --- a/automa_ai/agents/agent_factory.py +++ b/automa_ai/agents/agent_factory.py @@ -180,7 +180,6 @@ def get_agent(self): return self.__call__() def __call__(self) -> BaseAgent: - load_tool_plugins() chat_model = resolve_chat_model( self.chat_model, diff --git a/automa_ai/agents/langgraph_chatagent.py b/automa_ai/agents/langgraph_chatagent.py index 0020e9b..5647704 100644 --- a/automa_ai/agents/langgraph_chatagent.py +++ b/automa_ai/agents/langgraph_chatagent.py @@ -31,6 +31,8 @@ memory = MemorySaver() logger = logging.getLogger(__name__) +STREAM_SOURCE_MARKER_PREFIX = "[[source:" +STREAM_SOURCE_MARKER_SUFFIX = "]]" class GenericLangGraphChatAgent(BaseAgent): @@ -466,6 +468,7 @@ async def _forward_subagent_events( "response_type": "text", "is_task_complete": False, "require_user_input": False, + "source": e.source, "content": content_str, } ) @@ -475,12 +478,21 @@ async def _forward_subagent_events( @staticmethod def _format_subagent_event(event: StreamEvent) -> str: - # content_str = f"\n\n[{event.source}] " - content_str = "" + content_body = event.content if event.metadata and event.metadata.get("final"): - content_str += "(final) " - content_str += event.content - return content_str + content_body = f"(final) {content_body}" + return GenericLangGraphChatAgent._attach_source_marker( + content_body, event.source + ) + + @staticmethod + def _attach_source_marker(content: str, source: str | None) -> str: + if not source: + return content + return ( + f"{STREAM_SOURCE_MARKER_PREFIX}{source}{STREAM_SOURCE_MARKER_SUFFIX} " + f"{content}" + ) @staticmethod def _normalize_chunk_content(chunk: AIMessageChunk) -> str | None: diff --git a/automa_ai/agents/langgraph_chatagent_test.py b/automa_ai/agents/langgraph_chatagent_test.py index b0774e7..28fd486 100644 --- a/automa_ai/agents/langgraph_chatagent_test.py +++ b/automa_ai/agents/langgraph_chatagent_test.py @@ -68,6 +68,8 @@ async def test_forward_subagent_events_emits_text(): item = await asyncio.wait_for(output_queue.get(), timeout=1) task.cancel() assert item["response_type"] == "text" + assert item["source"] == "subagent:test" + assert item["content"].startswith("[[source:subagent:test]] ") assert "(final)" in item["content"] diff --git a/examples/travel_blackboard_demo/ui.py b/examples/travel_blackboard_demo/ui.py index 6772995..7c95b9a 100644 --- a/examples/travel_blackboard_demo/ui.py +++ b/examples/travel_blackboard_demo/ui.py @@ -3,6 +3,7 @@ import asyncio import json import os +import re import uuid from pathlib import Path @@ -14,6 +15,7 @@ BASE_DIR = Path(__file__).resolve().parent BLACKBOARD_BASE_DIR = BASE_DIR / ".demo_blackboards" ORCHESTRATOR_URL = os.getenv("TRAVEL_ORCHESTRATOR_URL", "http://localhost:33000") +SOURCE_MARKER_RE = re.compile(r"^\[\[source:(?P[^\]]+)]]\s*") @st.cache_resource @@ -52,7 +54,7 @@ async def stream_reply(prompt: str, session_id: str): text_fragments = [ p.get("text") for p in parts - if p.get("kind") == "text" and p.get("text") + if p.get("kind") == "text" and p.get("text") and not p.get("text").strip().startswith("**Tool") ] if text_fragments: text_part = "\n".join(text_fragments) @@ -66,7 +68,12 @@ async def stream_reply(prompt: str, session_id: str): text_part = chunk["data"] if text_part: - yield text_part + source = None + marker_match = SOURCE_MARKER_RE.match(text_part) + if marker_match: + source = marker_match.group("source") + text_part = SOURCE_MARKER_RE.sub("", text_part, count=1) + yield {"text": text_part, "source": source} def main() -> None: @@ -111,16 +118,25 @@ def main() -> None: with st.chat_message("assistant"): placeholder = st.empty() + agent_status = st.empty() full_reply = st.session_state["messages"][assistant_index]["content"] st.session_state["is_streaming"] = True async def consume_stream(): nonlocal full_reply - async for token in stream_reply(prompt, st.session_state["session_id"]): - full_reply += token - st.session_state["messages"][assistant_index]["content"] = full_reply - placeholder.markdown(full_reply + "▌") + async for event in stream_reply(prompt, st.session_state["session_id"]): + source = event.get("source") + if source: + agent_name = source.replace("subagent:", "") + agent_status.markdown(f"## Active agent: `{agent_name}`") + + token = event.get("text", "") + if token: + full_reply += token + st.session_state["messages"][assistant_index]["content"] = full_reply + placeholder.markdown(full_reply + "▌") placeholder.markdown(full_reply) + agent_status.empty() try: asyncio.run(consume_stream()) @@ -132,4 +148,4 @@ async def consume_stream(): if __name__ == "__main__": - main() \ No newline at end of file + main()