diff --git a/acestep/api_server.py b/acestep/api_server.py index 186ef065..35dbb138 100644 --- a/acestep/api_server.py +++ b/acestep/api_server.py @@ -41,6 +41,7 @@ load_dotenv = None # type: ignore from fastapi import FastAPI, HTTPException, Request, Depends, Header +from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from starlette.datastructures import UploadFile as StarletteUploadFile @@ -2144,6 +2145,16 @@ async def _job_store_cleanup_worker() -> None: app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan) + # Enable CORS for browser-based frontends (e.g. studio.html opened via file://) + # Restricted to localhost origins and the "null" origin (file:// protocol) + app.add_middleware( + CORSMiddleware, + allow_origins=["null", "http://localhost", "http://127.0.0.1"], + allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$", + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["Content-Type", "Authorization"], + ) + # Mount OpenRouter-compatible endpoints (/v1/chat/completions, /v1/models) from acestep.openrouter_adapter import create_openrouter_router openrouter_router = create_openrouter_router(lambda: app.state) diff --git a/acestep/gradio_ui/api_routes.py b/acestep/gradio_ui/api_routes.py index 42c1bd81..76e4088f 100644 --- a/acestep/gradio_ui/api_routes.py +++ b/acestep/gradio_ui/api_routes.py @@ -10,6 +10,7 @@ from uuid import uuid4 from fastapi import APIRouter, HTTPException, Request, Depends, Header +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse # Global results directory inside project root @@ -512,6 +513,38 @@ def to_bool(val, default=False): raise HTTPException(status_code=500, detail=str(e)) +# Origins that are expected to call the API: +# - "null" → studio.html opened via file:// protocol +# - http://localhost:* → local dev servers / Gradio UI +# - http://127.0.0.1:* → same, numeric form +_CORS_KWARGS = dict( + allow_origins=["null", "http://localhost", "http://127.0.0.1"], + allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$", + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["Content-Type", "Authorization"], +) + + +def _add_cors_middleware(app): + """Add CORS middleware so browser-based frontends (e.g. studio.html via file://) can call the API.""" + app.add_middleware(CORSMiddleware, **_CORS_KWARGS) + + +def _add_cors_middleware_post_launch(app): + """Wrap an already-started app's middleware stack with CORS. + + ``add_middleware`` raises after Starlette has started, so we patch the + compiled middleware stack directly instead. + """ + from starlette.middleware.cors import CORSMiddleware as _CORSImpl + + if app.middleware_stack is not None: + app.middleware_stack = _CORSImpl(app=app.middleware_stack, **_CORS_KWARGS) + else: + # App hasn't built its stack yet – safe to use the normal path + _add_cors_middleware(app) + + def setup_api_routes_to_app(app, dit_handler, llm_handler, api_key: Optional[str] = None): """ Mount API routes to a FastAPI application (for use with gr.mount_gradio_app) @@ -523,6 +556,7 @@ def setup_api_routes_to_app(app, dit_handler, llm_handler, api_key: Optional[str api_key: Optional API key for authentication """ set_api_key(api_key) + _add_cors_middleware(app) app.state.dit_handler = dit_handler app.state.llm_handler = llm_handler app.include_router(router) @@ -540,6 +574,7 @@ def setup_api_routes(demo, dit_handler, llm_handler, api_key: Optional[str] = No """ set_api_key(api_key) app = demo.app + _add_cors_middleware_post_launch(app) app.state.dit_handler = dit_handler app.state.llm_handler = llm_handler app.include_router(router)