Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions foambench_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def parse_args():
default=None,
help="Path to custom mesh file (e.g., .msh, .stl, .obj). If not provided, no custom mesh will be used."
)
parser.add_argument(
'--dataset_log_path',
type=str,
default="",
help="Path to per-case dataset.jsonl for fine-tuning data extraction."
)
parser.add_argument(
'--case_id',
type=str,
default="",
help="Case identifier, e.g. 'Basic/Cavity/1' or 'Advanced/Cavity_LES'."
)
return parser.parse_args()

def run_command(command_str):
Expand Down Expand Up @@ -76,6 +88,10 @@ def main():
main_cmd = f"python src/main.py --prompt_path='{args.prompt_path}' --output_dir='{args.output}'"
if args.custom_mesh_path:
main_cmd += f" --custom_mesh_path='{args.custom_mesh_path}'"
if args.dataset_log_path:
main_cmd += f" --dataset_log_path='{args.dataset_log_path}'"
if args.case_id:
main_cmd += f" --case_id='{args.case_id}'"

print(f"Main workflow command: {main_cmd}")

Expand Down
2 changes: 2 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Config:
# If set, InputWriter will check <reuse_generated_dir>/<folder>/<file> first.
# When present, it will copy into the current case_dir and skip LLM generation.
reuse_generated_dir: str = ""
dataset_log_path: str = ""
case_id: str = ""
# LLM backend:
# - "openai": OpenAI Platform usage-based (API key)
# - "openai-codex": ChatGPT/Codex subscription sign-in (Codex auth cache)
Expand Down
37 changes: 32 additions & 5 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from nodes.input_writer_node import input_writer_node
from nodes.local_runner_node import local_runner_node
from nodes.reviewer_node import reviewer_node
from nodes.restore_best_node import restore_best_node
from nodes.visualization_node import visualization_node
from nodes.hpc_runner_node import hpc_runner_node
from router_func import (
route_after_planner,
route_after_input_writer,
route_after_runner,
route_after_reviewer
route_after_reviewer,
route_after_restore_best,
)
from logger import close_logging
import json
Expand All @@ -36,8 +38,9 @@ def create_foam_agent_graph() -> StateGraph:
workflow.add_node("local_runner", local_runner_node)
workflow.add_node("hpc_runner", hpc_runner_node)
workflow.add_node("reviewer", reviewer_node)
workflow.add_node("restore_best", restore_best_node)
workflow.add_node("visualization", visualization_node)

# Add edges
workflow.add_edge(START, "planner")
workflow.add_conditional_edges("planner", route_after_planner)
Expand All @@ -46,6 +49,7 @@ def create_foam_agent_graph() -> StateGraph:
workflow.add_conditional_edges("hpc_runner", route_after_runner)
workflow.add_conditional_edges("local_runner", route_after_runner)
workflow.add_conditional_edges("reviewer", route_after_reviewer)
workflow.add_conditional_edges("restore_best", route_after_restore_best)
workflow.add_edge("visualization", END)

return workflow
Expand Down Expand Up @@ -92,7 +96,9 @@ def initialize_state(user_requirement: str, config: Config, custom_mesh_path: Op
job_id=None,
cluster_info=None,
slurm_script_path=None,
termination_reason=None
termination_reason=None,
best_case_snapshot_dir=None,
best_progress_score=None,
)
if custom_mesh_path:
print(f"<custom_mesh_path>{custom_mesh_path}</custom_mesh_path>")
Expand Down Expand Up @@ -164,7 +170,19 @@ def main(user_requirement: str, config: Config, custom_mesh_path: Optional[str]
"If a file exists at <reuse_generated_dir>/<folder>/<file>, Foam-Agent will copy it into the current output and skip generation for that file."
),
)

parser.add_argument(
"--dataset_log_path",
type=str,
default="",
help="Path to per-case dataset.jsonl for fine-tuning data extraction.",
)
parser.add_argument(
"--case_id",
type=str,
default="",
help="Case identifier, e.g. 'Basic/Cavity/1' or 'Advanced/Cavity_LES'.",
)

args = parser.parse_args()
print(f"args: {args}")

