diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 8accbfd6..a40eb8aa 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -7,6 +7,14 @@ from logging import warning from pathlib import Path +from agentlab.llm.chat_api import ( + AzureChatModel, + OpenAIChatModel, + OpenRouterChatModel, + make_system_message, + make_user_message, +) + import gradio as gr import matplotlib.patches as patches import matplotlib.pyplot as plt @@ -15,7 +23,7 @@ from attr import dataclass from browsergym.experiments.loop import StepInfo as BGymStepInfo from langchain.schema import BaseMessage, HumanMessage -from openai import OpenAI +from openai import AzureOpenAI from openai.types.responses import ResponseFunctionToolCall from PIL import Image @@ -399,14 +407,27 @@ def run_gradio(results_dir: Path): interactive=True, elem_id="prompt_tests_textbox", ) - submit_button = gr.Button(value="Submit") + with gr.Row(): + num_queries_input = gr.Number( + value=3, + label="Number of model queries", + minimum=1, + maximum=10, + step=1, + precision=0, + interactive=True, + ) + submit_button = gr.Button(value="Submit") result_box = gr.Textbox( - value="", label="Result", show_label=True, interactive=False + value="", label="Result", show_label=True, interactive=False, max_lines=20 ) + with gr.Row(): + # Add plot component for action distribution graph + action_plot = gr.Plot(label="Action Distribution", show_label=True) # Define the interaction submit_button.click( - fn=submit_action, inputs=prompt_tests_textbox, outputs=result_box + fn=submit_action, inputs=[prompt_tests_textbox, num_queries_input], outputs=[result_box, action_plot] ) # Handle Events # @@ -843,9 +864,11 @@ def _page_to_iframe(page: str): return page -def submit_action(input_text): +def submit_action(input_text, num_queries=3): global info agent_info = info.exp_result.steps_info[info.step].agent_info + # Get the current step's action string for comparison + step_action_str = info.exp_result.steps_info[info.step].action chat_messages = deepcopy(agent_info.get("chat_messages", ["No Chat Messages"])[:2]) if isinstance(chat_messages[1], BaseMessage): # TODO remove once langchain is deprecated assert isinstance(chat_messages[1], HumanMessage), "Second message should be user" @@ -858,14 +881,102 @@ def submit_action(input_text): else: raise ValueError("Chat messages should be a list of BaseMessage or dict") - client = OpenAI() + client = AzureChatModel(model_name="gpt-35-turbo", deployment_name="gpt-35-turbo") chat_messages[1]["content"] = input_text - completion = client.chat.completions.create( - model="gpt-4o-mini", - messages=chat_messages, - ) - result_text = completion.choices[0].message.content - return result_text + + # Query the model N times + answers = [] + actions = [] + import re + + for _ in range(num_queries): + answer = client(chat_messages) + content = answer.get("content", "") + answers.append(content) + + # Extract action part using regex + action_match = re.search(r'(.*?)', content, re.DOTALL) + if action_match: + actions.append(action_match.group(1).strip()) + + # Prepare the aggregate result + result = "" + + # Include full responses first + result += "\n\n===== FULL MODEL RESPONSES =====\n\n" + result += "\n\n===== MODEL RESPONSE SEPARATION =====\n\n".join(answers) + + # Then add aggregated actions + result += "\n\n===== EXTRACTED ACTIONS =====\n\n" + + # Create plot for action distribution + import matplotlib.pyplot as plt + import numpy as np + from collections import Counter + + # Create a figure for the action distribution + fig = plt.figure(figsize=(10, 6)) + + if actions: + # Count unique actions + action_counts = Counter(actions) + + # Get actions in most_common order to ensure consistency between plot and text output + most_common_actions = action_counts.most_common() + + # Prepare data for plotting (using most_common order) + labels = [f"Action {i+1}" for i in range(len(most_common_actions))] + values = [count for _, count in most_common_actions] + percentages = [(count / len(actions)) * 100 for count in values] + + # Create bar chart + plt.bar(labels, percentages, color='skyblue') + plt.xlabel('Actions') + plt.ylabel('Percentage (%)') + plt.title(f'Action Distribution (from {num_queries} model queries)') + plt.ylim(0, 100) # Set y-axis from 0 to 100% + + # Add percentage labels on top of each bar + for i, v in enumerate(percentages): + plt.text(i, v + 2, f"{v:.1f}%", ha='center') + + # Add total counts as text annotation + plt.figtext(0.5, 0.01, + f"Total actions extracted: {len(actions)} | Unique actions: {len(action_counts)}", + ha="center", fontsize=10, bbox={"facecolor":"white", "alpha":0.5, "pad":5}) + + # Display unique actions and their counts in text result + for i, (action, count) in enumerate(action_counts.most_common()): + percentage = (count / len(actions)) * 100 + + # Check if this action matches the current step's action + matches_current_action = step_action_str and action.strip() == step_action_str.strip() + + # Highlight conditions: + # 1. If it's the most common action (i==0) + # 2. If it matches the current step's action + if i == 0 and matches_current_action: + result += f"** Predicted Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%) [MATCHES CURRENT ACTION]**:\n**{action}**\n\n" + elif i == 0: # Just the most common + result += f"** Predicted Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%)**:\n**{action}**\n\n" + elif matches_current_action: # Matches current action but not most common + result += f"** Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%) [MATCHES CURRENT ACTION]**:\n**{action}**\n\n" + else: # Regular action + result += f"Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%):\n{action}\n\n" + else: + result += "No actions found in any of the model responses.\n\n" + + # Create empty plot with message + plt.text(0.5, 0.5, "No actions found in model responses", + ha='center', va='center', fontsize=14) + plt.axis('off') # Hide axes + + plt.tight_layout() + + # Return both the text result and the figure + return result, fig + + def update_prompt_tests():