diff --git a/autofigure2.py b/autofigure2.py index bacdadb..ff2fad5 100644 --- a/autofigure2.py +++ b/autofigure2.py @@ -1587,6 +1587,68 @@ def crop_and_remove_background( return icon_infos +def _load_icon_infos_from_existing(output_dir: Path, boxlib_path: str) -> list: + """从已有 outputs 目录加载 icon_infos(用于 --start_from 4/5 恢复流程)""" + icons_dir = output_dir / "icons" + with open(boxlib_path, 'r', encoding='utf-8') as f: + boxlib_data = json.load(f) + icon_infos = [] + for box in boxlib_data.get("boxes", []): + box_id = box["id"] + label = box.get("label", f"{box_id + 1:02d}") + label_clean = label.replace("<", "").replace(">", "") + x1, y1, x2, y2 = box["x1"], box["y1"], box["x2"], box["y2"] + crop_path = icons_dir / f"icon_{label_clean}.png" + nobg_path = icons_dir / f"icon_{label_clean}_nobg.png" + nobg_str = str(nobg_path) if nobg_path.exists() else str(crop_path) + if not crop_path.exists() and not nobg_path.exists(): + raise FileNotFoundError(f"start_from>=4 需要图标文件: {crop_path} 或 {nobg_path}") + icon_infos.append({ + "id": box_id, "label": label, "label_clean": label_clean, + "x1": x1, "y1": y1, "x2": x2, "y2": y2, + "width": x2 - x1, "height": y2 - y1, + "crop_path": str(crop_path), "nobg_path": nobg_str, + }) + return icon_infos + + +def _require_files_for_step(output_dir: Path, step: int) -> None: + """检查指定步骤所需文件是否存在,不存在则抛出 FileNotFoundError""" + figure = output_dir / "figure.png" + samed = output_dir / "samed.png" + boxlib = output_dir / "boxlib.json" + icons_dir = output_dir / "icons" + template = output_dir / "template.svg" + optimized = output_dir / "optimized_template.svg" + + if step >= 2: + if not figure.exists(): + raise FileNotFoundError(f"从步骤 {step} 开始需要: {figure}") + if step >= 3: + if not samed.exists(): + raise FileNotFoundError(f"从步骤 {step} 开始需要: {samed}") + if not boxlib.exists(): + raise FileNotFoundError(f"从步骤 {step} 开始需要: {boxlib}") + if step >= 4: + if not icons_dir.is_dir(): + raise FileNotFoundError(f"从步骤 {step} 开始需要目录: {icons_dir}") + # 至少需有 boxlib 中对应的裁剪图 + with open(boxlib, 'r', encoding='utf-8') as f: + boxes = json.load(f).get("boxes", []) + for box in boxes: + label = box.get("label", f"{box['id']+1:02d}") + label_clean = label.replace("<", "").replace(">", "") + crop = icons_dir / f"icon_{label_clean}.png" + nobg = icons_dir / f"icon_{label_clean}_nobg.png" + if not crop.exists() and not nobg.exists(): + raise FileNotFoundError(f"从步骤 {step} 开始需要: {crop} 或 {nobg}") + if step >= 5: + if not optimized.exists() and not template.exists(): + raise FileNotFoundError( + f"从步骤 {step} 开始需要: {optimized} 或 {template}" + ) + + # ============================================================================ # 步骤四:多模态调用生成 SVG # ============================================================================ @@ -2402,6 +2464,7 @@ def method_to_svg( sam_max_masks: int = 32, rmbg_model_path: Optional[str] = None, stop_after: int = 5, + start_from: int = 1, placeholder_mode: PlaceholderMode = "label", optimize_iterations: int = 2, merge_threshold: float = 0.9, @@ -2424,6 +2487,7 @@ def method_to_svg( sam_max_masks: SAM3 API 最大 masks 数(api 模式使用) rmbg_model_path: RMBG 模型路径 stop_after: 执行到指定步骤后停止 + start_from: 从指定步骤开始(1-5,默认1。>=4 时跳过步骤1-3,使用已有 figure.png/samed.png/boxlib.json/icons) placeholder_mode: 占位符模式 - "none": 无特殊样式 - "box": 传入 boxlib 坐标 @@ -2463,21 +2527,29 @@ def method_to_svg( if sam_backend_value == "fal": print(f"SAM3 API max_masks: {sam_max_masks}") print(f"执行到步骤: {stop_after}") + print(f"从步骤开始: {start_from}") print(f"占位符模式: {placeholder_mode}") print(f"优化迭代次数: {optimize_iterations}") print(f"Box合并阈值: {merge_threshold}") print("=" * 60) - # 步骤一:生成图片 figure_path = output_dir / "figure.png" - generate_figure_from_method( - method_text=method_text, - output_path=str(figure_path), - api_key=api_key, - model=image_gen_model, - base_url=base_url, - provider=provider, - ) + samed_path = str(output_dir / "samed.png") + boxlib_path = str(output_dir / "boxlib.json") + + # 预先检查所需资源 + _require_files_for_step(output_dir, start_from) + + if start_from <= 1: + # 步骤一:生成图片 + generate_figure_from_method( + method_text=method_text, + output_path=str(figure_path), + api_key=api_key, + model=image_gen_model, + base_url=base_url, + provider=provider, + ) if stop_after == 1: print("\n" + "=" * 60) @@ -2493,95 +2565,107 @@ def method_to_svg( "final_svg_path": None, } - # 步骤二:SAM3 分割(包含Box合并) - samed_path, boxlib_path, valid_boxes = segment_with_sam3( - image_path=str(figure_path), - output_dir=str(output_dir), - text_prompts=sam_prompts, - min_score=min_score, - merge_threshold=merge_threshold, - sam_backend=sam_backend_value, - sam_api_key=sam_api_key, - sam_max_masks=sam_max_masks, - ) - - if len(valid_boxes) == 0: - print("\n警告: 没有检测到有效的图标,流程终止") - return { - "figure_path": str(figure_path), - "samed_path": samed_path, - "boxlib_path": boxlib_path, - "icon_infos": [], - "template_svg_path": None, - "final_svg_path": None, - } - - print(f"\n检测到 {len(valid_boxes)} 个图标") - - if stop_after == 2: - print("\n" + "=" * 60) - print("已在步骤 2 后停止") - print("=" * 60) - return { - "figure_path": str(figure_path), - "samed_path": samed_path, - "boxlib_path": boxlib_path, - "icon_infos": [], - "template_svg_path": None, - "optimized_template_path": None, - "final_svg_path": None, - } - - # 步骤三:裁切 + 去背景 - icon_infos = crop_and_remove_background( - image_path=str(figure_path), - boxlib_path=boxlib_path, - output_dir=str(output_dir), - rmbg_model_path=rmbg_model_path, - ) - - if stop_after == 3: - print("\n" + "=" * 60) - print("已在步骤 3 后停止") - print("=" * 60) - return { - "figure_path": str(figure_path), - "samed_path": samed_path, - "boxlib_path": boxlib_path, - "icon_infos": icon_infos, - "template_svg_path": None, - "optimized_template_path": None, - "final_svg_path": None, - } - - # 步骤四:生成 SVG 模板 - template_svg_path = output_dir / "template.svg" - generate_svg_template( - figure_path=str(figure_path), - samed_path=samed_path, - boxlib_path=boxlib_path, - output_path=str(template_svg_path), - api_key=api_key, - model=svg_gen_model, - base_url=base_url, - provider=provider, - placeholder_mode=placeholder_mode, - ) + if start_from <= 2: + # 步骤二:SAM3 分割(包含Box合并) + samed_path, boxlib_path, valid_boxes = segment_with_sam3( + image_path=str(figure_path), + output_dir=str(output_dir), + text_prompts=sam_prompts, + min_score=min_score, + merge_threshold=merge_threshold, + sam_backend=sam_backend_value, + sam_api_key=sam_api_key, + sam_max_masks=sam_max_masks, + ) + if len(valid_boxes) == 0: + print("\n警告: 没有检测到有效的图标,流程终止") + return { + "figure_path": str(figure_path), + "samed_path": samed_path, + "boxlib_path": boxlib_path, + "icon_infos": [], + "template_svg_path": None, + "final_svg_path": None, + } + print(f"\n检测到 {len(valid_boxes)} 个图标") + if stop_after == 2: + print("\n" + "=" * 60) + print("已在步骤 2 后停止") + print("=" * 60) + return { + "figure_path": str(figure_path), + "samed_path": samed_path, + "boxlib_path": boxlib_path, + "icon_infos": [], + "template_svg_path": None, + "optimized_template_path": None, + "final_svg_path": None, + } + + if start_from <= 3: + # 步骤三:裁切 + 去背景 + icon_infos = crop_and_remove_background( + image_path=str(figure_path), + boxlib_path=boxlib_path, + output_dir=str(output_dir), + rmbg_model_path=rmbg_model_path, + ) + if stop_after == 3: + print("\n" + "=" * 60) + print("已在步骤 3 后停止") + print("=" * 60) + return { + "figure_path": str(figure_path), + "samed_path": samed_path, + "boxlib_path": boxlib_path, + "icon_infos": icon_infos, + "template_svg_path": None, + "optimized_template_path": None, + "final_svg_path": None, + } + + if start_from >= 4: + # 从步骤 4 恢复:使用已有 figure.png, samed.png, boxlib.json, icons/ + samed_path = str(output_dir / "samed.png") + boxlib_path = str(output_dir / "boxlib.json") + icon_infos = _load_icon_infos_from_existing(output_dir, boxlib_path) + print(f"\n从步骤 {start_from} 恢复,已加载 {len(icon_infos)} 个图标") + + if start_from <= 4: + # 步骤四:生成 SVG 模板 + template_svg_path = output_dir / "template.svg" + generate_svg_template( + figure_path=str(figure_path), + samed_path=samed_path, + boxlib_path=boxlib_path, + output_path=str(template_svg_path), + api_key=api_key, + model=svg_gen_model, + base_url=base_url, + provider=provider, + placeholder_mode=placeholder_mode, + ) - # 步骤 4.6:LLM 优化 SVG 模板(可配置迭代次数,0 表示跳过) - optimized_template_path = output_dir / "optimized_template.svg" - optimize_svg_with_llm( - figure_path=str(figure_path), - samed_path=samed_path, - final_svg_path=str(template_svg_path), - output_path=str(optimized_template_path), - api_key=api_key, - model=svg_gen_model, - base_url=base_url, - provider=provider, - max_iterations=optimize_iterations, - skip_base64_validation=True, - ) + # 步骤 4.6:LLM 优化 SVG 模板(可配置迭代次数,0 表示跳过) + optimized_template_path = output_dir / "optimized_template.svg" + optimize_svg_with_llm( + figure_path=str(figure_path), + samed_path=samed_path, + final_svg_path=str(template_svg_path), + output_path=str(optimized_template_path), + api_key=api_key, + model=svg_gen_model, + base_url=base_url, + provider=provider, + max_iterations=optimize_iterations, + skip_base64_validation=True, + ) + else: + # start_from >= 5:使用已有的 template / optimized_template + template_svg_path = output_dir / "template.svg" + optimized_template_path = output_dir / "optimized_template.svg" + if not optimized_template_path.exists(): + optimized_template_path = template_svg_path if stop_after == 4: print("\n" + "=" * 60) @@ -2728,6 +2812,13 @@ def method_to_svg( default=5, help="执行到指定步骤后停止(1-5,默认: 5 完整流程)" ) + parser.add_argument( + "--start_from", + type=int, + choices=[1, 2, 3, 4, 5], + default=1, + help="从指定步骤开始(4 或 5 时跳过 1-3,使用 outputs 中已有的 figure.png/samed.png/boxlib.json/icons)" + ) # 占位符模式参数 parser.add_argument( @@ -2787,6 +2878,7 @@ def method_to_svg( sam_max_masks=args.sam_max_masks, rmbg_model_path=args.rmbg_model_path, stop_after=args.stop_after, + start_from=args.start_from, placeholder_mode=args.placeholder_mode, optimize_iterations=args.optimize_iterations, merge_threshold=args.merge_threshold, diff --git a/web/app.js b/web/app.js index 6264928..e21d072 100644 --- a/web/app.js +++ b/web/app.js @@ -40,7 +40,6 @@ } if (uploadZone && referenceFile) { - uploadZone.addEventListener("click", () => referenceFile.click()); uploadZone.addEventListener("dragover", (event) => { event.preventDefault(); uploadZone.classList.add("dragging"); diff --git a/web/index.html b/web/index.html index c03a8e6..afbc0a1 100644 --- a/web/index.html +++ b/web/index.html @@ -67,7 +67,7 @@
- +
Reference Image
Drop image here or click to upload