diff --git a/maab/agents/mlzero_default/mlzero_default.sh b/maab/agents/mlzero_default/mlzero_default.sh index 6432cfe2..2d1358ba 100755 --- a/maab/agents/mlzero_default/mlzero_default.sh +++ b/maab/agents/mlzero_default/mlzero_default.sh @@ -57,9 +57,10 @@ fi mlzero \ -i "$TRAINING_PATH" \ -o "$OUTPUT_DIR" \ - -n 8 \ + -n 3 \ -v 1 \ - --initial-instruction "complete the task in 30 minutes" + --continuous_improvement \ + --initial-instruction "complete the task in 10 minutes" # Check if the process was successful if [ $? -ne 0 ]; then diff --git a/pyproject.toml b/pyproject.toml index 89f40b16..1525fadc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "torchaudio", "torchvision", "pandas>=2.2", + "reportlab>=4.4.3", "streamlit>=1.37", "streamlit-aggrid>=1.0.2", "streamlit-extras>=0.4", diff --git a/src/autogluon/assistant/agents/tool_selector_agent.py b/src/autogluon/assistant/agents/tool_selector_agent.py index 0d5751b0..d629ee9c 100644 --- a/src/autogluon/assistant/agents/tool_selector_agent.py +++ b/src/autogluon/assistant/agents/tool_selector_agent.py @@ -1,5 +1,5 @@ import logging -from typing import Tuple +from typing import List, Union from ..prompts import ToolSelectorPrompt from .base_agent import BaseAgent @@ -10,14 +10,15 @@ class ToolSelectorAgent(BaseAgent): """ - Select the most appropriate tool based on data description and task requirements. + Select and rank the most appropriate tools based on data description and task requirements. Agent Input: - data_prompt: Text string containing data prompt - description: Description of the task/data from previous analysis Agent Output: - - str: Selected tool name + - List[str]: Prioritized list of tool names + - str: Selected tool name (for backward compatibility) """ def __init__(self, config, manager, llm_config, prompt_template): @@ -39,8 +40,8 @@ def __init__(self, config, manager, llm_config, prompt_template): multi_turn=self.tool_selector_llm_config.multi_turn, ) - def __call__(self) -> Tuple[str, str]: - self.manager.log_agent_start("ToolSelectorAgent: choosing the most appropriate ML library for the task.") + def __call__(self) -> Union[List[str], str]: + self.manager.log_agent_start("ToolSelectorAgent: choosing and ranking ML libraries for the task.") # Build prompt for tool selection prompt = self.tool_selector_prompt.build() @@ -54,8 +55,9 @@ def __call__(self) -> Tuple[str, str]: response = self.tool_selector_llm.assistant_chat(prompt) - selected_tool = self.tool_selector_prompt.parse(response) + tools = self.tool_selector_prompt.parse(response) - self.manager.log_agent_end(f"ToolSelectorAgent: selected {selected_tool}.") + tools_str = ", ".join(tools) + self.manager.log_agent_end(f"ToolSelectorAgent: selected tools in priority order: {tools_str}") - return selected_tool + return tools diff --git a/src/autogluon/assistant/coding_agent.py b/src/autogluon/assistant/coding_agent.py index 2794b2c1..899c37e2 100644 --- a/src/autogluon/assistant/coding_agent.py +++ b/src/autogluon/assistant/coding_agent.py @@ -1,5 +1,6 @@ import logging import os +import time import uuid from datetime import datetime from pathlib import Path @@ -17,7 +18,7 @@ def run_agent( input_data_folder, output_folder=None, config_path=None, - max_iterations=5, + max_iterations=10, # Default higher for MCTS search continuous_improvement=None, enable_meta_prompting=None, enable_per_iteration_instruction=False, @@ -26,6 +27,24 @@ def run_agent( manager=None, verbosity=1, ): + """ + Run the AutoGluon Assistant with MCTS-based search strategy. + + Args: + input_data_folder: Path to input data directory + output_folder: Path to output directory + config_path: Path to configuration file + max_iterations: Maximum number of iterations + continuous_improvement: Whether to continue after finding a valid solution + enable_meta_prompting: Whether to enable meta-prompting + enable_per_iteration_instruction: Whether to ask for user input at each iteration + initial_user_input: Initial user instruction + extract_archives_to: Path to extract archives to + verbosity: Verbosity level + + Returns: + None + """ # Get the directory of the current file current_file_dir = Path(__file__).parent @@ -36,7 +55,7 @@ def run_agent( # Generate a random UUID4 random_uuid = uuid.uuid4() # Create the folder name using the pattern - folder_name = f"mlzero-{current_datetime}-{random_uuid}" + folder_name = f"mlzero-mcts-{current_datetime}-{random_uuid}" # Create the full path for the new folder output_folder = os.path.join(working_dir, folder_name) @@ -47,7 +66,7 @@ def run_agent( output_dir.mkdir(parents=False, exist_ok=True) configure_logging(verbosity=verbosity, output_dir=output_dir) - from .managers import Manager + from .managers.node_manager import NodeManager if extract_archives_to is not None: if extract_archives_to and extract_archives_to != input_data_folder: @@ -100,7 +119,8 @@ def run_agent( config.enable_meta_prompting = enable_meta_prompting if manager is None: - manager = Manager( + # Create a new NodeManager instance + manager = NodeManager( input_data_folder=input_data_folder, output_folder=output_folder, config=config, @@ -108,30 +128,50 @@ def run_agent( initial_user_input=initial_user_input, ) - manager.set_initial_user_input( - enable_per_iteration_instruction=enable_per_iteration_instruction, initial_user_input=initial_user_input - ) + # Initialize the manager (generate initial prompts) + manager.initialize() - while manager.time_step + 1 < max_iterations: - logger.brief(f"Starting iteration {manager.time_step + 1}!") + # Execute the MCTS search + iteration = 0 + start_time = time.time() - manager.step() + while iteration < max_iterations: + # Log the current iteration + logger.brief(f"Starting MCTS iteration {iteration + 1}/{max_iterations}") - # Generate code - manager.update_python_code() - manager.update_bash_script() + # Perform one step of the Monte Carlo Tree Search + success = manager.step() - successful = manager.execute_code() - - if successful: + if success: + # Create a best run copy when we find a successful solution manager.create_best_run_copy() + + # If not in continuous improvement mode, we can stop if not config.continuous_improvement: + logger.brief("Stopping search - solution found and continuous improvement is disabled") break + elif success is None: + logger.brief("Stopping search - all nodes are terminal.") + break + else: + pass - if manager.time_step + 1 >= max_iterations: + # Increment iteration counter + iteration += 1 + + # Check if we've exceeded the maximum iterations + if iteration >= max_iterations: logger.warning(f"[bold red]Warning: Reached maximum iterations ({max_iterations})[/bold red]") + manager.visualize_results() manager.report_token_usage() - manager.get_validation_score_summary() - logger.brief(f"output saved in {output_dir}.") + # Report token usage and validation score summary manager.cleanup() + + # Log summary + elapsed_time = time.time() - start_time + logger.brief(f"MCTS search completed in {elapsed_time:.2f} seconds") + logger.brief(f"Total nodes explored: {manager.time_step + 1}") + logger.brief(f"Best validation score: {manager.best_validation_score}") + logger.brief(f"Tools used: {', '.join(manager.used_tools)}") + logger.brief(f"Output saved in {output_dir}") diff --git a/src/autogluon/assistant/configs/anthropic.yaml b/src/autogluon/assistant/configs/anthropic.yaml index abcd275e..297bdc75 100644 --- a/src/autogluon/assistant/configs/anthropic.yaml +++ b/src/autogluon/assistant/configs/anthropic.yaml @@ -12,7 +12,7 @@ max_num_tutorials: 5 max_user_input_length: 2048 max_error_message_length: 2048 max_tutorial_length: 32768 -create_venv: false +configure_env: false condense_tutorials: True use_tutorial_summary: True continuous_improvement: False diff --git a/src/autogluon/assistant/configs/bedrock.yaml b/src/autogluon/assistant/configs/bedrock.yaml index 4da35453..c536fba1 100644 --- a/src/autogluon/assistant/configs/bedrock.yaml +++ b/src/autogluon/assistant/configs/bedrock.yaml @@ -1,6 +1,15 @@ -# Bedrock Configuration +# Tutorial Prompt Generator Configuration -per_execution_timeout: 86400 +per_execution_timeout: 7200 + +# MCTS (Monte Carlo Tree Search) parameters +exploration_constant: 1.414 # Controls exploration vs exploitation trade-off in UCT formula (higher = more exploration) +max_debug_depth: 3 # Maximum depth of debug nodes in the search tree +failure_offset: 2 # Number of failures to ignore before applying failure penalty +failure_penalty_weight: 0.5 # Weight of the penalty for failed executions in UCT calculation +initial_root_children: 5 # Maximum number of child nodes from root before considering fully expanded +max_debug_children: 2 # Maximum number of debug child nodes for a single parent node +max_evolve_children: 4 # Maximum number of evolution child nodes for a single parent node # Data Perception max_file_group_size_to_show: 5 @@ -10,9 +19,8 @@ max_chars_per_file: 768 num_tutorial_retrievals: 30 max_num_tutorials: 5 max_user_input_length: 2048 -max_error_message_length: 2048 max_tutorial_length: 32768 -create_venv: false +configure_env: false condense_tutorials: True use_tutorial_summary: True continuous_improvement: False @@ -20,7 +28,11 @@ optimize_system_resources: False cleanup_unused_env: True enable_meta_prompting: False +# Default LLM Configuration +# For each agent (coder, etc.) you can use a different one llm: &default_llm + # Note: bedrock is only supported in limited AWS regions + # and requires AWS credentials provider: bedrock model: "us.anthropic.claude-3-7-sonnet-20250219-v1:0" max_tokens: 65535 @@ -44,9 +56,7 @@ bash_coder: executer: <<: *default_llm # Merge llm_config - max_stdout_length: 8192 - max_stderr_length: 2048 - + meta_prompting: <<: *default_llm # Merge llm_config multi_turn: False @@ -80,4 +90,4 @@ task_descriptor: tool_selector: <<: *default_llm # Merge llm_config temperature: 0. - top_p: 1. \ No newline at end of file + top_p: 1. diff --git a/src/autogluon/assistant/configs/data_visualizer.yaml b/src/autogluon/assistant/configs/data_visualizer.yaml index cc283323..6b7aaa1d 100644 --- a/src/autogluon/assistant/configs/data_visualizer.yaml +++ b/src/autogluon/assistant/configs/data_visualizer.yaml @@ -11,7 +11,7 @@ num_tutorial_retrievals: 30 max_num_tutorials: 5 max_user_input_length: 2048 max_tutorial_length: 32768 -create_venv: false +configure_env: false condense_tutorials: True use_tutorial_summary: True continuous_improvement: False @@ -82,7 +82,7 @@ python_coder: {tool_prompt} - {best_code_prompt} + {code_improvement_prompt} Please provide the complete Python script that accomplishes these tasks, ensuring it's ready to run given the appropriate data inputs. @@ -96,7 +96,7 @@ python_coder: {user_input_truncate_end_2048} ### Previous Errors - {all_error_analyses} + {all_previous_error_analyses} ### Tutorials for Reference {tutorial_prompt} diff --git a/src/autogluon/assistant/configs/default.yaml b/src/autogluon/assistant/configs/default.yaml index 6c2b89c8..c536fba1 100644 --- a/src/autogluon/assistant/configs/default.yaml +++ b/src/autogluon/assistant/configs/default.yaml @@ -1,6 +1,15 @@ # Tutorial Prompt Generator Configuration -per_execution_timeout: 86400 +per_execution_timeout: 7200 + +# MCTS (Monte Carlo Tree Search) parameters +exploration_constant: 1.414 # Controls exploration vs exploitation trade-off in UCT formula (higher = more exploration) +max_debug_depth: 3 # Maximum depth of debug nodes in the search tree +failure_offset: 2 # Number of failures to ignore before applying failure penalty +failure_penalty_weight: 0.5 # Weight of the penalty for failed executions in UCT calculation +initial_root_children: 5 # Maximum number of child nodes from root before considering fully expanded +max_debug_children: 2 # Maximum number of debug child nodes for a single parent node +max_evolve_children: 4 # Maximum number of evolution child nodes for a single parent node # Data Perception max_file_group_size_to_show: 5 @@ -11,7 +20,7 @@ num_tutorial_retrievals: 30 max_num_tutorials: 5 max_user_input_length: 2048 max_tutorial_length: 32768 -create_venv: false +configure_env: false condense_tutorials: True use_tutorial_summary: True continuous_improvement: False @@ -26,10 +35,6 @@ llm: &default_llm # and requires AWS credentials provider: bedrock model: "us.anthropic.claude-3-7-sonnet-20250219-v1:0" - #provider: openai - #model: gpt-4o-2024-08-06 - #provider: anthropic - # model: claude-3-7-sonnet-20250219 max_tokens: 65535 proxy_url: null temperature: 0.1 diff --git a/src/autogluon/assistant/configs/openai.yaml b/src/autogluon/assistant/configs/openai.yaml index 5eb2e5e5..35fe0d57 100644 --- a/src/autogluon/assistant/configs/openai.yaml +++ b/src/autogluon/assistant/configs/openai.yaml @@ -12,7 +12,7 @@ max_num_tutorials: 5 max_user_input_length: 2048 max_error_message_length: 2048 max_tutorial_length: 32768 -create_venv: false +configure_env: false condense_tutorials: True use_tutorial_summary: True continuous_improvement: False diff --git a/src/autogluon/assistant/configs/sagemaker.yaml b/src/autogluon/assistant/configs/sagemaker.yaml index a81f9394..ea827529 100644 --- a/src/autogluon/assistant/configs/sagemaker.yaml +++ b/src/autogluon/assistant/configs/sagemaker.yaml @@ -11,7 +11,7 @@ num_tutorial_retrievals: 30 max_num_tutorials: 5 max_user_input_length: 2048 max_tutorial_length: 32768 -create_venv: false +configure_env: false condense_tutorials: True use_tutorial_summary: True continuous_improvement: False diff --git a/src/autogluon/assistant/constants.py b/src/autogluon/assistant/constants.py index 6242212e..2d2177a7 100644 --- a/src/autogluon/assistant/constants.py +++ b/src/autogluon/assistant/constants.py @@ -8,6 +8,9 @@ LOGO_NIGHT_PATH = PACKAGE_ROOT / "webui" / "static" / "sidebar_icon.png" LOGO_PATH = PACKAGE_ROOT / "webui" / "static" / "page_icon.png" +### Default Library +DEFAULT_LIBRARY = "machine learning" + ### WebUI VALID_CODING_LANGUAGES = ["python", "bash"] diff --git a/src/autogluon/assistant/managers/__init__.py b/src/autogluon/assistant/managers/__init__.py index 1c1738ef..f1e57d3e 100644 --- a/src/autogluon/assistant/managers/__init__.py +++ b/src/autogluon/assistant/managers/__init__.py @@ -1 +1 @@ -from .manager import Manager +from .node_manager import Node, NodeManager diff --git a/src/autogluon/assistant/managers/manager.py b/src/autogluon/assistant/managers/manager.py deleted file mode 100644 index 27feb334..00000000 --- a/src/autogluon/assistant/managers/manager.py +++ /dev/null @@ -1,649 +0,0 @@ -import logging -import os -import shutil -import uuid -from pathlib import Path -from typing import List, Optional - -from ..agents import ( - CoderAgent, - DataPerceptionAgent, - DescriptionFileRetrieverAgent, - ErrorAnalyzerAgent, - ExecuterAgent, - MetaPromptingAgent, - RerankerAgent, - RetrieverAgent, - TaskDescriptorAgent, - ToolSelectorAgent, -) -from ..constants import ENV_FOLDER_NAME -from ..llm import ChatLLMFactory -from ..tools_registry import registry -from ..utils import get_user_input_webui - -# Basic configuration -logging.basicConfig(level=logging.INFO) - -# Create a logger -logger = logging.getLogger(__name__) - - -class Manager: - def __init__( - self, - input_data_folder: str, - output_folder: str, - config: str, - initial_user_input: str, - enable_per_iteration_instruction: bool, - ): - """Initialize Manager with required paths and config from YAML file. - - Args: - input_data_folder: Path to input data directory - output_folder: Path to output directory - config_path: Path to YAML configuration file - initial_user_input: Initial user instruction - enable_per_iteration_instruction: If asking for per iteration user input - """ - self.time_step = -1 - self.best_step = -1 - self.last_successful_step = -1 - self.best_step_saved = -1 - - # Store required paths - self.input_data_folder = input_data_folder - self.output_folder = output_folder - - # Validate paths - for path, name in [(input_data_folder, "input_data_folder")]: - if not Path(path).exists(): - raise FileNotFoundError(f"{name} not found: {path}") - - # Create output folder if it doesn't exist - Path(output_folder).mkdir(parents=True, exist_ok=True) - - self.config = config - - self.set_initial_user_input( - enable_per_iteration_instruction=enable_per_iteration_instruction, initial_user_input=initial_user_input - ) - - self.target_prompt_instance = None - - self.dp_agent = DataPerceptionAgent( - config=self.config, - manager=self, - input_data_folder=self.input_data_folder, - reader_llm_config=self.config.reader, - reader_prompt_template=None, # TODO: add it to argument - ) - - self.dfr_agent = DescriptionFileRetrieverAgent( - config=self.config, - manager=self, - llm_config=self.config.description_file_retriever, - prompt_template=None, # TODO: add it to argument - ) - - self.td_agent = TaskDescriptorAgent( - config=self.config, - manager=self, - llm_config=self.config.task_descriptor, - prompt_template=None, # TODO: add it to argument - ) - - self.ts_agent = ToolSelectorAgent( - config=self.config, - manager=self, - llm_config=self.config.tool_selector, - prompt_template=None, # TODO: add it to argument - ) - - # Initialize meta-prompting - self.enable_meta_prompting = getattr(self.config, "enable_meta_prompting", False) - # Set up meta-prompting LLM config if enabled - self.meta_prompting_agent = MetaPromptingAgent( - config=self.config, - manager=self, - llm_config=self.config.meta_prompting, - ) - - # Initialize prompts - self.generate_initial_prompts() - - self.user_inputs: List[str] = [] - self.error_messages: List[str] = [] - self.error_analyses: List[str] = [] - self.python_codes: List[str] = [] - self.python_file_paths: List[str] = [] - self.bash_scripts: List[str] = [] - self.tutorial_retrievals: List[str] = [] - self.tutorial_prompts: List[str] = [] - self.val_scores: List[Optional[float]] = [] - - self.error_analyzer = ErrorAnalyzerAgent( - config=self.config, - manager=self, - llm_config=self.config.error_analyzer, - prompt_template=None, # TODO: Add prompt_template to argument - ) - - self.retriever = RetrieverAgent( - config=self.config, - manager=self, - llm_config=self.config.retriever, - prompt_template=None, # TODO: Add prompt_template to argument - ) - - self.reranker = RerankerAgent( - config=self.config, - manager=self, - llm_config=self.config.reranker, - prompt_template=None, # TODO: Add prompt_template to argument - ) - - self.python_coder = CoderAgent( - config=self.config, - manager=self, - language="python", - coding_mode="coder", - llm_config=self.config.python_coder, - prompt_template=None, - ) # TODO: Add prompt_template to argument - self.bash_coder = CoderAgent( - config=self.config, - manager=self, - language="bash", - coding_mode="coder", - llm_config=self.config.bash_coder, - prompt_template=None, - ) # TODO: Add prompt_template to argument - - self.executer = ExecuterAgent( - config=self.config, - manager=self, - language="bash", - timeout=self.config.per_execution_timeout, - executer_llm_config=self.config.executer, - executer_prompt_template=None, - ) # TODO: Add prompt_template to argument - - def generate_initial_prompts(self): - self.data_prompt = self.dp_agent() - - self.description_files = self.dfr_agent() - - self.task_description = self.td_agent() - - self.selected_tool = self.ts_agent() - - # TODO: remove the hard code for "create_venv" (add in tool registry if need installation) - if self.selected_tool.lower() in ["machine learning", "huggingface", "fairseq"]: - self.config.create_venv = True - - # Get tool-specific template and requirements if they exist - tool_info = registry.get_tool(self.selected_tool) - if not tool_info: - raise ValueError(f"Tool {self.selected_tool} not found in registry") - # Get tool-specific prompt - self.tool_prompt = tool_info.get("prompt_template", "") - if isinstance(self.tool_prompt, list): - self.tool_prompt = "\n".join(self.tool_prompt) - - @property - def user_input(self) -> str: - assert self.time_step >= 0, "No user input because the prompt generator is not stepped yet." - assert len(self.user_inputs) == self.time_step + 1, "user input is not updated yet" - return self.user_inputs[self.time_step] - - @property - def python_code(self) -> str: - assert self.time_step >= 0, "No python code because the prompt generator is not stepped yet." - assert len(self.python_codes) == self.time_step + 1, "python code is not updated yet" - return self.python_codes[self.time_step] - - @property - def python_file_path(self) -> str: - assert self.time_step >= 0, "No python file path because the prompt generator is not stepped yet." - assert len(self.python_file_paths) == self.time_step + 1, "python file path is not updated yet" - return self.python_file_paths[self.time_step] - - @property - def previous_python_code(self) -> str: - if self.time_step >= 1: - return self.python_codes[self.time_step - 1] - else: - return "" - - @property - def bash_script(self) -> str: - assert self.time_step >= 0, "No bash script because the prompt generator is not stepped yet." - assert len(self.bash_scripts) == self.time_step + 1, "bash script is not updated yet" - return self.bash_scripts[self.time_step] - - @property - def previous_bash_script(self) -> str: - if self.time_step >= 1: - return self.bash_scripts[self.time_step - 1] - else: - return "" - - @property - def error_message(self) -> str: - assert self.time_step >= 0, "No error message because the prompt generator is not stepped yet." - assert len(self.error_messages) == self.time_step + 1, "error message is not updated yet" - return self.error_messages[self.time_step] - - @property - def previous_error_message(self) -> str: - if self.time_step >= 1: - return self.error_messages[self.time_step - 1] - else: - return "" - - @property - def error_analysis(self) -> str: - assert self.time_step >= 0, "No error prompt because the prompt generator is not stepped yet." - assert len(self.error_analyses) == self.time_step + 1, "error prompt is not updated yet" - return self.error_analyses[self.time_step] - - @property - def previous_error_analysis(self) -> str: - if self.time_step >= 1: - return self.error_analyses[self.time_step - 1] - else: - return "" - - @property - def all_previous_error_analyses(self) -> str: - if self.time_step >= 1: - return "\n\n".join(self.error_analyses[: self.time_step]) - else: - return "" - - @property - def tutorial_prompt(self) -> str: - assert self.time_step >= 0, "No tutorial prompt because the prompt generator is not stepped yet." - assert len(self.tutorial_prompts) == self.time_step + 1, "tutorial prompt is not updated yet" - return self.tutorial_prompts[self.time_step] - - @property - def previous_tutorial_prompt(self) -> str: - if self.time_step >= 1: - return self.tutorial_prompts[self.time_step - 1] - else: - return "" - - @property - def tutorial_retrieval(self) -> str: - assert self.time_step >= 0, "No tutorial retrieval because the prompt generator is not stepped yet." - assert len(self.tutorial_retrievals) == self.time_step + 1, "tutorial retrieval is not updated yet" - return self.tutorial_retrievals[self.time_step] - - @property - def previous_tutorial_retrieval(self) -> str: - if self.time_step >= 1: - return self.tutorial_retrievals[self.time_step - 1] - else: - return "" - - @property - def common_env_file(self) -> str: - return registry.registry_path / "_common" / "requirements.txt" - - @property - def selected_tool_env_file(self) -> str: - tool_path = registry.get_tool(self.selected_tool)["path"] - return registry.registry_path / tool_path / "requirements.txt" - - @property - def iteration_folder(self) -> str: - if self.time_step >= 0: - iter_folder = os.path.join(self.output_folder, f"generation_iter_{self.time_step}") - else: - iter_folder = os.path.join(self.output_folder, "initialization") - os.makedirs(iter_folder, exist_ok=True) - return iter_folder - - @property - def per_iteration_output_folder(self) -> str: - iter_output_folder = os.path.join(self.iteration_folder, "output") - os.makedirs(iter_output_folder, exist_ok=True) - return iter_output_folder - - @property - def validation_score(self) -> Optional[float]: - """Get the current validation score.""" - assert self.time_step >= 0, "No validation score because the prompt generator is not stepped yet." - assert len(self.val_scores) == self.time_step + 1, "validation score is not updated yet" - return self.val_scores[self.time_step] - - @property - def best_validation_score(self) -> Optional[float]: - """Get the best validation score found so far.""" - if self.best_step >= 0 and self.best_step < len(self.val_scores): - return self.val_scores[self.best_step] - return None - - def set_initial_user_input(self, enable_per_iteration_instruction, initial_user_input): - self.enable_per_iteration_instruction = enable_per_iteration_instruction - self.initial_user_input = initial_user_input - - def step(self): - """Step the prompt generator forward.""" - self.time_step += 1 - - user_input = self.initial_user_input - # Get per iter user inputs if needed - if self.enable_per_iteration_instruction: - if self.time_step > 0: - logger.brief( - f"[bold green]Previous iteration info is stored in:[/bold green] {os.path.join(self.output_folder, f'iteration_{self.time_step - 1}')}" - ) - else: - logger.brief( - f"[bold green]Initialization info is stored in:[/bold green] {os.path.join(self.output_folder, 'initialization')}" - ) - if user_input is None: - user_input = "" - if os.environ.get("AUTOGLUON_WEBUI", "false").lower() == "true": - # If running in WebUI, get user input from stdin - user_input += "\n" + get_user_input_webui( - f"Enter your inputs for current iteration (iter {self.time_step}) (press Enter to skip): " - ) - else: - user_input += "\n" + input( - f"Enter your inputs for current iteration (iter {self.time_step}) (press Enter to skip): " - ) - - assert len(self.user_inputs) == self.time_step - self.user_inputs.append(user_input) - - if self.time_step > 0: - previous_error_analysis = self.error_analyzer() - - assert len(self.error_analyses) == self.time_step - 1 - self.error_analyses.append(previous_error_analysis) - - retrieved_tutorials = self.retriever() - assert len(self.tutorial_retrievals) == self.time_step - self.tutorial_retrievals.append(retrieved_tutorials) - - tutorial_prompt = self.reranker() - assert len(self.tutorial_prompts) == self.time_step - self.tutorial_prompts.append(tutorial_prompt) - - def write_code_script(self, script, output_code_file): - with open(output_code_file, "w") as file: - file.write(script) - - def update_python_code(self): - """Update the current Python code.""" - assert len(self.python_codes) == self.time_step - assert len(self.python_file_paths) == self.time_step - - python_code = self.python_coder() - - python_file_path = os.path.join(self.iteration_folder, "generated_code.py") - - self.write_code_script(python_code, python_file_path) - - self.python_codes.append(python_code) - self.python_file_paths.append(python_file_path) - - def update_bash_script(self): - """Update the current bash script.""" - assert len(self.bash_scripts) == self.time_step - - bash_script = self.bash_coder() - - bash_file_path = os.path.join(self.iteration_folder, "execution_script.sh") - - self.write_code_script(bash_script, bash_file_path) - - self.bash_scripts.append(bash_script) - - def execute_code(self): - planner_decision, planner_error_summary, validation_score, planner_prompt, stderr, stdout = self.executer( - code_to_execute=self.bash_script, - code_to_analyze=self.python_code, - execution_task=self.task_description, - execution_data=self.data_prompt, - ) - - self.save_and_log_states(stderr, "stderr", per_iteration=True, add_uuid=False) - self.save_and_log_states(stdout, "stdout", per_iteration=True, add_uuid=False) - - # Track validation scores and update best step - assert len(self.val_scores) == self.time_step - self.val_scores.append(validation_score) - - # Update best step if we have a better validation score (higher is better) - if validation_score is not None: - if self.best_step == -1 or validation_score > self.val_scores[self.best_step]: - self.best_step = self.time_step - logger.brief( - f"[bold green]New best validation score: {validation_score:.4f} at step {self.time_step}[/bold green]" - ) - else: - logger.brief( - f"[bold yellow]Current validation score: {validation_score:.4f} (best: {self.val_scores[self.best_step]:.4f} at step {self.best_step})[/bold yellow]" - ) - self.remove_env_folder(self.iteration_folder) - - # Save validation score information - self.save_and_log_states( - content=f"Step: {self.time_step}\nValidation Score: {validation_score}\nBest Step: {self.best_step}\nBest Score: {self.best_validation_score}", - save_name="validation_score.txt", - per_iteration=True, - add_uuid=False, - ) - - if planner_decision == "FIX": - logger.brief(f"[bold red]Code generation failed in iteration[/bold red] {self.time_step}!") - # Add suggestions to the error message to guide next iteration - error_message = f"stderr: {stderr}\n\n" if stderr else "" - error_message += ( - f"Error summary from planner (the error can appear in stdout if it's catched): {planner_error_summary}" - ) - self.update_error_message(error_message=error_message) - self.remove_env_folder(self.iteration_folder) - return False - elif planner_decision == "SUCCESS": - self.last_successful_step = self.time_step - logger.brief(f"[bold green]Code generation successful at iteration[/bold green] {self.time_step}") - if validation_score is not None: - logger.brief(f"[bold green]Final validation score: {validation_score:.4f}[/bold green]") - if self.best_step >= 0: - logger.brief( - f"[bold green]Best validation score achieved: {self.best_validation_score:.4f} at step {self.best_step}[/bold green]" - ) - self.update_error_message(error_message="") - return True - else: - logger.warning(f"###INVALID Planner Output: {planner_decision}###") - self.update_error_message(error_message="") - self.remove_env_folder(self.iteration_folder) - return False - - def update_error_message(self, error_message: str): - """Update the current error message.""" - assert len(self.error_messages) == self.time_step - self.error_messages.append(error_message) - - def get_validation_score_summary(self) -> str: - """Get a summary of all validation scores.""" - if not self.val_scores: - return "No validation scores available." - - summary = ["Validation Score Summary:"] - for i, score in enumerate(self.val_scores): - marker = " (BEST)" if i == self.best_step else "" - summary.append(f"Step {i}: {score if score is not None else 'N/A'}{marker}") - - if self.best_step >= 0: - summary.append(f"\nBest score: {self.best_validation_score:.4f} at step {self.best_step}") - - return "\n".join(summary) - - def save_and_log_states(self, content, save_name, per_iteration=False, add_uuid=False): - if add_uuid: - # Split filename and extension - name, ext = os.path.splitext(save_name) - # Generate 4-digit UUID (using first 4 characters of hex) - uuid_suffix = str(uuid.uuid4()).replace("-", "")[:4] - save_name = f"{name}_{uuid_suffix}{ext}" - - if per_iteration: - states_dir = os.path.join(self.iteration_folder, "states") - else: - states_dir = os.path.join(self.output_folder, "states") - os.makedirs(states_dir, exist_ok=True) - output_file = os.path.join(states_dir, save_name) - - logger.info(f"Saving {output_file}...") - with open(output_file, "w") as file: - if content is not None: - if isinstance(content, list): - # Join list elements with newlines - file.write("\n".join(str(item) for item in content)) - else: - # Handle as string (original behavior) - file.write(content) - else: - file.write("") - - def log_agent_start(self, message: str): - logger.brief(message) - - def log_agent_end(self, message: str): - logger.brief(message) - - def report_token_usage(self): - token_usage_path = os.path.join(self.output_folder, "token_usage.json") - usage = ChatLLMFactory.get_total_token_usage(save_path=token_usage_path) - total = usage["total"] - logger.brief( - f"Total tokens — input: {total['total_input_tokens']}, " - f"output: {total['total_output_tokens']}, " - f"sum: {total['total_tokens']}" - ) - - logger.info(f"Full token usage detail:\n{usage}") - - def create_best_run_copy(self): - """Create a 'best_run' folder that copies the best step folder. - - If no best step is available, uses the last successful step. - If neither is available, logs a warning and does nothing. - """ - - # Determine which step to copy - target_step = None - copy_reason = "" - - if self.best_step >= 0: - target_step = self.best_step - copy_reason = f"best validation score ({self.best_validation_score:.4f})" - elif self.last_successful_step >= 0: - target_step = self.last_successful_step - copy_reason = "last successful execution" - else: - logger.warning("No best step or successful step found. Cannot create best_run copy.") - return - - if target_step == self.best_step_saved: - logger.info(f"Skipping the saving process as step {target_step} has already been saved as best run.") - return - - # Create paths - source_folder = os.path.join(self.output_folder, f"generation_iter_{target_step}") - best_run_folder = os.path.join(self.output_folder, "best_run") - - # Verify source folder exists - if not os.path.exists(source_folder): - logger.warning(f"Source folder does not exist: {source_folder}") - return - - # Check if source folder has an 'output' subdirectory - source_output_folder = os.path.join(source_folder, "output") - if not os.path.exists(source_output_folder): - logger.warning(f"Source output folder does not exist: {source_output_folder}") - return - - # Remove existing best_run folder if it exists - if os.path.exists(best_run_folder): - try: - shutil.rmtree(best_run_folder) - logger.info("Removed existing best_run folder") - except Exception as e: - logger.error(f"Failed to remove existing best_run folder: {e}") - return - - try: - # Copy all files from source_output_folder to self.output_folder - for item in os.listdir(source_output_folder): - source_item = os.path.join(source_output_folder, item) - dest_item = os.path.join(self.output_folder, item) - - if os.path.isfile(source_item): - shutil.copy2(source_item, dest_item) - elif os.path.isdir(source_item): - shutil.copytree(source_item, dest_item, dirs_exist_ok=True) - - if self.config.cleanup_unused_env: - # Move conda_env folder from source to best_run folder - shutil.move( - os.path.join(source_folder, ENV_FOLDER_NAME), os.path.join(best_run_folder, ENV_FOLDER_NAME) - ) - # Copy the entire source folder to best_run folder - shutil.copytree(source_folder, best_run_folder, dirs_exist_ok=True) - - logger.brief( - f"[bold green]Created best_run folder (copied from step {target_step} - {copy_reason})[/bold green]" - ) - - # Save summary information in the best_run folder - summary_content = [ - "Best Run Summary", - "================", - f"Copied from: generation_iter_{target_step}", - f"Reason: {copy_reason}", - f"Copy created at: {os.path.basename(best_run_folder)}", - "", - self.get_validation_score_summary(), - ] - - # Save summary in both the main output folder and the best_run folder - summary_text = "\n".join(summary_content) - - self.save_and_log_states( - content=summary_text, save_name="best_run_summary.txt", per_iteration=False, add_uuid=False - ) - - self.best_step_saved = target_step - - except Exception as e: - logger.error(f"Failed to copy folder: {e}") - return - - def remove_env_folder(self, iter_folder): - if not self.config.cleanup_unused_env: - return - try: - env_folder = os.path.join(iter_folder, ENV_FOLDER_NAME) - shutil.rmtree(env_folder) - logger.info(f"Removed unused env folder {env_folder}") - except Exception as e: - logger.error(f"Failed to remove env folder {env_folder}: {e}") - - def cleanup(self): - """Clean up resources.""" - if hasattr(self, "retriever"): - self.retriever.cleanup() - - def __del__(self): - """Destructor to ensure cleanup.""" - self.cleanup() diff --git a/src/autogluon/assistant/managers/node_manager.py b/src/autogluon/assistant/managers/node_manager.py new file mode 100644 index 00000000..5d54e9fa --- /dev/null +++ b/src/autogluon/assistant/managers/node_manager.py @@ -0,0 +1,1317 @@ +""" +Node-based manager using pure Monte Carlo Tree Search. It implements a tree-based +search strategy that allows for more flexible exploration and exploitation of solution +space. It also ensures all available tools are tried during the exploration process. +""" + +import logging +import math +import os +import threading +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, List, Literal, Optional, Set + +from ..llm import ChatLLMFactory +from ..tools_registry import registry + +logger = logging.getLogger(__name__) + + +@dataclass +class Node: + """ + A node in the solution tree representing a single iteration. + Stores code, execution results, and evaluation information. + """ + + # Node creation time + ctime: float = field(default_factory=lambda: time.time()) + + # Tree structure + parent: Optional["Node"] = None + children: Set["Node"] = field(default_factory=set) + + # Node position in tree + time_step: int = None # Corresponds to the global time step when created + depth: int = 0 # Depth in the tree (root=0, increases with each level) + + # Solution stage + stage: Literal["root", "debug", "evolve"] = "root" + + # MCTS statistics + visits: int = 0 + validated_visits: int = 0 # Number of successful runs with validation scores + failure_visits: int = 0 # Number of failed runs + unvalidated_visits: int = 0 # Number of successful runs without validation scores + validated_reward: float = 0.0 # Total reward from validated runs + # total_reward: float = 0.0 # Replaced by separate reward tracking + + # Node state tracking + is_successful: bool = False # Did the execution succeed? + is_debug_successful: bool = False # Did the debug in the subtree succeed? + is_terminal: bool = False # Should this node not be expanded further? + debug_attempts: int = 0 # Number of debug attempts on this node + + # Solution artifacts + python_code: str = "" + bash_script: str = "" + tool_used: str = "" # The primary tool used for this solution + tools_available: List[str] = field( + default_factory=list + ) # All tools available for this solution, in priority order + tutorial_retrieval: str = "" # Retrieved tutorials for this node + tutorial_prompt: str = "" # Processed tutorial prompt for this node + + # Execution results + stdout: str = "" + stderr: str = "" + execution_time: float = 0.0 + error_message: str = "" + error_analysis: str = "" + + # Evaluation metrics + validation_score: Optional[float] = None + + # Locking for thread safety + _lock: threading.Lock = field(default_factory=threading.Lock) + expected_child_count: int = 0 + + @property + def id( + self, + ): + return self.time_step + + def __post_init__(self): + """ + Initialize a node, adding it to parent's children if parent exists. + Set depth based on parent's depth. + """ + if self.parent is not None: + self.parent.add_child(self) + self.depth = self.parent.depth + 1 + + def add_child(self, child: "Node") -> None: + """ + Add a child node to this node. + """ + logger.detail(f"Node {child.id} is added to children of Node {self.id}.") + self.children.add(child) + + def remove_child(self, child: "Node") -> None: + """ + Remove a child node of this node. + """ + logger.detail(f"Node {child.id} is removed from children of Node {self.id}.") + self.children.remove(child) + + @property + def is_leaf(self) -> bool: + """ + Check if the node is a leaf node (has no children). + """ + return len(self.children) == 0 + + @property + def num_children(self) -> int: + """ + Get the number of child nodes. + """ + return len(self.children) + + @property + def prev_tutorial_prompt(self) -> str: + if self.parent and self.parent.tutorial_prompt: + return self.parent.tutorial_prompt + + def update(self, reward: float, is_validated: bool = False, is_failure: bool = False) -> None: + """ + Update the node's statistics with a new reward. + + Args: + reward: The raw validation score (for validated runs) or None + is_validated: Whether this reward comes from a validated run + is_failure: Whether this run was a failure + """ + with self._lock: + self.visits += 1 + + if is_failure: + self.failure_visits += 1 + elif is_validated and reward is not None: + # For validated runs, store the raw validation score + self.validated_visits += 1 + self.validated_reward += reward # Sum up the raw scores, will be normalized in UCT + else: + # For successful runs without validation + self.unvalidated_visits += 1 + + def uct_value( + self, + exploration_constant: float = 1.414, + best_score: Optional[float] = None, + worst_score: Optional[float] = None, + failure_offset: float = 0, + failure_penalty_weight: float = 0.5, + ) -> float: + """ + Calculate the UCT (Upper Confidence Bound for Trees) value of the node. + + Args: + exploration_constant: The constant that controls exploration vs exploitation + best_score: The best validation score seen so far (for scaling) + worst_score: The worst validation score seen so far (for scaling) + + Returns: + The UCT value + """ + # For unvisited nodes, return infinity to ensure they are visited + if self.visits == 0: + return float("inf") + + # Get parent visits for UCT calculation + if self.parent: + parent_visits = max(1, self.parent.visits) + else: + parent_visits = 1 + + # Calculate exploitation term based on node stats + self.normalized_failure_visit = max(0, self.failure_visits - failure_offset) + self.failure_penalty = -failure_penalty_weight * self.normalized_failure_visit / self.visits + + # Calculate the validated rewards part + if self.validated_visits > 0: + if best_score is not None and worst_score is not None and best_score > worst_score: + # Normalize the validated_reward using best and worst scores + # First get the average raw score + self.avg_raw_score = self.validated_reward / self.validated_visits + # Then normalize it between 0 and 1 + self.normalized_score = (self.avg_raw_score - worst_score) / (best_score - worst_score) + self.validated_weight = self.validated_visits / self.visits + self.validated_contribution = self.validated_weight * self.normalized_score + else: + # If can't normalize + self.validated_contribution = 1.0 + else: + self.validated_contribution = 0.0 + + # Unvalidated contribution (nodes that succeeded but have no score) use a score of 0. and thus can be ignored + + # Total exploitation is the weighted sum of all components + self.exploitation = self.validated_contribution + self.failure_penalty + + # Calculate exploration term + self.exploration = exploration_constant * math.sqrt(math.log(parent_visits) / self.visits) + + return self.exploitation + self.exploration + + def __eq__(self, other): + if not isinstance(other, Node): + return False + return self.id == other.id + + def __hash__(self): + return hash(self.id) + + +class NodeManager: + """ + Manages a tree of nodes representing different iterations of solution development. + Uses Monte Carlo Tree Search (MCTS) to explore the solution space more effectively. + """ + + def __init__( + self, + input_data_folder: str, + output_folder: str, + config: Any, + initial_user_input: str, + enable_per_iteration_instruction: bool, + ): + """ + Initialize the NodeManager with required paths and configuration. + + Args: + input_data_folder: Path to input data directory + output_folder: Path to output directory + config: Configuration object + initial_user_input: Initial user instruction + enable_per_iteration_instruction: If asking for per iteration user input + """ + # Store required paths + self.input_data_folder = input_data_folder + self.output_folder = output_folder + + # Validate paths + for path, name in [(input_data_folder, "input_data_folder")]: + if not Path(path).exists(): + raise FileNotFoundError(f"{name} not found: {path}") + + # Create output folder if it doesn't exist + Path(output_folder).mkdir(parents=True, exist_ok=True) + + self.config = config + self.enable_per_iteration_instruction = enable_per_iteration_instruction + self.initial_user_input = initial_user_input + + # Track time_step + self.time_step = -1 + # Create root node + self.root_node = Node(stage="root", time_step=self.time_step, depth=0) + self.current_node = self.root_node + + # Track best nodes and metrics + self._best_node = None + self._best_validation_score = None + self._worst_validation_score = None + self.last_successful_node = None + + # Key node tracking + self.best_step = -1 + self.last_successful_step = -1 + + # MCTS parameters + self.exploration_constant = self.config.exploration_constant + self.max_debug_depth = self.config.max_debug_depth + self.failure_offset = self.config.failure_offset + self.failure_penalty_weight = self.config.failure_penalty_weight + + # Tracking for thread safety + self._node_lock = threading.Lock() + self.search_start_time = time.time() + + # User inputs storage + self.user_inputs = [] + + # Error analysis storage + self._all_error_analyses = [] + + # Tool tracking + self.used_tools = set() + + # Target prompt instance for meta-prompting + self.target_prompt_instance = None + + # Initialize the agent components + self._init_agents() + + def _init_agents(self): + """Initialize all required agents.""" + from ..agents import ( + CoderAgent, + DataPerceptionAgent, + DescriptionFileRetrieverAgent, + ErrorAnalyzerAgent, + ExecuterAgent, + MetaPromptingAgent, + RerankerAgent, + RetrieverAgent, + TaskDescriptorAgent, + ToolSelectorAgent, + ) + + # Data perception agent + self.dp_agent = DataPerceptionAgent( + config=self.config, + manager=self, + input_data_folder=self.input_data_folder, + reader_llm_config=self.config.reader, + reader_prompt_template=None, + ) + + # Description file retriever agent + self.dfr_agent = DescriptionFileRetrieverAgent( + config=self.config, + manager=self, + llm_config=self.config.description_file_retriever, + prompt_template=None, + ) + + # Task descriptor agent + self.td_agent = TaskDescriptorAgent( + config=self.config, + manager=self, + llm_config=self.config.task_descriptor, + prompt_template=None, + ) + + # Tool selector agent + self.ts_agent = ToolSelectorAgent( + config=self.config, + manager=self, + llm_config=self.config.tool_selector, + prompt_template=None, + ) + + # Initialize meta-prompting + self.enable_meta_prompting = self.config.enable_meta_prompting + self.meta_prompting_agent = MetaPromptingAgent( + config=self.config, + manager=self, + llm_config=self.config.meta_prompting, + ) + + # Error analyzer + self.error_analyzer = ErrorAnalyzerAgent( + config=self.config, + manager=self, + llm_config=self.config.error_analyzer, + prompt_template=None, + ) + + # Retriever + self.retriever = RetrieverAgent( + config=self.config, + manager=self, + llm_config=self.config.retriever, + prompt_template=None, + ) + + # Reranker + self.reranker = RerankerAgent( + config=self.config, + manager=self, + llm_config=self.config.reranker, + prompt_template=None, + ) + + # Python coder + self.python_coder = CoderAgent( + config=self.config, + manager=self, + language="python", + coding_mode="coder", + llm_config=self.config.python_coder, + prompt_template=None, + ) + + # Bash coder + self.bash_coder = CoderAgent( + config=self.config, + manager=self, + language="bash", + coding_mode="coder", + llm_config=self.config.bash_coder, + prompt_template=None, + ) + + # Executer + self.executer = ExecuterAgent( + config=self.config, + manager=self, + language="bash", + timeout=self.config.per_execution_timeout, + executer_llm_config=self.config.executer, + executer_prompt_template=None, + ) + + def initialize(self): + """Initialize the manager.""" + self.data_prompt = self.dp_agent() + self.description_files = self.dfr_agent() + self.task_description = self.td_agent() + + # Use tool selector to get prioritized list of tools + self.available_tools = self.ts_agent() + + def get_iteration_folder(self, node: Node) -> str: + """ + Get the folder for storing iteration artifacts. + + Args: + node: The node to get the folder for + + Returns: + Path to the iteration folder + """ + if node.id < 0: + iter_folder = os.path.join(self.output_folder, "node_init") + else: + iter_folder = os.path.join(self.output_folder, f"node_{node.id}") + os.makedirs(iter_folder, exist_ok=True) + return iter_folder + + def get_per_iteration_output_folder(self, node: Node) -> str: + """ + Get the folder for storing iteration output artifacts. + + Args: + node: The node to get the output folder for + + Returns: + Path to the iteration output folder + """ + iter_output_folder = os.path.join(self.get_iteration_folder(node), "output") + os.makedirs(iter_output_folder, exist_ok=True) + return iter_output_folder + + def save_and_log_states(self, content, save_name, per_iteration=False, add_uuid=False, node=None): + """ + Save states to a file and log them. + + Args: + content: Content to save + save_name: Name for the saved file + per_iteration: Whether this is for a specific iteration (backward compatibility) + add_uuid: Whether to add a UUID to the filename + node: Node to associate with the saved content (required if per_iteration is False) + """ + if add_uuid: + # Split filename and extension + name, ext = os.path.splitext(save_name) + # Generate 4-digit UUID (using first 4 characters of hex) + uuid_suffix = str(uuid.uuid4()).replace("-", "")[:4] + save_name = f"{name}_{uuid_suffix}{ext}" + + # Determine the save directory + if per_iteration and self.current_node: + states_dir = os.path.join(self.get_iteration_folder(self.current_node), "states") + elif node: + states_dir = os.path.join(self.get_iteration_folder(node), "states") + else: + states_dir = os.path.join(self.output_folder, "states") + + os.makedirs(states_dir, exist_ok=True) + output_file = os.path.join(states_dir, save_name) + + logger.info(f"Saving {output_file}...") + with open(output_file, "w") as file: + if content is not None: + if isinstance(content, list): + # Join list elements with newlines + file.write("\n".join(str(item) for item in content)) + else: + # Handle as string (original behavior) + file.write(content) + else: + file.write("") + + def log_agent_start(self, message: str): + """Log agent start message.""" + logger.info(message) + + def log_agent_end(self, message: str): + """Log agent end message.""" + logger.info(message) + + def select_node(self) -> Node: + """ + Select a node for expansion using UCT selection. + + Returns: + The selected node + """ + node = self.root_node + + # Traverse the tree until we find a node to expand + while node is not None and not node.is_leaf: + # If the node is not fully expanded, return it + if not self._is_fully_expanded(node): + return node + + # Otherwise, select the best child according to UCT + node = self._uct_select(node) + + return node + + def _is_fully_expanded(self, node: Node) -> bool: + """ + Check if a node is fully expanded. + + Args: + node: The node to check + + Returns: + True if the node is fully expanded, False otherwise + """ + # Root node + if node.stage == "root": + return node.num_children >= self.config.initial_root_children or self._get_unused_tool() is None + + # For debug nodes, stop expanding after getting a successful node + if node.stage == "debug": + # TODO: better debugging workflow? + if node.is_debug_successful: + return True + return node.num_children >= self.config.max_debug_children + + # For evolve nodes + if node.stage == "evolve": + return node.num_children >= self.config.max_evolve_children + + return False + + def _uct_select(self, node: Node) -> Node: + """ + Select the best child node according to UCT, excluding terminal nodes. + + Args: + node: The parent node + + Returns: + The selected child node + """ + non_terminal_children = [child for child in node.children if not child.is_terminal] + if not non_terminal_children: + # Fallback case - this shouldn't happen if backpropagation is working correctly + assert ( + node.is_terminal + ), f"All children of node {node.id} are terminal but node itself is not marked terminal" + logger.info("All nodes are terminal. Run complete.") + return None + + # Pass the best and worst validation scores for proper scaling + # If current node is root, adjust exploration constant based on tool index + if node == self.root_node: + # Get each child's tool index in the available tools list + def get_child_uct(child): + # Tools earlier in the list get higher exploration constants + tool_index = self.available_tools.index(child.tool_used) + # Scale exploration constant - earlier tools get higher values + tool_specific_exploration = self.exploration_constant * max(0.25, 1.0 - 0.25 * tool_index) + # Use config for failure offset + uct_value = child.uct_value( + tool_specific_exploration, + self._best_validation_score, + self._worst_validation_score, + failure_offset=self.failure_offset, + failure_penalty_weight=self.failure_penalty_weight, + ) + logger.detail(f"UCT Value is {uct_value} for Node {child.id}") + return uct_value + + return max(non_terminal_children, key=get_child_uct) + else: + # For non-root nodes, use the standard exploration constant + def get_child_uct(child): + uct_value = self.compute_uct_value(child) + logger.detail(f"UCT Value is {uct_value} for Node {child.id}") + return uct_value + + return max(non_terminal_children, key=get_child_uct) + + def expand(self) -> Node: + """ + Expand the current node by creating a child node. + + Returns: + The newly created child node + """ + if self.current_node.stage == "root": + return self._create_evolve_node() + elif self.current_node.is_successful: + return self._create_evolve_node() + else: + return self._create_debug_node() + + def _get_unused_tool(self) -> Optional[str]: + """ + Get a tool that has not been used yet in the tree. + + Returns: + An unused tool, or None if all tools have been used + """ + unused_tools = [tool for tool in self.available_tools if tool not in self.used_tools] + if unused_tools: + # return random.choice(unused_tools) + return unused_tools[0] # TODO: enable random selection of available tools + return None + + def _create_debug_node( + self, + ) -> Node: + """ + Create a debug node to fix issues in a failed node. + + Returns: + The newly created debug node + """ + # Increment global time step for this new node + self.time_step += 1 + + # Create a new node + self.current_node = Node( + parent=self.current_node, + stage="debug", + # Use the same tool as the parent for debugging + tool_used=self.current_node.tool_used, + tools_available=self.available_tools, + time_step=self.time_step, + debug_attempts=self.current_node.debug_attempts + 1, + ) + + # Check if we've exceeded the maximum debug attempts for this node + if self.current_node.debug_attempts >= self.max_debug_depth: + logger.warning( + f"Node {self.current_node.id} has reached the maximum debug depth ({self.max_debug_depth}). Marking as terminal." + ) + self.mark_node_terminal(self.current_node) + + # Generate code for the node + self._generate_code() + + def _create_evolve_node( + self, + ) -> Node: + """ + Create an evolve node to improve a successful node. + + Returns: + The newly created evolve node + """ + # Increment global time step for this new node + self.time_step += 1 + + # Check if there's an unused tool to try + unused_tool = self._get_unused_tool() + if unused_tool: + # If there's an unused tool, create a node from the root with that tool + logger.info(f"Found unused tool {unused_tool}, creating evolve node from root using this tool") + parent = self.root_node + tool_used = unused_tool + else: + # Otherwise evolve from the parent node + logger.info(f"Creating evolve node from Node {self.current_node.id} using {self.current_node.tool_used}.") + parent = self.current_node + tool_used = self.current_node.tool_used + + self.current_node = Node( + parent=parent, + stage="evolve", + tool_used=tool_used, + tools_available=self.available_tools, + time_step=self.time_step, + ) + + # Generate code for the node + self._generate_code() + + def _update_tutorials(self=None): + """ + Retrieve and update tutorials for the current selected tool. + + Args: + node: Node to associate the tutorials with (optional) + """ + # Retrieve tutorials + self.current_node.tutorial_retrieval = self.retriever() + + # Rerank the retrieved tutorials + self.current_node.tutorial_prompt = self.reranker() + + # Save to node's folder + self.save_and_log_states( + content=self.current_node.tutorial_retrieval, + save_name="tutorial_retrievals.txt", + node=self.current_node, + add_uuid=False, + ) + self.save_and_log_states( + content=self.current_node.tutorial_prompt, + save_name="tutorial_prompt.txt", + node=self.current_node, + add_uuid=False, + ) + + def _generate_code(self): + """ + Generate Python and Bash code for the current node after the tool to use is specified. + """ + # Mark this tool as used + self.used_tools.add(self.current_node.tool_used) + + # Get user input for this step if enabled + if ( + self.enable_per_iteration_instruction + ): # TODO: refine the logic to store user inputs (currently they are not in nodes) + self._get_user_input_for_step() + + # Get the tool-specific prompt for the node's selected tool + from ..tools_registry import registry + + tool_info = registry.get_tool(self.current_node.tool_used) + if not tool_info: + print(self.current_node.state) + raise ValueError(f"Tool {self.current_node.tool_used} not found in registry") + + # Get tool-specific prompt + self.tool_prompt = tool_info.get("prompt_template", "") + if isinstance(self.tool_prompt, list): + self.tool_prompt = "\n".join(self.tool_prompt) + + # Get tutorials specific to this node + self._update_tutorials() + + # Generate Python code + self.current_node.python_code = self.python_coder() + + # Write the Python code to a file + python_file_path = os.path.join(self.get_iteration_folder(self.current_node), "generated_code.py") + with open(python_file_path, "w") as file: + file.write(self.current_node.python_code) + + # Generate Bash script + self.current_node.bash_script = self.bash_coder() + + # Write the Bash script to a file + bash_file_path = os.path.join(self.get_iteration_folder(self.current_node), "execution_script.sh") + with open(bash_file_path, "w") as file: + file.write(self.current_node.bash_script) + + def _get_user_input_for_step(self): + """Get user input for the current step.""" + # TODO: refine the logic to store user inputs (currently they are not in nodes) + if self.time_step == -1: + user_input = self.initial_user_input or "" + else: + logger.info(f"Previous iteration info is stored in: {self.get_iteration_folder(self.current_node)}") + user_input = self.initial_user_input or "" + user_input += "\n" + input( + f"Enter your inputs for current node (step {self.time_step}) (press Enter to skip): " + ) + + self.user_inputs.append(user_input) + + def simulate(self) -> tuple: + """ + Simulate execution of current node and evaluate the result. + + Returns: + Tuple containing: (validation_score, is_validated, is_failure) + validation_score: The raw validation score (or None if not available) + is_validated: True if this run has a validation score + is_failure: True if this run failed + """ + # Execute the code + planner_decision, error_summary, validation_score, planner_prompt, stderr, stdout = self.executer( + code_to_execute=self.current_node.bash_script, + code_to_analyze=self.current_node.python_code, + execution_task=self.task_description, + execution_data=self.data_prompt, + ) + + # Store execution results + self.current_node.stdout = stdout + self.current_node.stderr = stderr + + # Save execution outputs + self.save_and_log_states(stderr, "stderr", node=self.current_node, add_uuid=False) + self.save_and_log_states(stdout, "stdout", node=self.current_node, add_uuid=False) + + # Update validation score + self.current_node.validation_score = validation_score + + # Track the best and worst validation scores for scaling in UCT calculation + if validation_score is not None: + # Update best validation score + if self._best_node is None or validation_score > self._best_validation_score: + self._best_node = self.current_node + self._best_validation_score = validation_score + self.best_step = self.time_step + + # Track worst validation score (initialize if not set yet) + if not hasattr(self, "_worst_validation_score") or self._worst_validation_score is None: + self._worst_validation_score = validation_score + else: + self._worst_validation_score = min(self._worst_validation_score, validation_score) + + # Determine if the execution was successful + if planner_decision == "SUCCESS": + self.current_node.is_successful = True + self.last_successful_node = self.current_node + self.last_successful_step = self.time_step + self.current_node.error_message = "" + + # If this is a debug node, find the origin of the debug chain + if self.current_node.stage == "debug": + # Find the original node that started this debugging chain + debug_origin = self._find_debug_origin(self.current_node) + + # Add this successful node as a sibling to the original buggy node + self.current_node.parent.remove_child(self.current_node) + self.current_node.parent = debug_origin.parent + debug_origin.parent.add_child(self.current_node) + + self.mark_node_terminal(debug_origin) + + logger.info( + f"Replaced debug origin node {debug_origin.id} with successful debug node {self.current_node.id}" + ) + + # Return the raw validation score (for tracking), is_validated flag, and is_failure flag + return (validation_score, validation_score is not None, False) + else: + self.current_node.is_successful = False + self.current_node.error_message = f"stderr: {stderr}\n\n" if stderr else "" + self.current_node.error_message += f"Error summary: {error_summary}" + + # Get error analysis + self.current_node.error_analysis = self.error_analyzer() + + self._all_error_analyses.append(self.current_node.error_analysis) + + # If this is a debug node and it failed, check parent's debug attempts + if self.current_node.stage == "debug" and self.current_node.parent: + self.current_node.parent.debug_attempts += 1 + logger.warning( + f"Debug attempt failed. Debug attempts on parent node {self.current_node.parent.id}: {self.current_node.parent.debug_attempts}/{self.max_debug_depth}" + ) + + # If parent has reached max debug attempts, mark it as terminal + if self.current_node.parent.debug_attempts >= self.max_debug_depth: + logger.warning( + f"Parent node {self.current_node.parent.id} has reached the maximum debug depth. Marking as terminal." + ) + self.mark_node_terminal(self.current_node.parent) + + # For failures, we return None score, not validated, and is_failure=True + return (None, False, True) + + def backpropagate(self, simulation_result): + """ + Backpropagate the reward up the tree and update terminal status. + + Args: + simulation_result: Tuple of (validation_score, is_validated, is_failure) + """ + # Extract simulation results + validation_score, is_validated, is_failure = simulation_result + + node = self.current_node + while node is not None: + node.update(validation_score, is_validated, is_failure) + node = node.parent + + def step(self): + """ + Perform one step of the Monte Carlo Tree Search. + + Returns: + True if a successful node was found, False otherwise + """ + # Selection: select a node to expand + self.current_node = self.select_node() + if self.current_node is None: + return None + + # Expansion: create a new child node + # Note: time_step is now incremented in the creation methods + self.expand() + + # Simulation: execute the code and get results + simulation_result = self.simulate() + + # Backpropagation: update node statistics + self.backpropagate(simulation_result) + + return self.current_node.is_successful + + def mark_node_terminal(self, node): + """ + Mark a node and all its descendants as terminal. + Then check if any ancestors should be marked terminal. + + Args: + node: The node to mark as terminal + """ + # Mark the node itself and all descendants as terminal + self._mark_subtree_terminal(node) + + # Check if any ancestors should be marked terminal + self._check_ancestors_terminal(node.parent) + + def _mark_subtree_terminal(self, node): + """ + Recursively mark a node and all its descendants as terminal. + + Args: + node: The node to mark as terminal + """ + if node.is_terminal: + return + + node.is_terminal = True + logger.info(f"Marking node {node.id} as terminal") + + # Recursively mark all children + for child in node.children: + self._mark_subtree_terminal(child) + + def _check_ancestors_terminal(self, node): + """ + Recursively check if ancestors should be marked as terminal. + An ancestor is terminal if fully expanded and all children are terminal. + + Args: + node: The ancestor node to check + """ + if node is None: + return + + if self._is_fully_expanded(node) and all(child.is_terminal for child in node.children): + node.is_terminal = True + logger.info(f"Marking ancestor node {node.id} as terminal (all children terminal)") + + # Continue checking up the tree + self._check_ancestors_terminal(node.parent) + + def _get_all_nodes(self) -> List[Node]: + """ + Get all nodes in the tree. + + Returns: + List of all nodes + """ + all_nodes = [] + + def _collect_nodes(node): + all_nodes.append(node) + for child in node.children: + _collect_nodes(child) + + _collect_nodes(self.root_node) + return all_nodes + + def create_best_run_copy(self): + """Create a 'best_run' folder that copies the best node folder.""" + # Determine which node to copy + target_node = None + copy_reason = "" + + if self._best_node: + target_node = self._best_node + copy_reason = f"best validation score ({self._best_validation_score:.4f})" + elif self.last_successful_node: + target_node = self.last_successful_node + copy_reason = "last successful execution" + else: + logger.warning("No best node or successful node found. Cannot create best_run copy.") + return + + # Create paths + source_folder = self.get_iteration_folder(target_node) + best_run_folder = os.path.join(self.output_folder, "best_run") + + # Verify source folder exists + if not os.path.exists(source_folder): + logger.warning(f"Source folder does not exist: {source_folder}") + return + + # Check if source folder has an 'output' subdirectory + source_output_folder = os.path.join(source_folder, "output") + if not os.path.exists(source_output_folder): + logger.warning(f"Source output folder does not exist: {source_output_folder}") + return + + # Remove existing best_run folder if it exists + if os.path.exists(best_run_folder): + import shutil + + try: + shutil.rmtree(best_run_folder) + logger.info("Removed existing best_run folder") + except Exception as e: + logger.error(f"Failed to remove existing best_run folder: {e}") + return + + try: + # Copy all files from source_output_folder to self.output_folder + import shutil + + for item in os.listdir(source_output_folder): + source_item = os.path.join(source_output_folder, item) + dest_item = os.path.join(self.output_folder, item) + + if os.path.isfile(source_item): + shutil.copy2(source_item, dest_item) + elif os.path.isdir(source_item): + shutil.copytree(source_item, dest_item, dirs_exist_ok=True) + + # Copy the entire source folder to best_run folder + shutil.copytree(source_folder, best_run_folder, dirs_exist_ok=True) + + logger.info(f"Created best_run folder (copied from node {target_node.id} - {copy_reason})") + + # Save summary information in the best_run folder + summary_content = [ + "Best Run Summary", + "================", + f"Copied from: node_{target_node.id}", + f"Reason: {copy_reason}", + f"Tool used: {target_node.tool_used}", + f"Copy created at: {os.path.basename(best_run_folder)}", + "", + self.get_validation_score_summary(), + "", + "Tool Usage Summary:", + "==================", + f"Available tools: {', '.join(self.available_tools)}", + f"Tools used: {', '.join(self.used_tools)}", + f"Tools not used: {', '.join(set(self.available_tools) - self.used_tools)}", + ] + + # Save summary in both the main output folder and the best_run folder + summary_text = "\n".join(summary_content) + + self.save_and_log_states( + content=summary_text, save_name="best_run_summary.txt", node=target_node, add_uuid=False + ) + + except Exception as e: + logger.error(f"Failed to copy folder: {e}") + + def get_validation_score_summary(self) -> str: + """ + Get a summary of all validation scores. + + Returns: + A summary string + """ + all_nodes = self._get_all_nodes() + nodes_with_scores = [node for node in all_nodes if node.validation_score is not None] + + if not nodes_with_scores: + return "No validation scores available." + + summary = ["Validation Score Summary:"] + for node in nodes_with_scores: + marker = " (BEST)" if node == self._best_node else "" + summary.append(f"Node {node.id} ({node.tool_used}): {node.validation_score}{marker}") + + if self._best_node: + summary.append( + f"\nBest score: {self._best_validation_score:.4f} from node {self._best_node.id} using {self._best_node.tool_used}" + ) + + return "\n".join(summary) + + def cleanup(self): + """Clean up resources.""" + if hasattr(self, "retriever"): + self.retriever.cleanup() + + def _find_debug_origin(self, node: Node) -> Optional[Node]: + """ + Find the original node that started this debugging chain. + + Args: + node: The current node in the debug chain + + Returns: + The original node that started the debug chain + """ + # Go up the tree until we find a non-debug node + current = node + while current.parent and current.parent.stage == "debug": + current = current.parent + + debug_origin = current.parent + assert not debug_origin.is_successful + + return debug_origin + + def __del__(self): + """Destructor to ensure cleanup.""" + self.cleanup() + + def visualize_results(self, output_path: Optional[str] = None) -> str: + """ + Generate a PDF visualization of the node structure. + + Args: + output_path: Path to save the PDF. If not provided, it will be saved to + the output folder. + + Returns: + The path to the generated PDF file + """ + from .node_visualizer import visualize_results + + return visualize_results(self, output_path) + + def report_token_usage(self): + token_usage_path = os.path.join(self.output_folder, "token_usage.json") + usage = ChatLLMFactory.get_total_token_usage(save_path=token_usage_path) + total = usage["total"] + logger.brief( + f"Total tokens — input: {total['total_input_tokens']}, " + f"output: {total['total_output_tokens']}, " + f"sum: {total['total_tokens']}" + ) + + logger.info(f"Full token usage detail:\n{usage}") + + def compute_uct_value(self, node): + return node.uct_value( + self.exploration_constant, + self._best_validation_score, + self._worst_validation_score, + failure_offset=self.failure_offset, + failure_penalty_weight=self.failure_penalty_weight, + ) + + # Properties to maintain compatibility with Manager API + @property + def user_input(self) -> str: + """Get the user input for the current step.""" + if self.time_step < 0 or self.time_step >= len(self.user_inputs): + return "" + return self.user_inputs[self.time_step] + + @property + def best_validation_score(self) -> float: + """Get the best validation score.""" + return self._best_validation_score if self._best_validation_score is not None else 0.0 + + @property + def best_node(self) -> Node: + """Get the best node.""" + return self._best_node + + @property + def python_code(self) -> str: + """Get the Python code from the current node.""" + return self.current_node.python_code if self.current_node else "" + + @property + def python_file_path(self) -> str: + """Get the Python file path for the current node.""" + if not self.current_node: + return "" + return os.path.join(self.get_iteration_folder(self.current_node), "generated_code.py") + + @property + def previous_python_code(self) -> str: + """Get the Python code from the previous node.""" + if self.current_node and self.current_node.parent: + return self.current_node.parent.python_code + return "" + + @property + def bash_script(self) -> str: + """Get the Bash script from the current node.""" + return self.current_node.bash_script if self.current_node else "" + + @property + def previous_bash_script(self) -> str: + """Get the Bash script from the previous node.""" + if self.current_node and self.current_node.parent: + return self.current_node.parent.bash_script + return "" + + @property + def error_message(self) -> str: + """Get the error message from the current node.""" + return self.current_node.error_message if self.current_node else "" + + @property + def previous_error_message(self) -> str: + """Get the error message from the previous node.""" + if self.current_node and self.current_node.parent: + return self.current_node.parent.error_message + return "" + + @property + def error_analysis(self) -> str: + """Get the error analysis from the current node.""" + return self.current_node.error_analysis if self.current_node else "" + + @property + def previous_error_analysis(self) -> str: + """Get the error analysis from the previous node.""" + if self.current_node and self.current_node.parent: + return self.current_node.parent.error_analysis + return "" + + @property + def all_previous_error_analyses(self) -> str: + """Get all error analyses from previous nodes.""" + # TODO: make this recursive, handle debugging code and successful ones differently + return "\n\n".join(self._all_error_analyses) + + if not self.current_node: + return "" + + analyses = [] + node = self.current_node + while node.parent: + node = node.parent + if node.error_analysis: + analyses.append(node.error_analysis) + + return "\n\n".join(analyses) + + @property + def per_iteration_output_folder(self) -> str: + """Get the output folder for the current iteration.""" + if not self.current_node: + return os.path.join(self.output_folder, "initialization", "output") + return self.get_per_iteration_output_folder(self.current_node) + + @property + def iteration_folder(self) -> str: + """Get the folder for the current iteration.""" + if not self.current_node: + return os.path.join(self.output_folder, "initialization") + return self.get_iteration_folder(self.current_node) + + @property + def tutorial_retrieval(self) -> str: + """Get the tutorial retrieval for the current step.""" + if self.current_node: + return self.current_node.tutorial_retrieval + else: + logger.warning("Invalid node while asking for tutorial_retrieval") + + @property + def tutorial_prompt(self) -> str: + """Get the tutorial prompt for the current step.""" + return self.current_node.tutorial_prompt if self.current_node else "" + + @property + def previous_tutorial_prompt(self) -> str: + """Get the tutorial prompt from the previous step.""" + return self.current_node.prev_tutorial_prompt + + @property + def common_env_file(self) -> str: + return registry.registry_path / "_common" / "requirements.txt" + + @property + def selected_tool(self) -> str: + return self.current_node.tool_used + + @property + def selected_tool_env_file(self) -> str: + tool_path = registry.get_tool(self.selected_tool)["path"] + return registry.registry_path / tool_path / "requirements.txt" + + @property + def configure_env( + self, + ): + if self.selected_tool.lower() in ["machine learning", "huggingface", "fairseq"]: + return True + else: + return self.config.configure_env + + @property + def code_to_improve( + self, + ): + if self.current_node.stage == "evolve": + return self.current_node.parent.python_code + else: + return None + + @property + def code_to_debug( + self, + ): + if self.current_node.stage == "debug": + return self.current_node.parent.python_code + else: + return None diff --git a/src/autogluon/assistant/managers/node_visualizer.py b/src/autogluon/assistant/managers/node_visualizer.py new file mode 100644 index 00000000..15eae5c3 --- /dev/null +++ b/src/autogluon/assistant/managers/node_visualizer.py @@ -0,0 +1,483 @@ +""" +Node visualization utility for Node Manager. + +This module provides functionality to generate a PDF visualization of the node structures +created during the Monte Carlo Tree Search process in the NodeManager. The visualization +includes a tree view of nodes with their basic information and states. +""" + +import logging +import os +from typing import Dict, List, Optional, Tuple + +from reportlab.lib import colors +from reportlab.lib.pagesizes import landscape, letter +from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet +from reportlab.platypus import Flowable, PageBreak, Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle + +from .node_manager import Node, NodeManager + +# For compatibility with updated node_manager.py that stores error analyses differently + +logger = logging.getLogger(__name__) + + +class NodeTree(Flowable): + """A Flowable for drawing the node tree structure.""" + + def __init__(self, root_node: Node, node_info: Dict[int, Dict], width=700, height=500): + Flowable.__init__(self) + self.root_node = root_node + self.node_info = node_info + self.width = width + self.height = height + + def draw(self): + """Draw the node tree.""" + # Calculate positions for all nodes + positions = self._calculate_positions(self.root_node) + + # Draw connections first (so they appear behind nodes) + self._draw_connections(positions) + + # Draw nodes + for node_id, (x, y) in positions.items(): + self._draw_node(node_id, x, y) + + def _calculate_positions(self, root: Node) -> Dict[int, Tuple[float, float]]: + """Calculate positions for all nodes in the tree.""" + positions = {} + max_depth = self._get_max_depth(root) + + # Get nodes by level to calculate width distribution + nodes_by_level = [[] for _ in range(max_depth + 1)] + + def collect_by_level(node, level=0): + if len(nodes_by_level) <= level: + nodes_by_level.append([]) + nodes_by_level[level].append(node) + for child in node.children: + collect_by_level(child, level + 1) + + collect_by_level(root) + + # Level height - ensure enough vertical space between levels + level_height = self.height / (max_depth + 2) + + # Create a mapping to track node positions + for level, level_nodes in enumerate(nodes_by_level): + num_nodes = len(level_nodes) + if num_nodes == 0: + continue + + # Use the full width with margins + usable_width = self.width * 0.9 + margin = (self.width - usable_width) / 2 + + # Calculate spacing between nodes at this level + if num_nodes == 1: + # If only one node at this level, center it + x = self.width / 2 + y = self.height - (level * level_height) - (level_height / 2) + positions[level_nodes[0].id] = (x, y) + else: + # Distribute nodes evenly across the usable width + spacing = usable_width / (num_nodes - 1) if num_nodes > 1 else usable_width + + for i, node in enumerate(level_nodes): + x = margin + (i * spacing) + y = self.height - (level * level_height) - (level_height / 2) + positions[node.id] = (x, y) + + return positions + + return positions + + def _get_max_depth(self, root: Node) -> int: + """Get the maximum depth of the tree.""" + if not root.children: + return 0 + + return 1 + max(self._get_max_depth(child) for child in root.children) + + def _draw_connections(self, positions: Dict[int, Tuple[float, float]]): + """Draw connections between nodes.""" + for node_id, (x, y) in positions.items(): + node_data = self.node_info.get(node_id, {}) + parent_id = node_data.get("parent_id") + + if parent_id is not None and parent_id in positions: + parent_x, parent_y = positions[parent_id] + + # Draw line from parent to child + self.canv.setStrokeColor(colors.grey) + self.canv.setLineWidth(0.5) + self.canv.line(parent_x, parent_y, x, y) + + def _draw_node(self, node_id: int, x: float, y: float): + """Draw a single node.""" + # Node info + node_data = self.node_info.get(node_id, {}) + + # Determine node color based on status + if node_data.get("is_successful", False): + fill_color = colors.lightgreen + elif node_data.get("error_message"): + fill_color = colors.lightcoral + else: + fill_color = colors.lightblue + + # Draw node circle - using much smaller radius + radius = 12 + self.canv.setFillColor(fill_color) + self.canv.setStrokeColor(colors.black) + self.canv.circle(x, y, radius, stroke=1, fill=1) + + # Draw node ID - smaller font + self.canv.setFillColor(colors.black) + self.canv.setFont("Helvetica", 8) + self.canv.drawCentredString(x, y - 2, str(node_id)) + + # Standard PDF link approach + + # Add clickable area with larger detection area for easier clicking + click_padding = 4 + # Use internal PDF links + self.canv.linkURL( + f"#node_{node_id}", + ( + x - radius - click_padding, + y - radius - click_padding, + x + radius + click_padding, + y + radius + click_padding, + ), + relative=0, + ) + + +class NodeVisualizer: + """ + A utility class for generating PDF visualizations of node structures. + """ + + def __init__(self, node_manager: NodeManager): + """ + Initialize the NodeVisualizer. + + Args: + node_manager: The NodeManager instance + """ + self.node_manager = node_manager + self.styles = getSampleStyleSheet() + + # Create custom styles + self.styles.add(ParagraphStyle("NodeTitle", parent=self.styles["Heading1"], fontSize=14, spaceAfter=10)) + self.styles.add(ParagraphStyle("NodeInfo", parent=self.styles["Normal"], fontSize=10, spaceAfter=5)) + self.styles.add(ParagraphStyle("ErrorText", parent=self.styles["Normal"], textColor=colors.red, fontSize=10)) + self.styles.add( + ParagraphStyle( + "NavInstruction", + parent=self.styles["Normal"], + fontName="Helvetica-Oblique", + fontSize=10, + textColor=colors.darkblue, + ) + ) + + def _create_node_info_dict(self, node: Node) -> Dict: + """ + Create a dictionary with node information. + + Args: + node: The node to extract information from + + Returns: + A dictionary with node information + """ + return { + "id": node.id, + "stage": node.stage, + "time_step": node.time_step, + "depth": node.depth, + "is_successful": node.is_successful, + "is_terminal": node.is_terminal, + "tool_used": node.tool_used, + "debug_attempts": node.debug_attempts, + "error_message": node.error_message, + "error_analysis": node.error_analysis, + "validation_score": node.validation_score, + "parent_id": node.parent.id if node.parent else None, + "child_ids": [child.id for child in node.children], + } + + def _get_all_nodes(self) -> List[Node]: + """Get all nodes in the tree.""" + return self.node_manager._get_all_nodes() + + def _create_node_summary(self, node: Node) -> List[Flowable]: + """ + Create a summary of the node. + + Args: + node: The node to create a summary for + + Returns: + A list of flowables for the PDF + """ + elements = [] + + # Create a bookmark/anchor for this node + # Create anchor for internal links + elements.append(Paragraph(f'', self.styles["Normal"])) + + # Node title + elements.append(Paragraph(f"Node {node.id} ({node.stage})", self.styles["NodeTitle"])) + + # Dynamically get all properties of the node (except special ones like parent/children which need special handling) + node_attrs = {} + # Extended special_attrs to exclude large text content that would break the PDF layout + special_attrs = { + "parent", + "children", + "error_message", + "error_analysis", + "_lock", + "python_code", + "bash_script", + "stdout", + "stderr", + "tutorial_retrieval", + "tutorial_prompt", + } + for attr_name in dir(node): + if (not attr_name.startswith("_") or attr_name == "_lock") and attr_name not in special_attrs: + try: + attr_value = getattr(node, attr_name) + # Skip methods + if not callable(attr_value): + node_attrs[attr_name] = attr_value + except Exception as e: + node_attrs[attr_name] = f"" + + # Create a table for node properties + property_data = [["Property", "Value"]] + property_data.append(["uct_value", f"{self.node_manager.compute_uct_value(node):.4f}"]) + + # Helper function to truncate long strings + def truncate_string(s, max_len=500): + if isinstance(s, str) and len(s) > max_len: + return s[:max_len] + f" ... [truncated, {len(s)} chars total]" + return s + + # Add all properties to the table + for prop_name, prop_value in sorted(node_attrs.items()): + # Format special cases + if prop_name == "validation_score" and prop_value is not None: + prop_value = f"{prop_value:.4f}" + elif prop_name == "ctime": + from datetime import datetime + + prop_value = datetime.fromtimestamp(prop_value).strftime("%Y-%m-%d %H:%M:%S") + elif isinstance(prop_value, list) or isinstance(prop_value, set): + # Truncate list/set representations if they're too long + list_str = str(list(prop_value)) + prop_value = truncate_string(list_str) + elif isinstance(prop_value, bool): + prop_value = "✓" if prop_value else "✗" + elif isinstance(prop_value, str): + # Truncate any string values + prop_value = truncate_string(prop_value) + + property_data.append([prop_name, str(prop_value)]) + + # Add parent and children info + property_data.append(["parent_id", str(node.parent.id) if node.parent else "None"]) + property_data.append(["child_ids", str([child.id for child in node.children]) if node.children else "[]"]) + + # Create and style the table + props_table = Table(property_data, colWidths=[150, 350]) + props_table.setStyle( + TableStyle( + [ + ("BACKGROUND", (0, 0), (1, 0), colors.lightgrey), + ("TEXTCOLOR", (0, 0), (1, 0), colors.black), + ("ALIGN", (0, 0), (0, -1), "LEFT"), + ("ALIGN", (1, 0), (1, -1), "LEFT"), + ("FONTNAME", (0, 0), (1, 0), "Helvetica-Bold"), + ("BOTTOMPADDING", (0, 0), (-1, 0), 6), + ("GRID", (0, 0), (-1, -1), 0.5, colors.grey), + ("VALIGN", (0, 0), (-1, -1), "MIDDLE"), + ] + ) + ) + + elements.append(props_table) + elements.append(Spacer(1, 10)) + + # Error information if failed + if not node.is_successful and node.error_message: + elements.append(Paragraph("Error Message:", self.styles["Heading3"])) + # Truncate long error messages for better PDF layout + truncated_message = node.error_message + if len(truncated_message) > 2000: + truncated_message = ( + truncated_message[:2000] + f"... [truncated, {len(node.error_message)} chars total]" + ) + + elements.append(Paragraph(truncated_message.replace("\n", "
"), self.styles["ErrorText"])) + + # Error analysis if available + if node.error_analysis: + elements.append(Paragraph("Error Analysis:", self.styles["Heading3"])) + # Truncate long error analyses for better PDF layout + truncated_analysis = node.error_analysis + if len(truncated_analysis) > 2000: + truncated_analysis = ( + truncated_analysis[:2000] + f"... [truncated, {len(node.error_analysis)} chars total]" + ) + + elements.append(Paragraph(truncated_analysis.replace("\n", "
"), self.styles["Normal"])) + + elements.append(Spacer(1, 20)) + return elements + + def visualize_nodes(self, output_path: Optional[str] = None) -> str: + """ + Generate a PDF visualization of the node structure. + + Args: + output_path: Path to save the PDF. If not provided, it will be saved to + the node manager's output folder. + + Returns: + The path to the generated PDF file + """ + # Set default output path + if output_path is None: + output_path = os.path.join(self.node_manager.output_folder, "node_visualization.pdf") + + # Get all nodes + all_nodes = self._get_all_nodes() + + # Create PDF document + doc = SimpleDocTemplate( + output_path, pagesize=landscape(letter), topMargin=20, bottomMargin=20, leftMargin=20, rightMargin=20 + ) + + # Main elements for the document + elements = [] + + # Add title with hyperlink guide + elements.append(Paragraph("Node Structure Visualization", self.styles["Heading1"])) + elements.append( + Paragraph( + "Click on any node in the tree to jump to its detailed information.", self.styles["NavInstruction"] + ) + ) + + # Add summary statistics + best_score_text = ( + f"Best Validation Score: {self.node_manager.best_validation_score:.4f} (Node {self.node_manager.best_step})" + if self.node_manager.best_validation_score > 0 + else "No validation scores available" + ) + + elements.append(Paragraph(f"Total Nodes: {len(all_nodes)} | {best_score_text}", self.styles["Normal"])) + elements.append(Spacer(1, 10)) + + # Create a tree visualization + elements.append(Paragraph("Node Tree:", self.styles["Heading2"])) + elements.append(Spacer(1, 10)) + + # Create node info dictionary for the tree visualization + node_info = {node.id: self._create_node_info_dict(node) for node in all_nodes} + + # Add the tree visualization - wider and taller to accommodate more nodes + elements.append(NodeTree(self.node_manager.root_node, node_info, width=750, height=400)) + elements.append(Spacer(1, 20)) + + # Add legend + legend_data = [["Status", "Color"], ["Success", "Green"], ["Failure", "Red"], ["Neutral", "Blue"]] + legend_table = Table(legend_data, colWidths=[100, 100]) + legend_table.setStyle( + TableStyle( + [ + ("BACKGROUND", (0, 0), (1, 0), colors.grey), + ("TEXTCOLOR", (0, 0), (1, 0), colors.whitesmoke), + ("BACKGROUND", (1, 1), (1, 1), colors.lightgreen), + ("BACKGROUND", (1, 2), (1, 2), colors.lightcoral), + ("BACKGROUND", (1, 3), (1, 3), colors.lightblue), + ("ALIGN", (0, 0), (-1, -1), "CENTER"), + ("FONTNAME", (0, 0), (1, 0), "Helvetica-Bold"), + ("GRID", (0, 0), (-1, -1), 0.5, colors.grey), + ] + ) + ) + elements.append(legend_table) + elements.append(Spacer(1, 30)) + + # Add individual node details + elements.append(PageBreak()) + elements.append(Paragraph("Node Details:", self.styles["Heading2"])) + elements.append(Spacer(1, 20)) + + # Sort nodes by ID for easier navigation + all_nodes = sorted(all_nodes, key=lambda n: n.id) + + # Add node details + for node in all_nodes: + elements.extend(self._create_node_summary(node)) + if node != all_nodes[-1]: # No page break after the last node + elements.append(PageBreak()) + + # Create a table of contents + elements.insert(5, Paragraph("Table of Contents:", self.styles["Heading2"])) + toc_data = [["Node ID", "Stage", "Status", "Tool"]] + + # Add entry for each node to the TOC with hyperlinks + for node in all_nodes: + status = "✓" if node.is_successful else ("✗" if node.error_message else "") + toc_data.append( + [f"{node.id}", node.stage, status, node.tool_used or ""] + ) + + # Create the TOC table + toc_table = Table(toc_data, colWidths=[50, 70, 50, 200]) + toc_table.setStyle( + TableStyle( + [ + ("BACKGROUND", (0, 0), (-1, 0), colors.lightgrey), + ("TEXTCOLOR", (0, 0), (-1, 0), colors.black), + ("ALIGN", (0, 0), (-1, -1), "LEFT"), + ("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"), + ("BOTTOMPADDING", (0, 0), (-1, 0), 6), + ("GRID", (0, 0), (-1, -1), 0.5, colors.grey), + ("VALIGN", (0, 0), (-1, -1), "MIDDLE"), + ] + ) + ) + + elements.insert(6, toc_table) + elements.insert(7, Spacer(1, 20)) + + # Build the document + doc.build(elements) + logger.info(f"Node visualization generated at: {output_path}") + + return output_path + + +def visualize_results(node_manager: NodeManager, output_path: Optional[str] = None) -> str: + """ + Generate a PDF visualization of the node structure. + + Args: + node_manager: The NodeManager instance + output_path: Path to save the PDF. If not provided, it will be saved to + the node manager's output folder. + + Returns: + The path to the generated PDF file + """ + visualizer = NodeVisualizer(node_manager) + return visualizer.visualize_nodes(output_path) diff --git a/src/autogluon/assistant/prompts/bash_coder_prompt.py b/src/autogluon/assistant/prompts/bash_coder_prompt.py index 5bb07907..eea20ab3 100644 --- a/src/autogluon/assistant/prompts/bash_coder_prompt.py +++ b/src/autogluon/assistant/prompts/bash_coder_prompt.py @@ -36,7 +36,7 @@ def default_template(self) -> str: {python_code} ### Previous Error -{all_error_analyses} +{all_previous_error_analyses} ### Previous failed bash script: {previous_bash_script} @@ -91,7 +91,7 @@ def parse(self, response: Dict) -> Tuple[str, Optional[str]]: return extracted_bash_script def get_env_prompt(self): - create_venv = self.manager.config.create_venv + configure_env = self.manager.configure_env iteration_folder = self.manager.iteration_folder selected_tool = self.manager.selected_tool common_env_file = self.manager.common_env_file @@ -104,7 +104,7 @@ def get_env_prompt(self): - pip install uv - Install required packages from {common_env_file} and {selected_tool_env_file} using uv pip install -r {selected_tool_env_file} -r {common_env_file}""" - if not create_venv: + if not configure_env: env_prompt += f"\n - Do not install or update any package unless there is an error due to the missing package.\n - Do NOT upgrade {selected_tool} which is already installed." else: env_prompt += "\n - Install any packages that are needed in the python script" diff --git a/src/autogluon/assistant/prompts/error_analyzer_prompt.py b/src/autogluon/assistant/prompts/error_analyzer_prompt.py index fae619e7..b0648eb1 100644 --- a/src/autogluon/assistant/prompts/error_analyzer_prompt.py +++ b/src/autogluon/assistant/prompts/error_analyzer_prompt.py @@ -35,7 +35,7 @@ def default_template(self) -> str: SUGGESTED_FIX: [Specific debugging directions in 1-3 sentences without code] ### Error Message -{previous_error_message_truncate_mid_8192} +{error_message_truncate_mid_8192} ### Task Description {task_description} @@ -47,13 +47,13 @@ def default_template(self) -> str: {user_input} ### Previous Python Code: -{previous_python_code} +{python_code} ### Previous Bash Script to Execute the Python Code: -{previous_bash_script} +{bash_script} ### Relevant Tutorials -{previous_tutorial_prompt} +{tutorial_prompt} """ def _build(self, **kwargs) -> str: diff --git a/src/autogluon/assistant/prompts/executer_prompt.py b/src/autogluon/assistant/prompts/executer_prompt.py index 22161682..825d0ed6 100644 --- a/src/autogluon/assistant/prompts/executer_prompt.py +++ b/src/autogluon/assistant/prompts/executer_prompt.py @@ -41,11 +41,11 @@ def default_template(self) -> str: ## Execution Results ### Standard Output (stdout) -{stdout_truncate_mid_8192} +{stdout_truncate_start_8192} ### Standard Error (stderr) -{stderr_truncate_mid_8192} +{stderr_truncate_start_8192} Evaluate the execution results and decide on one of the following actions: 1. SUCCESS - If the execution was completely successful and met all requirements. @@ -60,7 +60,7 @@ def default_template(self) -> str: Even if the code executed without throwing errors, it might still have issues with logic or not meet all requirements. For validation scores: -- If there is a validation score present in the execution results, extract it +- If there is a validation score present in the execution results, extract it (e.g. the last validation score reported in the training process). - Convert the score to ensure higher values indicate better performance (multiply "lower is better" metrics like RMSE, MAE, or loss by -1) - Return the converted score that follows the "higher is better" convention""" diff --git a/src/autogluon/assistant/prompts/python_coder_prompt.py b/src/autogluon/assistant/prompts/python_coder_prompt.py index e385e420..2bf88853 100644 --- a/src/autogluon/assistant/prompts/python_coder_prompt.py +++ b/src/autogluon/assistant/prompts/python_coder_prompt.py @@ -60,7 +60,7 @@ def default_template(self) -> str: {tool_prompt} -{best_code_prompt} +{code_improvement_prompt} Please provide the complete Python script that accomplishes these tasks, ensuring it's ready to run given the appropriate data inputs. @@ -74,7 +74,8 @@ def default_template(self) -> str: {user_input_truncate_end_2048} ### Previous Errors -{all_error_analyses} +These errors were encountered across different implementation approaches and may not be directly related to your current implementation. Use them as reference material to identify potential pitfalls and avoid similar mistakes in your implementation. +{all_previous_error_analyses} ### Tutorials for Reference {tutorial_prompt} @@ -93,12 +94,12 @@ def _build(self, **kwargs) -> str: assert self.manager.time_step >= 0, "run manager.step(user_input) before retrieving the prompt" # Generate best code prompt and validation prompt - best_code_prompt = self._generate_best_code_prompt() + code_improvement_prompt = self._generate_code_improvement_prompt() validation_prompt = self._generate_validation_prompt() # Render the prompt using the variable provider with additional variables additional_vars = { - "best_code_prompt": best_code_prompt, # Dynamically generated + "code_improvement_prompt": code_improvement_prompt, # Dynamically generated "validation_prompt": validation_prompt, # Dynamically generated } @@ -143,51 +144,32 @@ def _generate_system_resources_prompt(self) -> str: Please optimize your code to efficiently utilize the available hardware resources. """ - def _generate_best_code_prompt(self) -> str: + def _generate_code_improvement_prompt(self) -> str: """Generate prompt section about best/successful previous code.""" if self.manager.time_step == 0: return "" # No previous code on first iteration - best_code_prompt = [] - - # Check if we have a best step with validation score - if self.manager.best_step >= 0 and self.manager.best_step < self.manager.time_step: - best_code = self.manager.python_codes[self.manager.best_step] - best_score = self.manager.val_scores[self.manager.best_step] - - best_code_prompt.append("### Previous Best Code") - best_code_prompt.append( - f"The following code achieved the best validation score so far ({best_score:.4f}):" - ) - best_code_prompt.append("```python") - best_code_prompt.append(best_code) - best_code_prompt.append("```") - best_code_prompt.append("") - best_code_prompt.append( - "Please prioritize model architecture improvements and training optimization to enhance performance. Feature engineering may also be applied but with lower priority." - ) - if self.manager.config.optimize_system_resources: - best_code_prompt.append(self._generate_system_resources_prompt()) - # Check if we have a last successful step (different from best step) - elif self.manager.last_successful_step >= 0 and self.manager.last_successful_step < self.manager.time_step: - successful_code = self.manager.python_codes[self.manager.last_successful_step] - - best_code_prompt.append("### Previous Successful Code") - best_code_prompt.append("The following code executed successfully:") - best_code_prompt.append("```python") - best_code_prompt.append(successful_code) - best_code_prompt.append("```") - best_code_prompt.append("") - best_code_prompt.append( - "Please prioritize model architecture improvements and training optimization to enhance performance. Feature engineering may also be applied but with lower priority." - ) - if self.manager.config.optimize_system_resources: - best_code_prompt.append(self._generate_system_resources_prompt()) - # Do nothing if there's no successful code + if self.manager.code_to_improve: + code_improvement_prompt = f"""### Previous Code to Improve +```python +{self.manager.code_to_improve} +``` +Please prioritize model architecture improvements and training optimization to enhance performance. Feature engineering may also be applied but with lower priority. +""" + elif self.manager.code_to_debug: + code_improvement_prompt = f"""### Previous Code to Debug +```python +{self.manager.code_to_debug} +``` +Please fix the errors in the code above. Make minimal changes necessary to fix the issues. +""" else: - best_code_prompt = [] + code_improvement_prompt = "" + + if self.manager.config.optimize_system_resources: + code_improvement_prompt += self._generate_system_resources_prompt() - return "\n".join(best_code_prompt) + return code_improvement_prompt def parse(self, response: Dict) -> Tuple[str, Optional[str]]: """Parse the LLM's response to generated python code""" diff --git a/src/autogluon/assistant/prompts/reranker_prompt.py b/src/autogluon/assistant/prompts/reranker_prompt.py index f9a2390d..149865c9 100644 --- a/src/autogluon/assistant/prompts/reranker_prompt.py +++ b/src/autogluon/assistant/prompts/reranker_prompt.py @@ -82,7 +82,7 @@ def default_template(self) -> str: {user_input} ### Previous Error Analysis -{all_error_analyses} +{all_previous_error_analyses} Available Tutorials: {tutorials_info} diff --git a/src/autogluon/assistant/prompts/retriever_prompt.py b/src/autogluon/assistant/prompts/retriever_prompt.py index eec621c1..27c62832 100644 --- a/src/autogluon/assistant/prompts/retriever_prompt.py +++ b/src/autogluon/assistant/prompts/retriever_prompt.py @@ -39,7 +39,7 @@ def default_template(self) -> str: {user_input} ### Previous Error Analysis -{all_error_analyses} +{all_previous_error_analyses} ### Selected Tool/Library {selected_tool} diff --git a/src/autogluon/assistant/prompts/tool_selector_prompt.py b/src/autogluon/assistant/prompts/tool_selector_prompt.py index 3dba73ae..cdafc3c0 100644 --- a/src/autogluon/assistant/prompts/tool_selector_prompt.py +++ b/src/autogluon/assistant/prompts/tool_selector_prompt.py @@ -1,7 +1,8 @@ import logging import re -from typing import Dict, Tuple +from typing import Dict, List, Union +from ..constants import DEFAULT_LIBRARY from ..tools_registry import registry from .base_prompt import BasePrompt @@ -49,7 +50,7 @@ def meta_instructions(cls) -> str: def default_template(self) -> str: """Default template for tool selection""" return """ -You are a data science expert tasked with selecting the most appropriate ML library for a specific task. +You are a data science expert tasked with selecting and ranking the most appropriate ML libraries for a specific task. ### Task Description: {task_description} @@ -62,15 +63,22 @@ def default_template(self) -> str: IMPORTANT: Your response MUST follow this exact format: --- -SELECTED_LIBRARY: -EXPLANATION: +EXPLANATION: + +RANKED_LIBRARIES: +1. +2. +3. +... --- Requirements for your response: -1. The SELECTED_LIBRARY must be exactly as shown in the available libraries list -2. Use the exact headers "SELECTED_LIBRARY:" and "EXPLANATION:" -3. Provide a clear, detailed explanation of why this library is the best choice -4. Consider the task requirements, data characteristics, and library features +1. First provide a detailed explanation of your reasoning process using the "EXPLANATION:" header +2. Then provide a ranking of libraries using the "RANKED_LIBRARIES:" header +3. The library names must be exactly as shown in the available libraries list +4. Provide a ranking of at least 3 libraries (if available) +5. In your explanation, analyze each library's strengths and weaknesses for this specific task +6. Consider the task requirements, data characteristics, and library features Do not include any other formatting or additional sections in your response. """ @@ -93,7 +101,7 @@ def _build(self, **kwargs) -> str: return prompt - def parse(self, response: str) -> Tuple[str, str]: + def parse(self, response: str) -> Union[List[str], str]: """ Parse the library selection response from LLM with improved robustness. @@ -101,63 +109,88 @@ def parse(self, response: str) -> Tuple[str, str]: response: The raw response from the LLM Returns: - Tuple[str, str]: (selected_tool, explanation) + Union[List[str], str]: Either a prioritized list of tools or a single tool name """ - # Default values - selected_tool = "" - explanation = "" - # Clean the response response = response.strip() - # Try different parsing strategies - # Strategy 1: Look for exact headers - selected_library_match = re.search(r"SELECTED_LIBRARY:[\s]*(.+?)(?:\n|$)", response, re.IGNORECASE) + # Extract explanation first explanation_match = re.search( - r"EXPLANATION:[\s]*(.+?)(?=SELECTED_LIBRARY:|$)", response, re.IGNORECASE | re.DOTALL + r"EXPLANATION:[\s]*(.+?)(?=RANKED_LIBRARIES:|$)", response, re.IGNORECASE | re.DOTALL ) - # Strategy 2: Fallback to more lenient parsing - if not selected_library_match: - selected_library_match = re.search( - r"(?:selected|chosen|recommended).*?(?:library|tool):[\s]*(.+?)(?:\n|$)", response, re.IGNORECASE - ) - if not explanation_match: explanation_match = re.search( - r"(?:explanation|reasoning|rationale):[\s]*(.+?)(?=$)", response, re.IGNORECASE | re.DOTALL + r"(?:explanation|reasoning|rationale):[\s]*(.+?)(?=RANKED_LIBRARIES:|ranking|ranked|prioritized|priority|$)", + response, + re.IGNORECASE | re.DOTALL, ) - # Extract and clean the matches - if selected_library_match: - selected_tool = selected_library_match.group(1).strip() - if explanation_match: - explanation = explanation_match.group(1).strip() + explanation = ( + explanation_match.group(1).strip() if explanation_match else "No explanation provided by the model." + ) - # Validate against available tools - available_tools = set(registry.tools.keys()) - if selected_tool and selected_tool not in available_tools: - # Try to find the closest match - closest_match = min(available_tools, key=lambda x: len(set(x.lower()) ^ set(selected_tool.lower()))) - logger.warning( - f"Selected tool '{selected_tool}' not in available tools. " f"Using closest match: '{closest_match}'" - ) - selected_tool = closest_match + # Strategy 1: Look for ranked libraries section + ranked_libraries_section = re.search(r"RANKED_LIBRARIES:(.*?)$", response, re.IGNORECASE | re.DOTALL) - # Final validation - if not selected_tool: - logger.error("Failed to extract selected tool from LLM response") - selected_tool = list(registry.tools.keys())[0] # Default to first available tool - logger.warning(f"Defaulting to: {selected_tool}") + # Strategy 2: Fallback to more lenient parsing + if not ranked_libraries_section: + ranked_libraries_section = re.search( + r"(?:ranking|ranked|prioritized|priority).*?(?:libraries|tools):(.*?)$", + response, + re.IGNORECASE | re.DOTALL, + ) - if not explanation: - logger.error("Failed to extract explanation from LLM response") - explanation = "No explanation provided by the model." + # Parse the ranked libraries + prioritized_tools = [] + + if ranked_libraries_section: + # Get the list section + ranked_section = ranked_libraries_section.group(1).strip() + + # Try to find numbered list items + list_items = re.findall(r"^\s*\d+\.\s*(.+?)$", ranked_section, re.MULTILINE) + + if list_items: + # Found a numbered list + for item in list_items: + tool_name = item.strip() + if tool_name: + prioritized_tools.append(tool_name) + else: + # Try to find bullet points or just lines + list_items = re.findall(r"(?:^|\n)\s*(?:[-*•])?\s*(.+?)(?:$|\n)", ranked_section) + for item in list_items: + tool_name = item.strip() + if tool_name: + prioritized_tools.append(tool_name) + + # Validate against available tools and clean up + available_tools = set(registry.tools.keys()) + validated_tools = [] + + for tool in prioritized_tools: + if tool in available_tools: + validated_tools.append(tool) + else: + # Try to find the closest match + closest_match = min(available_tools, key=lambda x: len(set(x.lower()) ^ set(tool.lower()))) + logger.warning(f"Tool '{tool}' not in available tools. Using closest match: '{closest_match}'") + validated_tools.append(closest_match) + + # Final validation - if we couldn't parse any tools, default to original behavior + if not validated_tools: + logger.error("Failed to extract ranked tools from LLM response") + default_tool = DEFAULT_LIBRARY + logger.warning(f"Defaulting to single tool: {default_tool}") + self._log_results(response, default_tool, explanation) + return [default_tool] # Log the results - self._log_results(response, selected_tool, explanation) + tools_str = ", ".join(validated_tools) + self._log_results(response, tools_str, explanation) - return selected_tool + return validated_tools def _log_results(self, response: str, selected_tool: str, explanation: str): """Log the parsing results.""" diff --git a/src/autogluon/assistant/prompts/variable_provider.py b/src/autogluon/assistant/prompts/variable_provider.py index f9ee8e21..74fba3df 100644 --- a/src/autogluon/assistant/prompts/variable_provider.py +++ b/src/autogluon/assistant/prompts/variable_provider.py @@ -70,7 +70,7 @@ def get_value(self, var_name: str) -> Any: "bash_script": lambda: self.manager.bash_script, "previous_bash_script": lambda: self.manager.previous_bash_script, "previous_error_message": lambda: self.manager.previous_error_message, - "all_error_analyses": lambda: self.manager.all_previous_error_analyses, + "all_previous_error_analyses": lambda: self.manager.all_previous_error_analyses, "tutorial_prompt": lambda: self.manager.tutorial_prompt, "previous_tutorial_prompt": lambda: self.manager.previous_tutorial_prompt, "selected_tool": lambda: self.manager.selected_tool, diff --git a/src/autogluon/assistant/prompts/variables.py b/src/autogluon/assistant/prompts/variables.py index 2ee569cb..f5b5d039 100644 --- a/src/autogluon/assistant/prompts/variables.py +++ b/src/autogluon/assistant/prompts/variables.py @@ -91,15 +91,15 @@ def _initialize_registry(self): # Error-related variables self.register( VariableDefinition( - name="previous_error_message", - description="Error message in the previous iteration", + name="error_message", + description="Error message in the current iteration", ) ) # Error-related variables self.register( VariableDefinition( - name="all_error_analyses", + name="all_previous_error_analyses", description="Error analysis in all completed iterations", ) ) @@ -197,7 +197,7 @@ def _initialize_registry(self): # Best code related self.register( VariableDefinition( - name="best_code_prompt", + name="code_improvement_prompt", description="Examples of high-quality code", ) ) diff --git a/src/autogluon/assistant/tools_registry/autogluon.tabular/tool.json b/src/autogluon/assistant/tools_registry/autogluon.tabular/tool.json index 618994bc..5814fbca 100644 --- a/src/autogluon/assistant/tools_registry/autogluon.tabular/tool.json +++ b/src/autogluon/assistant/tools_registry/autogluon.tabular/tool.json @@ -4,6 +4,7 @@ "description": "AutoGluon Tabular automates ML model training for tabular data, excelling at numerical and categorical data processing. It doesn't support NLP or image-based tasks.", "requirements": [], "prompt_template": [ + "try to use presets=\"extreme\"", "DO NOT create validation dataset from training data. Validation score is provided during the training process.", "AutoGluon Tabular requires converting multiple label columns into either a single combined label for mutually exclusive cases or training separate predictors for each label in non-exclusive scenarios." ] diff --git a/src/autogluon/assistant/tools_registry/autogluon.timeseries/tool.json b/src/autogluon/assistant/tools_registry/autogluon.timeseries/tool.json index 67a00413..4ec28db9 100644 --- a/src/autogluon/assistant/tools_registry/autogluon.timeseries/tool.json +++ b/src/autogluon/assistant/tools_registry/autogluon.timeseries/tool.json @@ -4,6 +4,7 @@ "description": "AutoGluon TimeSeries automates forecasting model training for time series data, supporting both univariate and multivariate predictions.", "requirements": [], "prompt_template": [ + "use presets: best_quality", "DO NOT drop any data samples (to make sure the frequency is regular).", "DO NOT create validation dataset from training data. Validation score is provided during the training process.", "Note that the prediction is given in a column named \"mean\". You need to rename the column in the result.",