Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 150 additions & 89 deletions ai_scientist/perform_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
AVAILABLE_LLMS,
)

from ai_scientist.perform_icbinb_writeup import (
load_idea_text,
load_exp_summaries,
filter_experiment_summaries,
)

from ai_scientist.tools.semantic_scholar import search_for_papers

from ai_scientist.perform_vlm_review import generate_vlm_img_review
Expand Down Expand Up @@ -236,7 +242,7 @@ def get_citation_addition(

try:
text, msg_history = get_response_from_llm(
msg=citation_first_prompt_template.format(
prompt=citation_first_prompt_template.format(
current_round=current_round + 1,
total_rounds=total_rounds,
Idea=idea_text,
Expand Down Expand Up @@ -284,7 +290,7 @@ def get_citation_addition(

try:
text, msg_history = get_response_from_llm(
msg=citation_second_prompt_template.format(
prompt=citation_second_prompt_template.format(
papers=papers_str,
current_round=current_round + 1,
total_rounds=total_rounds,
Expand Down Expand Up @@ -451,9 +457,106 @@ def get_citation_addition(
```
"""

def gather_citations(base_folder, num_cite_rounds=20, small_model="gpt-4o-2024-05-13"):
"""
Gather citations for a paper, with ability to resume from previous progress.

Args:
base_folder: Path to project folder
num_cite_rounds: Maximum number of citation gathering rounds
small_model: Model to use for citation collection
resume: Whether to try to resume from previous progress

Returns:
str: The gathered citations text, or None if failed
"""

# Initialize or load progress
current_round = 0
citations_text = ""

latex_folder = osp.join(base_folder, "latex")

# Prepare a new fresh latex folder
if not osp.exists(osp.join(latex_folder, "template.tex")):
shutil.copytree(
"ai_scientist/blank_icml_latex", latex_folder, dirs_exist_ok=True
)

writeup_file = osp.join(latex_folder, "template.tex")
with open(writeup_file, "r") as f:
writeup_text = f.read()

writeup_file = osp.join(latex_folder, "template.tex")
with open(writeup_file, "r") as f:
writeup_text = f.read()

try:
# Load idea text and summaries
idea_text = load_idea_text(base_folder)
exp_summaries = load_exp_summaries(base_folder)
filtered_summaries = filter_experiment_summaries(
exp_summaries, step_name="citation_gathering"
)
combined_summaries_str = json.dumps(filtered_summaries, indent=2)

# Run small model for citation additions
client, client_model = create_client(small_model)
for round_idx in range(num_cite_rounds):
with open(writeup_file, "r") as f:
writeup_text = f.read()
try:
references_bib = re.search(
r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}",
writeup_text,
re.DOTALL,
)
if references_bib is None:
raise ValueError("No references.bib found in template.tex")
citations_text = references_bib.group(1)
context_for_citation = (combined_summaries_str, citations_text)

addition, done = get_citation_addition(
client,
client_model,
context_for_citation,
round_idx,
num_cite_rounds,
idea_text,
)
if done:
break

if addition is not None:
# Simple check to avoid duplicating the same title
title_match = re.search(r" title = {(.*?)}", addition)
if title_match:
new_title = title_match.group(1).lower()
existing_titles = re.findall(
r" title = {(.*?)}", citations_text
)
existing_titles = [t.lower() for t in existing_titles]
if new_title not in existing_titles:
pattern_end = r"\end{filecontents}"
revised = writeup_text.replace(
pattern_end, f"\n{addition}{pattern_end}"
)
with open(writeup_file, "w") as fo:
fo.write(revised)
except Exception:
print("EXCEPTION in gather_citations:")
print(traceback.format_exc())
continue
return citations_text if citations_text else None

except Exception:
print("EXCEPTION in gather_citations:")
print(traceback.format_exc())
return citations_text if citations_text else None

def perform_writeup(
base_folder,
citations_text=None,
no_writing=False,
num_cite_rounds=20,
small_model="gpt-4o-2024-05-13",
Expand All @@ -472,41 +575,14 @@ def perform_writeup(
# os.remove(pdf_file)

try:
# Load idea text
idea_text = ""
research_idea_path = osp.join(base_folder, "research_idea.md")
if osp.exists(research_idea_path):
with open(research_idea_path, "r") as f_idea:
idea_text = f_idea.read()
else:
idea_md_path = osp.join(base_folder, "idea.md")
if osp.exists(idea_md_path):
with open(idea_md_path, "r") as f_idea:
idea_text = f_idea.read()

# Load summaries
summary_files = [
("logs/0-run/baseline_summary.json", "BASELINE_SUMMARY"),
("logs/0-run/research_summary.json", "RESEARCH_SUMMARY"),
("logs/0-run/ablation_summary.json", "ABLATION_SUMMARY"),
]
loaded_summaries = {}
for fname, key in summary_files:
path = osp.join(base_folder, fname)
if osp.exists(path):
try:
with open(path, "r") as f:
loaded_summaries[key] = json.load(f)
except json.JSONDecodeError:
print(
f"Warning: {fname} is not valid JSON. Using empty data for {key}."
)
loaded_summaries[key] = {}
else:
loaded_summaries[key] = {}

# Load idea text and summaries
idea_text = load_idea_text(base_folder)
exp_summaries = load_exp_summaries(base_folder)
filtered_summaries_for_writeup = filter_experiment_summaries(
exp_summaries, step_name="writeup"
)
# Convert them to one big JSON string for context
combined_summaries_str = json.dumps(loaded_summaries, indent=2)
combined_summaries_str = json.dumps(filtered_summaries_for_writeup, indent=2)

# Prepare a new fresh latex folder
if not osp.exists(osp.join(latex_folder, "template.tex")):
Expand Down Expand Up @@ -538,54 +614,36 @@ def perform_writeup(
if no_writing:
compile_latex(latex_folder, base_pdf_file + ".pdf")
return osp.exists(base_pdf_file + ".pdf")

# Run small model for citation additions
client, client_model = create_client(small_model)
for round_idx in range(num_cite_rounds):
with open(writeup_file, "r") as f:
writeup_text = f.read()
try:
references_bib = re.search(
r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}",
writeup_text,
re.DOTALL,
)
if references_bib is None:
raise ValueError("No references.bib found in template.tex")
citations_text = references_bib.group(1)
context_for_citation = (combined_summaries_str, citations_text)

addition, done = get_citation_addition(
client,
client_model,
context_for_citation,
round_idx,
num_cite_rounds,
idea_text,

# If no citations provided, try to load from cache first
if citations_text is None:
citations_cache_path = osp.join(base_folder, "cached_citations.bib")
if osp.exists(citations_cache_path):
try:
with open(citations_cache_path, "r") as f:
citations_text = f.read()
print("Loaded citations from cache")
except Exception as e:
print(f"Error loading cached citations: {e}")
citations_text = None

# If still no citations, gather them
if not citations_text:
citations_text = gather_citations(
base_folder, num_cite_rounds, small_model
)
if done:
break
if citations_text is None:
print("Warning: Citation gathering failed")
citations_text = ""

if addition is not None:
# Simple check to avoid duplicating the same title
title_match = re.search(r" title = {(.*?)}", addition)
if title_match:
new_title = title_match.group(1).lower()
existing_titles = re.findall(
r" title = {(.*?)}", citations_text
)
existing_titles = [t.lower() for t in existing_titles]
if new_title not in existing_titles:
pattern_end = r"\end{filecontents}"
revised = writeup_text.replace(
pattern_end, f"\n{addition}{pattern_end}"
)
with open(writeup_file, "w") as fo:
fo.write(revised)
except Exception:
print("EXCEPTION in perform_writeup (citation round):")
print(traceback.format_exc())
continue
# Insert citations into template.tex
if citations_text:
with open(writeup_file, "r") as f:
content = f.read()
pattern_end = r"\end{filecontents}"
content = content.replace(pattern_end, f"\n{citations_text}{pattern_end}")
with open(writeup_file, "w") as f:
f.write(content)

# Generate VLM-based descriptions but do not overwrite plot_names
try:
Expand Down Expand Up @@ -636,7 +694,7 @@ def perform_writeup(
)

response, msg_history = get_response_from_llm(
msg=combined_prompt,
prompt=combined_prompt,
client=big_client,
model=big_client_model,
system_message=big_model_system_message,
Expand All @@ -645,6 +703,9 @@ def perform_writeup(

latex_code_match = re.search(r"```latex(.*?)```", response, re.DOTALL)
if not latex_code_match:
print("No valid LaTeX code block found in initial writeup response.")
print("Response was:")
print(response)
return False
updated_latex_code = latex_code_match.group(1).strip()
with open(writeup_file, "w") as f:
Expand All @@ -665,9 +726,9 @@ def perform_writeup(
invalid_figs = used_figs - all_figs

# Compile current version before reflection
compile_latex(latex_folder, base_pdf_file + f"_{compile_attempt}.pdf")
compile_latex(latex_folder, base_pdf_file + f"_reflection_{compile_attempt}.pdf")
compile_attempt += 1
print(f"Compiled {base_pdf_file}_{compile_attempt}.pdf")
print(f"Compiled {base_pdf_file}_reflection_{compile_attempt}.pdf")

# Detect where "Impact Statement" appears
impact_loc = detect_pages_before_impact(latex_folder)
Expand Down Expand Up @@ -707,7 +768,7 @@ def perform_writeup(
"""

reflection_response, msg_history = get_response_from_llm(
msg=reflection_prompt,
prompt=reflection_prompt,
client=big_client,
model=big_client_model,
system_message=big_model_system_message,
Expand Down Expand Up @@ -741,18 +802,18 @@ def perform_writeup(
fo.write(final_text)

compile_latex(
latex_folder, base_pdf_file + f"_{compile_attempt}.pdf"
latex_folder, base_pdf_file + f"_reflection_{compile_attempt}.pdf"
)
compile_attempt += 1
print(f"Compiled {base_pdf_file}_{compile_attempt}.pdf")
print(f"Compiled {base_pdf_file}_reflection_{compile_attempt}.pdf")
else:
print(f"No changes in reflection step {i+1}.")
break
else:
print(f"No valid LaTeX code block found in reflection step {i+1}.")
break

return osp.exists(base_pdf_file + f"_{compile_attempt-1}.pdf")
return osp.exists(base_pdf_file + f"_reflection_{compile_attempt-1}.pdf")

except Exception:
print("EXCEPTION in perform_writeup:")
Expand Down
27 changes: 19 additions & 8 deletions launch_scientist_bfts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
edit_bfts_config_file,
)
from ai_scientist.perform_plotting import aggregate_plots
from ai_scientist.perform_writeup import perform_writeup
from ai_scientist.perform_writeup import perform_writeup, gather_citations
from ai_scientist.perform_icbinb_writeup import (
perform_writeup as perform_icbinb_writeup,
gather_citations,
gather_citations as gather_icbinb_citations,
)
from ai_scientist.perform_llm_review import perform_review, load_paper
from ai_scientist.perform_vlm_review import perform_imgs_cap_ref_review
Expand Down Expand Up @@ -88,6 +88,12 @@ def parse_arguments():
default="o3-mini-2025-01-31",
help="Model to use for plot aggregation",
)
parser.add_argument(
"--model_agg_plots_ref",
type=int,
default=5,
help="Number of reflections to use for plot aggregation",
)
parser.add_argument(
"--model_writeup",
type=str,
Expand Down Expand Up @@ -262,22 +268,22 @@ def redirect_stdout_stderr_to_file(log_file_path):
dirs_exist_ok=True,
)

aggregate_plots(base_folder=idea_dir, model=args.model_agg_plots)
aggregate_plots(base_folder=idea_dir, model=args.model_agg_plots, n_reflections=args.model_agg_plots_ref)

shutil.rmtree(osp.join(idea_dir, "experiment_results"))

save_token_tracker(idea_dir)

if not args.skip_writeup:
writeup_success = False
citations_text = gather_citations(
idea_dir,
num_cite_rounds=args.num_cite_rounds,
small_model=args.model_citation,
)
for attempt in range(args.writeup_retries):
print(f"Writeup attempt {attempt+1} of {args.writeup_retries}")
if args.writeup_type == "normal":
citations_text = gather_citations(
idea_dir,
num_cite_rounds=args.num_cite_rounds,
small_model=args.model_citation,
)
writeup_success = perform_writeup(
base_folder=idea_dir,
small_model=args.model_writeup_small,
Expand All @@ -286,6 +292,11 @@ def redirect_stdout_stderr_to_file(log_file_path):
citations_text=citations_text,
)
else:
citations_text = gather_icbinb_citations(
idea_dir,
num_cite_rounds=args.num_cite_rounds,
small_model=args.model_citation,
)
writeup_success = perform_icbinb_writeup(
base_folder=idea_dir,
small_model=args.model_writeup_small,
Expand Down