Expand All @@ -178,7 +196,16 @@ def main(user_requirement: str, config: Config, custom_mesh_path: Optional[str]

if args.reuse_generated_dir:
config.reuse_generated_dir = args.reuse_generated_dir

if args.dataset_log_path:
config.dataset_log_path = args.dataset_log_path
if args.case_id:
config.case_id = args.case_id

# Sync global LLM service with CLI-provided dataset_log_path/case_id
from services import global_llm_service
global_llm_service.dataset_log_path = config.dataset_log_path
global_llm_service.case_id = config.case_id

with open(args.prompt_path, 'r') as f:
user_requirement = f.read()

Expand Down
27 changes: 10 additions & 17 deletions src/nodes/input_writer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,8 @@
from typing import List
from pydantic import BaseModel, Field

# System prompts for different modes
INITIAL_WRITE_SYSTEM_PROMPT = (
"You are an expert in OpenFOAM simulation and numerical modeling."
f"Your task is to generate a complete and functional file named: <file_name>{{file_name}}</file_name> within the <folder_name>{{folder_name}}</folder_name> directory. "
"Ensure all required values are present and match with the files content already generated."
"Before finalizing the output, ensure:\n"
"- All necessary fields exist (e.g., if `nu` is defined in `constant/transportProperties`, it must be used correctly in `0/U`).\n"
"- Cross-check field names between different files to avoid mismatches.\n"
"- Ensure units and dimensions are correct** for all physical variables.\n"
f"- Ensure case solver settings are consistent with the user's requirements. Available solvers are: {{case_solver}}.\n"
"Provide only the code—no explanations, comments, or additional text."
)




def parse_allrun(text: str) -> str:
match = re.search(r'```(.*?)```', text, re.DOTALL)
Expand Down Expand Up @@ -56,17 +45,17 @@ def _rewrite_mode(state):
print("No review analysis available for rewrite mode.")
print("</input_writer>")
return state
out = rewrite_files(

return rewrite_files(
case_dir=state["case_dir"],
error_logs=state.get("error_logs", []),
review_analysis=state.get("review_analysis", ""),
rewrite_plan=state.get("rewrite_plan"),
rewrite_plan=None,
user_requirement=state.get("user_requirement", ""),
foamfiles=state.get("foamfiles"),
dir_structure=state.get("dir_structure", {}),
loop_count=state.get("loop_count", 0),
)
print("</input_writer>")
return out

def _initial_write_mode(state):
"""
Expand All @@ -92,6 +81,9 @@ def _initial_write_mode(state):
# Build Allrun via service
mesh_type = state.get("mesh_type")
mesh_commands = state.get("mesh_commands") or []
advice = state.get("similar_case_advice")
advice_d = advice.model_dump() if hasattr(advice, "model_dump") else (advice if isinstance(advice, dict) else {})
pre_solver_steps = advice_d.get("pre_solver_steps") if advice_d else None
allrun_out = build_allrun(
case_dir=state["case_dir"],
database_path=config.database_path,
Expand All @@ -101,6 +93,7 @@ def _initial_write_mode(state):
allrun_reference=state["allrun_reference"],
mesh_type=mesh_type,
mesh_commands=mesh_commands,
pre_solver_steps=pre_solver_steps,
)

print("</input_writer>")
Expand Down
40 changes: 28 additions & 12 deletions src/nodes/planner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,42 @@ def planner_node(state):
subtasks = plan_data["subtasks"]
similar_case_advice = plan_data.get("similar_case_advice")

# Handle case directory creation/cleanup
print(f"Parsed case name: {case_name}")
print(f"Parsed case domain: {case_domain}")
print(f"Parsed case category: {case_category}")
print(f"Parsed case solver: {case_solver}")
print(f"Created case directory: {case_dir}")
print(f"Retrieved similar case structure: {dir_structure_reference}")
print(f"Generated {len(subtasks)} subtasks.")
if similar_case_advice:
print(f"Similar case advice: {similar_case_advice}")

# Handle case directory creation/cleanup.
# Preserve dataset_log_path across rmtree when it lives inside case_dir,
# so plan-phase LLM records (already written this run) aren't destroyed.
dataset_log_path = getattr(config, "dataset_log_path", "") or ""
preserved_log_bytes = None
if dataset_log_path and os.path.exists(case_dir) and os.path.exists(dataset_log_path):
try:
rel = os.path.relpath(os.path.abspath(dataset_log_path), os.path.abspath(case_dir))
if not rel.startswith("..") and not os.path.isabs(rel):
with open(dataset_log_path, "rb") as _f:
preserved_log_bytes = _f.read()
except Exception as _e:
print(f"Warning: failed to snapshot dataset_log_path before rmtree: {_e}")
if os.path.exists(case_dir):
print(f"Warning: Case directory {case_dir} already exists. Overwriting.")
shutil.rmtree(case_dir)
os.makedirs(case_dir)

if preserved_log_bytes is not None:
os.makedirs(os.path.dirname(dataset_log_path), exist_ok=True)
with open(dataset_log_path, "wb") as _f:
_f.write(preserved_log_bytes)

# Initialize logging now that case_dir exists
setup_logging(case_dir)

print("<planner>")
print(f"<case_name>{case_name}</case_name>")
print(f"<case_domain>{case_domain}</case_domain>")
print(f"<case_category>{case_category}</case_category>")
print(f"<case_solver>{case_solver}</case_solver>")
print(f"<case_dir>{case_dir}</case_dir>")
print(f"<similar_case_structure>{dir_structure_reference}</similar_case_structure>")
print(f"<subtask_count>{len(subtasks)} subtasks generated.</subtask_count>")
if similar_case_advice:
print(f"<similar_case_advice>{similar_case_advice}</similar_case_advice>")

# Save reference file
save_file(case_path_reference, f"{faiss_detailed}\n\n\n{allrun_reference}")

Expand Down
19 changes: 19 additions & 0 deletions src/nodes/restore_best_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import shutil


def restore_best_node(state):
"""Restore the best-snapshot case directory when the review loop exits at max_loop."""
snap = state.get("best_case_snapshot_dir")
if snap and os.path.exists(snap):
case_dir = state["case_dir"]
if os.path.exists(case_dir):
shutil.rmtree(case_dir)
shutil.copytree(snap, case_dir)
print(
f"<restore_best>Restored best snapshot from {snap} "
f"(score={state.get('best_progress_score')})</restore_best>"
)
else:
print("<restore_best>No best snapshot to restore.</restore_best>")
return {}
73 changes: 49 additions & 24 deletions src/nodes/reviewer_node.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,76 @@
# reviewer_node.py
from pydantic import BaseModel, Field
from typing import List
from services.review import review_error_logs, generate_rewrite_plan
import os
import glob
import re
import shutil
from services.review import review_error_logs
from logger import log_review


def _compute_progress_score(case_dir: str, error_logs: list) -> float:
"""Highest timestep reached across log.* files; fall back to -len(error_logs)."""
best_t = None
for log_path in glob.glob(os.path.join(case_dir, "log.*")):
try:
with open(log_path, errors="ignore") as f:
for line in f:
m = re.match(r"^Time = ([\d.eE+\-]+)", line)
if m:
t = float(m.group(1))
if best_t is None or t > best_t:
best_t = t
except Exception:
pass
return float(best_t) if best_t is not None else -len(error_logs)


def reviewer_node(state):
"""
Reviewer node: Reviews the error logs and provides analysis and suggestions
for fixing the errors. This node only focuses on analysis, not file modification.
"""
"""Reviewer node: single-call review (FA 1.1.0 style) + best-state snapshot."""
print("<reviewer>")
if len(state["error_logs"]) == 0:
print("No error to review.")
print("</reviewer>")
return state

# Log error logs to review.log
log_review(str(state["error_logs"]), "error_logs")

# Stateless review via service
case_dir = state["case_dir"]
error_logs = state.get("error_logs", [])
loop_count = state.get("loop_count", 0)
history_text = state.get("history_text") or []

# Best-state snapshot before this loop's rewrite can regress it
snapshot_updates = {}
score = _compute_progress_score(case_dir, error_logs)
best_score = state.get("best_progress_score")
if best_score is None:
best_score = float("-inf")
if score > best_score:
snap = case_dir.rstrip("/") + "_best"
if os.path.exists(snap):
shutil.rmtree(snap)
shutil.copytree(case_dir, snap)
print(f"<snapshot>progress={score:.4g} > {best_score:.4g}, saved to {snap}</snapshot>")
snapshot_updates = {"best_case_snapshot_dir": snap, "best_progress_score": score}

review_content, updated_history = review_error_logs(
tutorial_reference=state.get('tutorial_reference', ''),
foamfiles=state.get('foamfiles'),
error_logs=state.get('error_logs'),
user_requirement=state.get('user_requirement', ''),
similar_case_advice=state.get('similar_case_advice'),
tutorial_reference=state.get("tutorial_reference", ""),
foamfiles=state.get("foamfiles"),
error_logs=error_logs,
user_requirement=state.get("user_requirement", ""),
similar_case_advice=state.get("similar_case_advice"),
history_text=history_text,
loop_count=loop_count,
)

log_review(review_content, "review_analysis")

rewrite_plan = generate_rewrite_plan(
foamfiles=state.get('foamfiles'),
error_logs=state.get('error_logs', []),
review_analysis=review_content,
user_requirement=state.get('user_requirement', ''),
)
log_review(str(rewrite_plan), "rewrite_plan")

print("</reviewer>")

return {
**snapshot_updates,
"history_text": updated_history,
"review_analysis": review_content,
"rewrite_plan": rewrite_plan,
"loop_count": state.get("loop_count", 0) + 1,
"loop_count": loop_count + 1,
"input_writer_mode": "rewrite",
}
Loading