Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions backend/src/analytics_agent/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def build_graph(
enabled_mutations: set[str] | None = None,
context_tools: list | None = None, # pre-built from DB context platforms at request time
engine_tools: list | None = None, # pre-built for MCP data sources (bypasses QueryEngine)
suppress_business_context_skill: bool = False,
):
from analytics_agent.agent.chart_generator import chart_node
from analytics_agent.engines.factory import get_registry
Expand All @@ -51,6 +52,8 @@ def build_graph(
llm = get_llm(streaming=True)

from analytics_agent.agent.chart_tool import create_chart
from analytics_agent.agent.proposals_tool import present_proposals
from analytics_agent.agent.results_tool import report_proposal_results

# Context platform tools — built dynamically from DB at request time.
# Falls back to env-var based build only when caller doesn't provide them.
Expand All @@ -64,7 +67,10 @@ def build_graph(
# Always-on skills (context search etc.) + opt-in write-back skills
from analytics_agent.skills.loader import build_always_on_skill_tools, build_skill_tools

skill_tools = build_always_on_skill_tools() + build_skill_tools(enabled_mutations or set())
always_on_skills = build_always_on_skill_tools()
if suppress_business_context_skill:
always_on_skills = [t for t in always_on_skills if t.name != "search_business_context"]
skill_tools = always_on_skills + build_skill_tools(enabled_mutations or set())

# Engine tools — MCP data sources supply pre-built tools; native engines use QueryEngine
if engine_tools is not None:
Expand All @@ -77,7 +83,9 @@ def build_graph(
raise ValueError(f"Engine '{engine_name}' not found.")
engine_tools = [t for t in engine.get_tools() if t.name not in disabled]
chart_tools = [] if "create_chart" in disabled else [create_chart]
all_tools = datahub_tools + skill_tools + engine_tools + chart_tools
proposal_tools = [] if "present_proposals" in disabled else [present_proposals]
results_tools = [] if "report_proposal_results" in disabled else [report_proposal_results]
all_tools = datahub_tools + skill_tools + engine_tools + chart_tools + proposal_tools + results_tools

if system_prompt_override:
from analytics_agent.skills.loader import (
Expand All @@ -87,12 +95,17 @@ def build_graph(
)

system_prompt = system_prompt_override.format(engine_name=engine_name)
system_prompt += get_search_business_context_section()
if not suppress_business_context_skill:
system_prompt += get_search_business_context_section()
system_prompt += get_improve_context_prompt_section()
if enabled_mutations:
system_prompt += get_skill_system_prompt_section(enabled_mutations)
else:
system_prompt = build_system_prompt(engine_name, enabled_skills=enabled_mutations)
system_prompt = build_system_prompt(
engine_name,
enabled_skills=enabled_mutations,
include_business_context=not suppress_business_context_skill,
)

# Enable per-tool error handling so validation errors (e.g. hallucinated
# arguments like filter= on get_entities) are returned as tool messages
Expand Down
15 changes: 12 additions & 3 deletions backend/src/analytics_agent/agent/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,23 @@ def build_history(
"input": payload.get("tool_input", {}),
}
)
elif evt in ("TOOL_RESULT", "SQL"):
elif evt in ("TOOL_RESULT", "SQL", "PROPOSALS", "PROPOSAL_RESULTS"):
idx = len(tool_results)
call_id = tool_calls[idx]["id"] if idx < len(tool_calls) else msg.id
if evt == "PROPOSALS":
result_text = orjson.dumps(payload).decode()[:4000]
tool_name = "present_proposals"
elif evt == "PROPOSAL_RESULTS":
result_text = orjson.dumps(payload).decode()[:4000]
tool_name = "report_proposal_results"
else:
result_text = payload.get("result", payload.get("sql", ""))[:4000]
tool_name = payload.get("tool_name", "")
tool_results.append(
{
"id": call_id,
"name": payload.get("tool_name", ""),
"result": payload.get("result", payload.get("sql", ""))[:4000],
"name": tool_name,
"result": result_text,
}
)
elif evt == "TEXT":
Expand Down
37 changes: 37 additions & 0 deletions backend/src/analytics_agent/agent/mcp_app_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Side-channel for MCP App tool results.

Mirrors the _pending_charts pattern in chart_tool.py:
- The wrapped tool returns a short marker string (MCP_APP_READY:<app_id>).
- The actual structured payload lives here, keyed by app_id.
- streaming.py pops from this dict when it sees the marker in on_tool_end.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any


@dataclass
class PendingApp:
app_id: str
connection_key: str
server_name: str
tool_name: str
tool_input: dict
# Structured CallToolResult content (list of content blocks), preserved so
# the frontend can forward it verbatim as `ui/notifications/tool-result`
# params per the MCP Apps spec.
tool_result: Any
resource_uri: str
csp: str | None = None
permissions: list[str] = field(default_factory=list)
# Tool names scoped to this app's connection that the iframe is allowed to
# call via the Phase 2 tool-proxy endpoint. Populated at wrap time from the
# full tool list for the originating connection_key. Persisted in the MCP_APP
# SSE payload so the endpoint can rehydrate it from the DB row.
allowed_tools: list[str] = field(default_factory=list)


# Keyed by app_id; popped once streaming.py emits the MCP_APP SSE event.
_pending_apps: dict[str, PendingApp] = {}
72 changes: 72 additions & 0 deletions backend/src/analytics_agent/agent/proposals_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import logging
import uuid
from typing import Literal

from langchain_core.tools import tool
from pydantic import BaseModel

logger = logging.getLogger(__name__)

# Side-channel: keyed by prop_id so streaming.py can fetch the payload
# without the model ever seeing the full JSON.
_pending_proposals: dict[str, dict] = {}


class ProposalItem(BaseModel):
id: str
kind: Literal["new_doc", "update_doc", "fix_description"]
title: str
detail: str
target: dict | None = None # e.g. {"urn": "...", "field_path": "..."}


@tool
async def present_proposals(
prompt: str,
proposals: list[dict],
) -> str:
"""
Present a list of improvement proposals to the user for review and selection.
Call this at the end of Step 3 of the /improve-context workflow, after drafting
proposals. The UI will render a card with checkboxes — do NOT print a markdown
list yourself.

Args:
prompt: Short framing sentence shown above the proposals
(e.g. "I found 3 improvements based on our conversation.")
proposals: List of proposal dicts, each with:
- id: unique string identifier (e.g. "1", "2", "3")
- kind: one of "new_doc", "update_doc", "fix_description"
- title: short title for the proposal
- detail: 1-2 sentence description of what to add/change
- target: optional dict with "urn" and/or "field_path" for existing entities

Example:
present_proposals(
prompt="Based on our conversation, here are 3 documentation improvements:",
proposals=[
{"id": "1", "kind": "new_doc", "title": "Revenue Metrics Guide",
"detail": "Define net ARR vs gross ARR and specify the revenue table."},
{"id": "2", "kind": "fix_description", "title": "orders.status column",
"detail": "Current description is empty. Values: pending, confirmed, shipped.",
"target": {"urn": "urn:li:dataset:...", "field_path": "status"}},
]
)
"""
try:
validated = [ProposalItem(**p) for p in proposals]
except Exception as e:
return f"present_proposals: invalid proposals format — {e}"

prop_id = str(uuid.uuid4())
_pending_proposals[prop_id] = {
"prompt": prompt,
"proposals": [p.model_dump() for p in validated],
}

return (
f"PROPOSALS_READY:{prop_id} "
f"({len(validated)} proposals; awaiting user selection)"
)
67 changes: 67 additions & 0 deletions backend/src/analytics_agent/agent/results_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

import logging
import uuid
from typing import Literal

from langchain_core.tools import tool
from pydantic import BaseModel

logger = logging.getLogger(__name__)

# Side-channel: keyed by result_id so streaming.py can fetch the payload
# without the model ever seeing the full JSON.
_pending_results: dict[str, dict] = {}


class ProposalResultItem(BaseModel):
id: str
kind: Literal["new_doc", "update_doc", "fix_description"]
title: str
status: Literal["success", "error"]
urn: str | None = None
error: str | None = None


@tool
async def report_proposal_results(
results: list[dict],
) -> str:
"""
Report the outcomes of writing approved proposals back to DataHub.
Call this ONCE after all save_correction calls have completed in Step 5
of the /improve-context workflow. The UI will render a results card —
do NOT write any additional summary text after calling this tool.

Args:
results: List of result dicts, each with:
- id: proposal id (matches the id from present_proposals)
- kind: one of "new_doc", "update_doc", "fix_description"
- title: proposal title
- status: "success" or "error"
- urn: the URN of the created/updated entity (set on success)
- error: error message (set on error)

Example:
report_proposal_results(results=[
{"id": "1", "kind": "new_doc", "title": "Revenue Metrics Guide",
"status": "success", "urn": "urn:li:corpUser:..."},
{"id": "3", "kind": "fix_description", "title": "orders.status column",
"status": "error", "error": "Permission denied"},
])
"""
try:
validated = [ProposalResultItem(**r) for r in results]
except Exception as e:
return f"report_proposal_results: invalid results format — {e}"

result_id = str(uuid.uuid4())
_pending_results[result_id] = {
"results": [r.model_dump() for r in validated],
}

successes = sum(1 for r in validated if r.status == "success")
return (
f"RESULTS_READY:{result_id} "
f"({successes}/{len(validated)} succeeded)"
)
Loading