diff --git a/ai_scientist/perform_writeup.py b/ai_scientist/perform_writeup.py index eec8beb0..b3613570 100644 --- a/ai_scientist/perform_writeup.py +++ b/ai_scientist/perform_writeup.py @@ -16,6 +16,12 @@ AVAILABLE_LLMS, ) +from ai_scientist.perform_icbinb_writeup import ( + load_idea_text, + load_exp_summaries, + filter_experiment_summaries, +) + from ai_scientist.tools.semantic_scholar import search_for_papers from ai_scientist.perform_vlm_review import generate_vlm_img_review @@ -236,7 +242,7 @@ def get_citation_addition( try: text, msg_history = get_response_from_llm( - msg=citation_first_prompt_template.format( + prompt=citation_first_prompt_template.format( current_round=current_round + 1, total_rounds=total_rounds, Idea=idea_text, @@ -284,7 +290,7 @@ def get_citation_addition( try: text, msg_history = get_response_from_llm( - msg=citation_second_prompt_template.format( + prompt=citation_second_prompt_template.format( papers=papers_str, current_round=current_round + 1, total_rounds=total_rounds, @@ -451,9 +457,106 @@ def get_citation_addition( ``` """ +def gather_citations(base_folder, num_cite_rounds=20, small_model="gpt-4o-2024-05-13"): + """ + Gather citations for a paper, with ability to resume from previous progress. + + Args: + base_folder: Path to project folder + num_cite_rounds: Maximum number of citation gathering rounds + small_model: Model to use for citation collection + resume: Whether to try to resume from previous progress + + Returns: + str: The gathered citations text, or None if failed + """ + + # Initialize or load progress + current_round = 0 + citations_text = "" + + latex_folder = osp.join(base_folder, "latex") + + # Prepare a new fresh latex folder + if not osp.exists(osp.join(latex_folder, "template.tex")): + shutil.copytree( + "ai_scientist/blank_icml_latex", latex_folder, dirs_exist_ok=True + ) + + writeup_file = osp.join(latex_folder, "template.tex") + with open(writeup_file, "r") as f: + writeup_text = f.read() + + writeup_file = osp.join(latex_folder, "template.tex") + with open(writeup_file, "r") as f: + writeup_text = f.read() + + try: + # Load idea text and summaries + idea_text = load_idea_text(base_folder) + exp_summaries = load_exp_summaries(base_folder) + filtered_summaries = filter_experiment_summaries( + exp_summaries, step_name="citation_gathering" + ) + combined_summaries_str = json.dumps(filtered_summaries, indent=2) + + # Run small model for citation additions + client, client_model = create_client(small_model) + for round_idx in range(num_cite_rounds): + with open(writeup_file, "r") as f: + writeup_text = f.read() + try: + references_bib = re.search( + r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}", + writeup_text, + re.DOTALL, + ) + if references_bib is None: + raise ValueError("No references.bib found in template.tex") + citations_text = references_bib.group(1) + context_for_citation = (combined_summaries_str, citations_text) + + addition, done = get_citation_addition( + client, + client_model, + context_for_citation, + round_idx, + num_cite_rounds, + idea_text, + ) + if done: + break + + if addition is not None: + # Simple check to avoid duplicating the same title + title_match = re.search(r" title = {(.*?)}", addition) + if title_match: + new_title = title_match.group(1).lower() + existing_titles = re.findall( + r" title = {(.*?)}", citations_text + ) + existing_titles = [t.lower() for t in existing_titles] + if new_title not in existing_titles: + pattern_end = r"\end{filecontents}" + revised = writeup_text.replace( + pattern_end, f"\n{addition}{pattern_end}" + ) + with open(writeup_file, "w") as fo: + fo.write(revised) + except Exception: + print("EXCEPTION in gather_citations:") + print(traceback.format_exc()) + continue + return citations_text if citations_text else None + + except Exception: + print("EXCEPTION in gather_citations:") + print(traceback.format_exc()) + return citations_text if citations_text else None def perform_writeup( base_folder, + citations_text=None, no_writing=False, num_cite_rounds=20, small_model="gpt-4o-2024-05-13", @@ -472,41 +575,14 @@ def perform_writeup( # os.remove(pdf_file) try: - # Load idea text - idea_text = "" - research_idea_path = osp.join(base_folder, "research_idea.md") - if osp.exists(research_idea_path): - with open(research_idea_path, "r") as f_idea: - idea_text = f_idea.read() - else: - idea_md_path = osp.join(base_folder, "idea.md") - if osp.exists(idea_md_path): - with open(idea_md_path, "r") as f_idea: - idea_text = f_idea.read() - - # Load summaries - summary_files = [ - ("logs/0-run/baseline_summary.json", "BASELINE_SUMMARY"), - ("logs/0-run/research_summary.json", "RESEARCH_SUMMARY"), - ("logs/0-run/ablation_summary.json", "ABLATION_SUMMARY"), - ] - loaded_summaries = {} - for fname, key in summary_files: - path = osp.join(base_folder, fname) - if osp.exists(path): - try: - with open(path, "r") as f: - loaded_summaries[key] = json.load(f) - except json.JSONDecodeError: - print( - f"Warning: {fname} is not valid JSON. Using empty data for {key}." - ) - loaded_summaries[key] = {} - else: - loaded_summaries[key] = {} - + # Load idea text and summaries + idea_text = load_idea_text(base_folder) + exp_summaries = load_exp_summaries(base_folder) + filtered_summaries_for_writeup = filter_experiment_summaries( + exp_summaries, step_name="writeup" + ) # Convert them to one big JSON string for context - combined_summaries_str = json.dumps(loaded_summaries, indent=2) + combined_summaries_str = json.dumps(filtered_summaries_for_writeup, indent=2) # Prepare a new fresh latex folder if not osp.exists(osp.join(latex_folder, "template.tex")): @@ -538,54 +614,36 @@ def perform_writeup( if no_writing: compile_latex(latex_folder, base_pdf_file + ".pdf") return osp.exists(base_pdf_file + ".pdf") - - # Run small model for citation additions - client, client_model = create_client(small_model) - for round_idx in range(num_cite_rounds): - with open(writeup_file, "r") as f: - writeup_text = f.read() - try: - references_bib = re.search( - r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}", - writeup_text, - re.DOTALL, - ) - if references_bib is None: - raise ValueError("No references.bib found in template.tex") - citations_text = references_bib.group(1) - context_for_citation = (combined_summaries_str, citations_text) - - addition, done = get_citation_addition( - client, - client_model, - context_for_citation, - round_idx, - num_cite_rounds, - idea_text, + + # If no citations provided, try to load from cache first + if citations_text is None: + citations_cache_path = osp.join(base_folder, "cached_citations.bib") + if osp.exists(citations_cache_path): + try: + with open(citations_cache_path, "r") as f: + citations_text = f.read() + print("Loaded citations from cache") + except Exception as e: + print(f"Error loading cached citations: {e}") + citations_text = None + + # If still no citations, gather them + if not citations_text: + citations_text = gather_citations( + base_folder, num_cite_rounds, small_model ) - if done: - break + if citations_text is None: + print("Warning: Citation gathering failed") + citations_text = "" - if addition is not None: - # Simple check to avoid duplicating the same title - title_match = re.search(r" title = {(.*?)}", addition) - if title_match: - new_title = title_match.group(1).lower() - existing_titles = re.findall( - r" title = {(.*?)}", citations_text - ) - existing_titles = [t.lower() for t in existing_titles] - if new_title not in existing_titles: - pattern_end = r"\end{filecontents}" - revised = writeup_text.replace( - pattern_end, f"\n{addition}{pattern_end}" - ) - with open(writeup_file, "w") as fo: - fo.write(revised) - except Exception: - print("EXCEPTION in perform_writeup (citation round):") - print(traceback.format_exc()) - continue + # Insert citations into template.tex + if citations_text: + with open(writeup_file, "r") as f: + content = f.read() + pattern_end = r"\end{filecontents}" + content = content.replace(pattern_end, f"\n{citations_text}{pattern_end}") + with open(writeup_file, "w") as f: + f.write(content) # Generate VLM-based descriptions but do not overwrite plot_names try: @@ -636,7 +694,7 @@ def perform_writeup( ) response, msg_history = get_response_from_llm( - msg=combined_prompt, + prompt=combined_prompt, client=big_client, model=big_client_model, system_message=big_model_system_message, @@ -645,6 +703,9 @@ def perform_writeup( latex_code_match = re.search(r"```latex(.*?)```", response, re.DOTALL) if not latex_code_match: + print("No valid LaTeX code block found in initial writeup response.") + print("Response was:") + print(response) return False updated_latex_code = latex_code_match.group(1).strip() with open(writeup_file, "w") as f: @@ -665,9 +726,9 @@ def perform_writeup( invalid_figs = used_figs - all_figs # Compile current version before reflection - compile_latex(latex_folder, base_pdf_file + f"_{compile_attempt}.pdf") + compile_latex(latex_folder, base_pdf_file + f"_reflection_{compile_attempt}.pdf") compile_attempt += 1 - print(f"Compiled {base_pdf_file}_{compile_attempt}.pdf") + print(f"Compiled {base_pdf_file}_reflection_{compile_attempt}.pdf") # Detect where "Impact Statement" appears impact_loc = detect_pages_before_impact(latex_folder) @@ -707,7 +768,7 @@ def perform_writeup( """ reflection_response, msg_history = get_response_from_llm( - msg=reflection_prompt, + prompt=reflection_prompt, client=big_client, model=big_client_model, system_message=big_model_system_message, @@ -741,10 +802,10 @@ def perform_writeup( fo.write(final_text) compile_latex( - latex_folder, base_pdf_file + f"_{compile_attempt}.pdf" + latex_folder, base_pdf_file + f"_reflection_{compile_attempt}.pdf" ) compile_attempt += 1 - print(f"Compiled {base_pdf_file}_{compile_attempt}.pdf") + print(f"Compiled {base_pdf_file}_reflection_{compile_attempt}.pdf") else: print(f"No changes in reflection step {i+1}.") break @@ -752,7 +813,7 @@ def perform_writeup( print(f"No valid LaTeX code block found in reflection step {i+1}.") break - return osp.exists(base_pdf_file + f"_{compile_attempt-1}.pdf") + return osp.exists(base_pdf_file + f"_reflection_{compile_attempt-1}.pdf") except Exception: print("EXCEPTION in perform_writeup:") diff --git a/launch_scientist_bfts.py b/launch_scientist_bfts.py index e854af0e..34fd7929 100644 --- a/launch_scientist_bfts.py +++ b/launch_scientist_bfts.py @@ -18,10 +18,10 @@ edit_bfts_config_file, ) from ai_scientist.perform_plotting import aggregate_plots -from ai_scientist.perform_writeup import perform_writeup +from ai_scientist.perform_writeup import perform_writeup, gather_citations from ai_scientist.perform_icbinb_writeup import ( perform_writeup as perform_icbinb_writeup, - gather_citations, + gather_citations as gather_icbinb_citations, ) from ai_scientist.perform_llm_review import perform_review, load_paper from ai_scientist.perform_vlm_review import perform_imgs_cap_ref_review @@ -88,6 +88,12 @@ def parse_arguments(): default="o3-mini-2025-01-31", help="Model to use for plot aggregation", ) + parser.add_argument( + "--model_agg_plots_ref", + type=int, + default=5, + help="Number of reflections to use for plot aggregation", + ) parser.add_argument( "--model_writeup", type=str, @@ -262,7 +268,7 @@ def redirect_stdout_stderr_to_file(log_file_path): dirs_exist_ok=True, ) - aggregate_plots(base_folder=idea_dir, model=args.model_agg_plots) + aggregate_plots(base_folder=idea_dir, model=args.model_agg_plots, n_reflections=args.model_agg_plots_ref) shutil.rmtree(osp.join(idea_dir, "experiment_results")) @@ -270,14 +276,14 @@ def redirect_stdout_stderr_to_file(log_file_path): if not args.skip_writeup: writeup_success = False - citations_text = gather_citations( - idea_dir, - num_cite_rounds=args.num_cite_rounds, - small_model=args.model_citation, - ) for attempt in range(args.writeup_retries): print(f"Writeup attempt {attempt+1} of {args.writeup_retries}") if args.writeup_type == "normal": + citations_text = gather_citations( + idea_dir, + num_cite_rounds=args.num_cite_rounds, + small_model=args.model_citation, + ) writeup_success = perform_writeup( base_folder=idea_dir, small_model=args.model_writeup_small, @@ -286,6 +292,11 @@ def redirect_stdout_stderr_to_file(log_file_path): citations_text=citations_text, ) else: + citations_text = gather_icbinb_citations( + idea_dir, + num_cite_rounds=args.num_cite_rounds, + small_model=args.model_citation, + ) writeup_success = perform_icbinb_writeup( base_folder=idea_dir, small_model=args.model_writeup_small,