diff --git a/README.md b/README.md index 5a2c1dc..534e942 100644 --- a/README.md +++ b/README.md @@ -324,6 +324,7 @@ Optional CLI flags (API): | **OpenRouter** | `openrouter.ai/api/v1` | Supports Gemini/Claude/others | | **Bianxie** | `api.bianxie.ai/v1` | OpenAI-compatible API | | **Gemini (Google)** | `generativelanguage.googleapis.com/v1beta` | Official Google Gemini API (`google-genai`) | +| **MiniMax** | `api.minimax.io/v1` | OpenAI-compatible API; SVG generation only (no image gen) | Common CLI flags: @@ -337,6 +338,33 @@ Common CLI flags: - `--merge_threshold` (0 disables merging) - `--optimize_iterations` (0 disables optimization) - `--reference_image_path` (optional) +- `--figure_path` (pre-generated figure image, skips step 1) + +### Using MiniMax + +[MiniMax](https://www.minimaxi.com/) provides OpenAI-compatible LLM APIs with strong multimodal capabilities (MiniMax-M2.7, 204K context). MiniMax excels at SVG generation and text understanding tasks but does not support image generation. Use `--figure_path` to provide a pre-generated figure: + +```bash +# Step 1: Generate figure with another provider +python autofigure2.py \ + --method_file paper.txt \ + --output_dir outputs/step1 \ + --provider gemini \ + --api_key YOUR_GEMINI_KEY \ + --stop_after 1 + +# Step 2: Use MiniMax for SVG generation (steps 2-5) +export MINIMAX_API_KEY="your-minimax-key" +python autofigure2.py \ + --method_file paper.txt \ + --output_dir outputs/demo \ + --provider minimax \ + --figure_path outputs/step1/figure.png +``` + +Available MiniMax models: +- `MiniMax-M2.7` (default) — latest model, 204K context +- `MiniMax-M2.7-highspeed` — faster variant, 204K context ### Custom Provider / Custom Base URL diff --git a/README_ZH.md b/README_ZH.md index baf13ef..b35374f 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -314,10 +314,11 @@ python autofigure2.py \ | **OpenRouter** | `openrouter.ai/api/v1` | 支持 Gemini/Claude/其他模型 | | **Bianxie** | `api.bianxie.ai/v1` | 兼容 OpenAI 接口 | | **Gemini (Google)** | `generativelanguage.googleapis.com/v1beta` | Google 官方 Gemini API(`google-genai`) | +| **MiniMax** | `api.minimax.io/v1` | 兼容 OpenAI 接口;仅 SVG 生成(不支持图像生成) | 常用 CLI 参数: -- `--provider` (openrouter | bianxie | gemini) +- `--provider` (openrouter | bianxie | gemini | minimax) - `--image_model`, `--svg_model` - `--image_size` (1K | 2K | 4K,仅 Gemini) - `--sam_prompt` (逗号分隔的提示词) @@ -327,6 +328,33 @@ python autofigure2.py \ - `--merge_threshold` (0 禁用合并) - `--optimize_iterations` (0 禁用优化) - `--reference_image_path` (可选) +- `--figure_path` (预生成图片路径,跳过步骤一) + +### 使用 MiniMax + +[MiniMax](https://www.minimaxi.com/) 提供兼容 OpenAI 的 LLM API,具有强大的多模态能力(MiniMax-M2.7,204K 上下文)。MiniMax 擅长 SVG 生成和文本理解,但不支持图像生成。使用 `--figure_path` 提供预生成的图片: + +```bash +# 步骤 1:用其他 provider 生成图片 +python autofigure2.py \ + --method_file paper.txt \ + --output_dir outputs/step1 \ + --provider gemini \ + --api_key YOUR_GEMINI_KEY \ + --stop_after 1 + +# 步骤 2:用 MiniMax 进行 SVG 生成(步骤 2-5) +export MINIMAX_API_KEY="your-minimax-key" +python autofigure2.py \ + --method_file paper.txt \ + --output_dir outputs/demo \ + --provider minimax \ + --figure_path outputs/step1/figure.png +``` + +可用的 MiniMax 模型: +- `MiniMax-M2.7`(默认)— 最新模型,204K 上下文 +- `MiniMax-M2.7-highspeed` — 高速版本,204K 上下文 ### 自定义提供商 / 自定义 Base URL diff --git a/autofigure2.py b/autofigure2.py index 8101b7f..ec5188a 100644 --- a/autofigure2.py +++ b/autofigure2.py @@ -5,6 +5,7 @@ - openrouter: OpenRouter API (https://openrouter.ai/api/v1) - bianxie: Bianxie API (https://api.bianxie.ai/v1) - 使用 OpenAI SDK - gemini: Google Gemini 官方 API (https://ai.google.dev/) +- minimax: MiniMax API (https://api.minimax.io/v1) - 使用 OpenAI SDK(仅 SVG 生成,不支持图像生成) 占位符模式 (--placeholder_mode): - none: 无特殊样式(默认黑色边框) @@ -104,9 +105,14 @@ "default_image_model": "gemini-3-pro-image-preview", "default_svg_model": "gemini-3.1-pro", }, + "minimax": { + "base_url": "https://api.minimax.io/v1", + "default_image_model": None, + "default_svg_model": "MiniMax-M2.7", + }, } -ProviderType = Literal["openrouter", "bianxie", "gemini"] +ProviderType = Literal["openrouter", "bianxie", "gemini", "minimax"] PlaceholderMode = Literal["none", "box", "label"] GEMINI_DEFAULT_IMAGE_SIZE = "4K" IMAGE_SIZE_CHOICES = ("1K", "2K", "4K") @@ -158,6 +164,8 @@ def call_llm_text( return _call_bianxie_text(prompt, api_key, model, base_url, max_tokens, temperature) if provider == "gemini": return _call_gemini_text(prompt, api_key, model, max_tokens, temperature) + if provider == "minimax": + return _call_minimax_text(prompt, api_key, model, base_url, max_tokens, temperature) return _call_openrouter_text(prompt, api_key, model, base_url, max_tokens, temperature) @@ -189,6 +197,8 @@ def call_llm_multimodal( return _call_bianxie_multimodal(contents, api_key, model, base_url, max_tokens, temperature) if provider == "gemini": return _call_gemini_multimodal(contents, api_key, model, max_tokens, temperature) + if provider == "minimax": + return _call_minimax_multimodal(contents, api_key, model, base_url, max_tokens, temperature) return _call_openrouter_multimodal(contents, api_key, model, base_url, max_tokens, temperature) @@ -224,6 +234,8 @@ def call_llm_image_generation( reference_image=reference_image, image_size=image_size, ) + if provider == "minimax": + return _call_minimax_image_generation(prompt, api_key, model, base_url, reference_image) return _call_openrouter_image_generation(prompt, api_key, model, base_url, reference_image) @@ -941,6 +953,114 @@ def _call_gemini_image_generation( raise +# ============================================================================ +# MiniMax Provider 实现 (OpenAI-compatible API) +# ============================================================================ + +def _clamp_minimax_temperature(temperature: float) -> float: + """MiniMax requires temperature in (0.0, 1.0].""" + if temperature <= 0.0: + return 0.01 + if temperature > 1.0: + return 1.0 + return temperature + + +def _call_minimax_text( + prompt: str, + api_key: str, + model: str, + base_url: str, + max_tokens: int = 16000, + temperature: float = 0.7, +) -> Optional[str]: + """Call MiniMax text API via OpenAI SDK (OpenAI-compatible endpoint).""" + try: + from openai import OpenAI + + client = OpenAI(base_url=base_url, api_key=api_key) + temperature = _clamp_minimax_temperature(temperature) + + completion = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + temperature=temperature, + ) + + return completion.choices[0].message.content if completion and completion.choices else None + except Exception as e: + print(f"[MiniMax] API call failed: {e}") + raise + + +def _call_minimax_multimodal( + contents: List[Any], + api_key: str, + model: str, + base_url: str, + max_tokens: int = 16000, + temperature: float = 0.7, +) -> Optional[str]: + """Call MiniMax multimodal API via OpenAI SDK (vision support).""" + try: + from openai import OpenAI + + client = OpenAI(base_url=base_url, api_key=api_key) + temperature = _clamp_minimax_temperature(temperature) + + message_content: List[Dict[str, Any]] = [] + for part in contents: + if isinstance(part, str): + message_content.append({"type": "text", "text": part}) + elif isinstance(part, Image.Image): + buf = io.BytesIO() + part.save(buf, format='PNG') + image_b64 = base64.b64encode(buf.getvalue()).decode('utf-8') + message_content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_b64}"} + }) + + completion = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": message_content}], + max_tokens=max_tokens, + temperature=temperature, + ) + + text = completion.choices[0].message.content if completion and completion.choices else None + if text: + # Strip ... tags from MiniMax reasoning models, + # but only if there is content outside the think block. + stripped = re.sub(r".*?", "", text, flags=re.DOTALL).strip() + if stripped: + text = stripped + return text + except Exception as e: + print(f"[MiniMax] Multimodal API call failed: {e}") + raise + + +def _call_minimax_image_generation( + prompt: str, + api_key: str, + model: str, + base_url: str, + reference_image: Optional[Image.Image] = None, +) -> Optional[Image.Image]: + """MiniMax does not support image generation. + + Use --figure_path to provide a pre-generated figure when using + MiniMax as the provider, or use a different provider for step 1. + """ + raise NotImplementedError( + "MiniMax does not support image generation. " + "Use --figure_path to provide a pre-generated figure, " + "or use a different provider (openrouter/bianxie/gemini) for image generation." + ) + + # ============================================================================ # 步骤一:调用 LLM 生成图片 # ============================================================================ @@ -2854,6 +2974,7 @@ def method_to_svg( optimize_iterations: int = 2, merge_threshold: float = 0.9, image_size: str = GEMINI_DEFAULT_IMAGE_SIZE, + figure_path: Optional[str] = None, ) -> dict: """ 完整流程:Paper Method → SVG with Icons @@ -2879,6 +3000,7 @@ def method_to_svg( - "label": 灰色填充+黑色边框+序号标签(推荐) optimize_iterations: 步骤 4.6 优化迭代次数(0 表示跳过优化) merge_threshold: Box合并阈值,重叠比例超过此值则合并(0表示不合并,默认0.9) + figure_path: 预生成的图片路径(跳过步骤一),适用于不支持图像生成的 provider(如 MiniMax) Returns: 结果字典 @@ -2919,24 +3041,35 @@ def method_to_svg( print(f"生图分辨率: {image_size}") print("=" * 60) - # 步骤一:生成图片 - figure_path = output_dir / "figure.png" - generate_figure_from_method( - method_text=method_text, - output_path=str(figure_path), - api_key=api_key, - model=image_gen_model, - base_url=base_url, - provider=provider, - image_size=image_size, - ) + # 步骤一:生成图片(或使用预生成的图片) + figure_path_out = output_dir / "figure.png" + if figure_path: + # 使用预生成的图片,跳过 LLM 图像生成 + src = Path(figure_path) + if not src.is_file(): + raise FileNotFoundError(f"预生成的图片不存在: {figure_path}") + print("=" * 60) + print("步骤一:使用预生成图片(跳过 LLM 图像生成)") + print("=" * 60) + print(f"图片来源: {figure_path}") + shutil.copy2(str(src), str(figure_path_out)) + else: + generate_figure_from_method( + method_text=method_text, + output_path=str(figure_path_out), + api_key=api_key, + model=image_gen_model, + base_url=base_url, + provider=provider, + image_size=image_size, + ) if stop_after == 1: print("\n" + "=" * 60) print("已在步骤 1 后停止") print("=" * 60) return { - "figure_path": str(figure_path), + "figure_path": str(figure_path_out), "samed_path": None, "boxlib_path": None, "icon_infos": [], @@ -2947,7 +3080,7 @@ def method_to_svg( # 步骤二:SAM3 分割(包含Box合并) samed_path, boxlib_path, valid_boxes = segment_with_sam3( - image_path=str(figure_path), + image_path=str(figure_path_out), output_dir=str(output_dir), text_prompts=sam_prompts, min_score=min_score, @@ -2968,7 +3101,7 @@ def method_to_svg( print("已在步骤 2 后停止") print("=" * 60) return { - "figure_path": str(figure_path), + "figure_path": str(figure_path_out), "samed_path": samed_path, "boxlib_path": boxlib_path, "icon_infos": [], @@ -2984,7 +3117,7 @@ def method_to_svg( else: _ensure_rmbg2_access_ready(rmbg_model_path) icon_infos = crop_and_remove_background( - image_path=str(figure_path), + image_path=str(figure_path_out), boxlib_path=boxlib_path, output_dir=str(output_dir), rmbg_model_path=rmbg_model_path, @@ -2995,7 +3128,7 @@ def method_to_svg( print("已在步骤 3 后停止") print("=" * 60) return { - "figure_path": str(figure_path), + "figure_path": str(figure_path_out), "samed_path": samed_path, "boxlib_path": boxlib_path, "icon_infos": icon_infos, @@ -3010,7 +3143,7 @@ def method_to_svg( final_svg_path = output_dir / "final.svg" try: generate_svg_template( - figure_path=str(figure_path), + figure_path=str(figure_path_out), samed_path=samed_path, boxlib_path=boxlib_path, output_path=str(template_svg_path), @@ -3024,7 +3157,7 @@ def method_to_svg( # 步骤 4.6:LLM 优化 SVG 模板(可配置迭代次数,0 表示跳过) optimize_svg_with_llm( - figure_path=str(figure_path), + figure_path=str(figure_path_out), samed_path=samed_path, final_svg_path=str(template_svg_path), output_path=str(optimized_template_path), @@ -3041,7 +3174,7 @@ def method_to_svg( raise print(f"无图标模式下 SVG 重建失败({exc}),改用内嵌原图的保底 SVG") create_embedded_figure_svg( - figure_path=str(figure_path), + figure_path=str(figure_path_out), output_path=str(final_svg_path), ) @@ -3050,7 +3183,7 @@ def method_to_svg( print("已在步骤 4 后停止") print("=" * 60) return { - "figure_path": str(figure_path), + "figure_path": str(figure_path_out), "samed_path": samed_path, "boxlib_path": boxlib_path, "icon_infos": icon_infos, @@ -3069,7 +3202,7 @@ def method_to_svg( else: print("无图标模式缺少模板 SVG,生成保底 final.svg") create_embedded_figure_svg( - figure_path=str(figure_path), + figure_path=str(figure_path_out), output_path=str(final_svg_path), ) else: @@ -3078,7 +3211,7 @@ def method_to_svg( print("步骤 4.7:坐标系对齐") print("-" * 50) - figure_img = Image.open(figure_path) + figure_img = Image.open(figure_path_out) figure_width, figure_height = figure_img.size print(f"原图尺寸: {figure_width} x {figure_height}") @@ -3114,7 +3247,7 @@ def method_to_svg( print("\n" + "=" * 60) print("流程完成!") print("=" * 60) - print(f"原始图片: {figure_path}") + print(f"原始图片: {figure_path_out}") print(f"标记图片: {samed_path}") print(f"Box信息: {boxlib_path}") print(f"图标数量: {len(icon_infos)}") @@ -3123,7 +3256,7 @@ def method_to_svg( print(f"最终SVG: {final_svg_path}") return { - "figure_path": str(figure_path), + "figure_path": str(figure_path_out), "samed_path": samed_path, "boxlib_path": boxlib_path, "icon_infos": icon_infos, @@ -3180,13 +3313,13 @@ def create_embedded_figure_svg( # Provider 参数 parser.add_argument( "--provider", - choices=["openrouter", "bianxie", "gemini"], + choices=["openrouter", "bianxie", "gemini", "minimax"], default="bianxie", help="API 提供商(默认: bianxie)" ) # API 参数 - parser.add_argument("--api_key", default=None, help="API Key") + parser.add_argument("--api_key", default=None, help="API Key(或设置 MINIMAX_API_KEY 环境变量)") parser.add_argument("--base_url", default=None, help="API base URL(默认根据 provider 自动设置)") # 模型参数 @@ -3199,6 +3332,13 @@ def create_embedded_figure_svg( ) parser.add_argument("--svg_model", default=None, help="SVG生成模型(默认根据 provider 自动设置)") + # 预生成图片参数(跳过步骤一,适用于 MiniMax 等不支持图像生成的 provider) + parser.add_argument( + "--figure_path", + default=None, + help="预生成的图片路径(跳过步骤一的 LLM 图像生成,适用于 minimax 等不支持图像生成的 provider)", + ) + # Step 1 参考图片参数 parser.add_argument( "--use_reference_image", @@ -3266,12 +3406,19 @@ def create_embedded_figure_svg( parser.error("--use_reference_image 需要 --reference_image_path") if args.reference_image_path and not Path(args.reference_image_path).is_file(): parser.error(f"参考图片不存在: {args.reference_image_path}") + if args.figure_path and not Path(args.figure_path).is_file(): + parser.error(f"预生成图片不存在: {args.figure_path}") USE_REFERENCE_IMAGE = bool(args.use_reference_image) REFERENCE_IMAGE_PATH = args.reference_image_path if REFERENCE_IMAGE_PATH: USE_REFERENCE_IMAGE = True + # API key: CLI flag > environment variable + api_key = args.api_key + if not api_key and args.provider == "minimax": + api_key = os.environ.get("MINIMAX_API_KEY") + # 获取 method 文本:优先使用 --method_text method_text = args.method_text if method_text is None: @@ -3282,7 +3429,7 @@ def create_embedded_figure_svg( result = method_to_svg( method_text=method_text, output_dir=args.output_dir, - api_key=args.api_key, + api_key=api_key, base_url=args.base_url, provider=args.provider, image_gen_model=args.image_model, @@ -3298,4 +3445,5 @@ def create_embedded_figure_svg( placeholder_mode=args.placeholder_mode, optimize_iterations=args.optimize_iterations, merge_threshold=args.merge_threshold, + figure_path=args.figure_path, ) diff --git a/server.py b/server.py index 4cecf5c..13a5502 100644 --- a/server.py +++ b/server.py @@ -101,6 +101,7 @@ class RunRequest(BaseModel): merge_threshold: Optional[float] = None optimize_iterations: Optional[int] = None reference_image_path: Optional[str] = None + figure_path: Optional[str] = None app = FastAPI() @@ -174,6 +175,15 @@ def run_job(req: RunRequest) -> JSONResponse: ) cmd += ["--reference_image_path", reference_path] + if req.figure_path: + fig_path = req.figure_path + fig_path = ( + str((BASE_DIR / fig_path).resolve()) + if not Path(fig_path).is_absolute() + else fig_path + ) + cmd += ["--figure_path", fig_path] + env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_minimax_provider.py b/tests/test_minimax_provider.py new file mode 100644 index 0000000..f6077a4 --- /dev/null +++ b/tests/test_minimax_provider.py @@ -0,0 +1,357 @@ +"""Unit tests for the MiniMax LLM provider integration.""" + +from __future__ import annotations + +import io +import json +import os +import shutil +import tempfile +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Import helpers – the module under test uses heavy ML deps (torch, etc.) +# We mock them when they are not installed so the provider-level tests still +# run in lightweight CI environments. +# --------------------------------------------------------------------------- + +import importlib +import sys + +_STUBS: dict[str, Any] = {} + + +def _ensure_stubs(): + """Create lightweight stubs for heavy optional deps if they are missing.""" + for mod_name in ("torch", "torchvision", "torchvision.transforms", + "timm", "transformers", "kornia"): + if mod_name not in sys.modules: + stub = MagicMock() + sys.modules[mod_name] = stub + _STUBS[mod_name] = stub + + +_ensure_stubs() + +# Now we can import the target module +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +import autofigure2 as af + + +# ============================================================================ +# Temperature clamping +# ============================================================================ + +class TestMiniMaxTemperatureClamping: + """_clamp_minimax_temperature must keep values in (0, 1].""" + + def test_zero_is_clamped_up(self): + assert af._clamp_minimax_temperature(0.0) == 0.01 + + def test_negative_is_clamped_up(self): + assert af._clamp_minimax_temperature(-0.5) == 0.01 + + def test_above_one_is_clamped_down(self): + assert af._clamp_minimax_temperature(1.5) == 1.0 + + def test_exactly_one_is_kept(self): + assert af._clamp_minimax_temperature(1.0) == 1.0 + + def test_normal_value_is_kept(self): + assert af._clamp_minimax_temperature(0.7) == 0.7 + + def test_small_positive_is_kept(self): + assert af._clamp_minimax_temperature(0.01) == 0.01 + + +# ============================================================================ +# Provider configuration +# ============================================================================ + +class TestMiniMaxProviderConfig: + """PROVIDER_CONFIGS must include minimax with correct defaults.""" + + def test_minimax_in_configs(self): + assert "minimax" in af.PROVIDER_CONFIGS + + def test_base_url(self): + assert af.PROVIDER_CONFIGS["minimax"]["base_url"] == "https://api.minimax.io/v1" + + def test_default_svg_model(self): + assert af.PROVIDER_CONFIGS["minimax"]["default_svg_model"] == "MiniMax-M2.7" + + def test_no_default_image_model(self): + assert af.PROVIDER_CONFIGS["minimax"]["default_image_model"] is None + + +# ============================================================================ +# Image generation raises NotImplementedError +# ============================================================================ + +class TestMiniMaxImageGeneration: + """MiniMax must refuse image generation with a clear error.""" + + def test_raises_not_implemented(self): + with pytest.raises(NotImplementedError, match="MiniMax does not support image generation"): + af._call_minimax_image_generation( + prompt="draw a cat", + api_key="test-key", + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + ) + + def test_dispatcher_raises_for_minimax(self): + with pytest.raises(NotImplementedError): + af.call_llm_image_generation( + prompt="draw a cat", + api_key="test-key", + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + provider="minimax", + ) + + +# ============================================================================ +# Text call dispatching +# ============================================================================ + +class TestMiniMaxTextCall: + """call_llm_text must route to _call_minimax_text for minimax provider.""" + + @patch("autofigure2._call_minimax_text", return_value="hello from minimax") + def test_dispatch(self, mock_call): + result = af.call_llm_text( + prompt="say hello", + api_key="test-key", + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + provider="minimax", + ) + assert result == "hello from minimax" + mock_call.assert_called_once() + + @patch("autofigure2._call_minimax_text") + def test_temperature_passed(self, mock_call): + af.call_llm_text( + prompt="hi", + api_key="k", + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + provider="minimax", + temperature=0.5, + ) + _, kwargs = mock_call.call_args + assert kwargs.get("temperature", mock_call.call_args[0][5] if len(mock_call.call_args[0]) > 5 else None) is not None + + +# ============================================================================ +# Multimodal call dispatching +# ============================================================================ + +class TestMiniMaxMultimodalCall: + """call_llm_multimodal must route to _call_minimax_multimodal.""" + + @patch("autofigure2._call_minimax_multimodal", return_value="svg code here") + def test_dispatch(self, mock_call): + result = af.call_llm_multimodal( + contents=["generate svg"], + api_key="test-key", + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + provider="minimax", + ) + assert result == "svg code here" + mock_call.assert_called_once() + + +# ============================================================================ +# Think-tag stripping in multimodal response +# ============================================================================ + +class TestThinkTagStripping: + """MiniMax multimodal must strip tags from responses.""" + + @patch("autofigure2.OpenAI", create=True) + def test_think_tag_stripped_when_content_outside(self, mock_openai_cls): + """Think tags should be stripped when there is content outside them.""" + mock_msg = MagicMock() + mock_msg.content = "reasoning hereSVG output" + mock_choice = MagicMock() + mock_choice.message = mock_msg + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_completion + mock_openai_cls.return_value = mock_client + + with patch.dict("sys.modules", {"openai": MagicMock(OpenAI=mock_openai_cls)}): + result = af._call_minimax_multimodal( + contents=["test"], + api_key="key", + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + ) + + assert "" not in result + assert result == "SVG output" + + @patch("autofigure2.OpenAI", create=True) + def test_think_tag_preserved_when_only_content(self, mock_openai_cls): + """If all content is inside think tags, preserve the original response.""" + mock_msg = MagicMock() + mock_msg.content = "the entire answer is here" + mock_choice = MagicMock() + mock_choice.message = mock_msg + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_completion + mock_openai_cls.return_value = mock_client + + with patch.dict("sys.modules", {"openai": MagicMock(OpenAI=mock_openai_cls)}): + result = af._call_minimax_multimodal( + contents=["test"], + api_key="key", + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + ) + + # Should preserve the original since stripping would leave empty + assert result == "the entire answer is here" + + +# ============================================================================ +# figure_path support in method_to_svg +# ============================================================================ + +class TestFigurePathSupport: + """method_to_svg should accept figure_path and skip image generation.""" + + def test_figure_path_copies_to_output(self): + """When figure_path is provided, step 1 should copy it instead of calling LLM.""" + # Create a temporary figure file + with tempfile.TemporaryDirectory() as tmpdir: + # Create a minimal PNG (1x1 pixel) + from PIL import Image + fig_src = Path(tmpdir) / "source_figure.png" + img = Image.new("RGB", (100, 100), color="red") + img.save(str(fig_src)) + + output_dir = Path(tmpdir) / "output" + output_dir.mkdir() + + # Patch heavy functions that would run in step 2+ + with patch.object(af, "segment_with_sam3") as mock_sam, \ + patch.object(af, "_ensure_rmbg2_access_ready"), \ + patch.object(af, "crop_and_remove_background", return_value=[]), \ + patch.object(af, "generate_svg_template"), \ + patch.object(af, "optimize_svg_with_llm"), \ + patch.object(af, "replace_icons_in_svg"): + + mock_sam.return_value = (str(output_dir / "samed.png"), str(output_dir / "boxlib.json"), []) + + # Create dummy samed.png and boxlib.json that steps expect + Image.new("RGB", (100, 100)).save(str(output_dir / "samed.png")) + Path(output_dir / "boxlib.json").write_text("[]") + + result = af.method_to_svg( + method_text="test method", + output_dir=str(output_dir), + api_key="test-key", + provider="minimax", + figure_path=str(fig_src), + stop_after=1, + ) + + # The figure should be copied to output + assert (output_dir / "figure.png").is_file() + assert result["figure_path"] == str(output_dir / "figure.png") + + def test_missing_figure_path_raises(self): + """When figure_path points to a missing file, should raise FileNotFoundError.""" + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(FileNotFoundError, match="预生成的图片不存在"): + af.method_to_svg( + method_text="test", + output_dir=str(Path(tmpdir) / "output"), + api_key="test-key", + provider="minimax", + figure_path="/nonexistent/figure.png", + ) + + +# ============================================================================ +# CLI argument parsing +# ============================================================================ + +class TestCLIArgs: + """Verify minimax is accepted by the CLI argparse config.""" + + def test_minimax_in_provider_choices(self): + """The --provider choices must include minimax.""" + import argparse + # Extract the provider argument from the parser setup + # We check PROVIDER_CONFIGS since the choices list comes from there + assert "minimax" in af.PROVIDER_CONFIGS + + def test_figure_path_default_none(self): + """--figure_path should default to None when not provided.""" + # This is implicitly tested by the argparse default + # We verify via the method_to_svg signature + import inspect + sig = inspect.signature(af.method_to_svg) + assert sig.parameters["figure_path"].default is None + + +# ============================================================================ +# Integration test (requires MINIMAX_API_KEY) +# ============================================================================ + +@pytest.mark.skipif( + not os.environ.get("MINIMAX_API_KEY"), + reason="MINIMAX_API_KEY not set", +) +class TestMiniMaxIntegration: + """Integration tests that call the real MiniMax API.""" + + def test_text_call(self): + """A simple text completion should return a non-empty string.""" + result = af._call_minimax_text( + prompt="Reply with exactly: MINIMAX_OK", + api_key=os.environ["MINIMAX_API_KEY"], + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + max_tokens=50, + temperature=0.01, + ) + assert result is not None + assert len(result) > 0 + + def test_multimodal_call(self): + """A multimodal call with text-only content should work.""" + result = af._call_minimax_multimodal( + contents=["What is 2+2? Answer with only the number."], + api_key=os.environ["MINIMAX_API_KEY"], + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + max_tokens=50, + temperature=0.5, + ) + assert result is not None + assert len(result) > 0 + + def test_image_generation_raises(self): + """Image generation must raise NotImplementedError even with a real key.""" + with pytest.raises(NotImplementedError): + af._call_minimax_image_generation( + prompt="a red circle", + api_key=os.environ["MINIMAX_API_KEY"], + model="MiniMax-M2.7", + base_url="https://api.minimax.io/v1", + ) diff --git a/web/index.html b/web/index.html index cf479cb..919d21e 100644 --- a/web/index.html +++ b/web/index.html @@ -34,6 +34,7 @@ +