diff --git a/acestep/api_server.py b/acestep/api_server.py index 7b2e1624..3cda01d6 100644 --- a/acestep/api_server.py +++ b/acestep/api_server.py @@ -894,13 +894,33 @@ def bool(self, name: str, default: bool = False) -> bool: def _validate_audio_path(path: Optional[str]) -> Optional[str]: """Validate a user-supplied audio file path to prevent path traversal attacks. - Rejects absolute paths and paths containing '..' traversal sequences. - Returns the validated path or None if the input is None/empty. + Accepts absolute paths strictly only if they are within the system temporary directory. + Otherwise, rejects absolute paths and paths containing '..' traversal sequences. + + Returns the validated, normalized path or None if the input is None/empty. Raises HTTPException 400 if the path is unsafe. """ if not path: return None - # Reject absolute paths (Unix and Windows) + + # Resolve requested path and system temp path to normalized absolute forms + import tempfile + system_temp = os.path.realpath(tempfile.gettempdir()) + requested_path = os.path.realpath(path) + + # SECURE CHECK: Use os.path.commonpath to verify directory boundary integrity. + # This prevents prefix bypasses (e.g., /tmp_evil when /tmp is allowed). + try: + is_in_temp = os.path.commonpath([system_temp, requested_path]) == system_temp + except ValueError: + # Occurs on Windows if paths are on different drives + is_in_temp = False + + if is_in_temp: + # Accept server-generated files in temp + return requested_path + + # Reject manual absolute paths outside of temp if os.path.isabs(path): raise HTTPException(status_code=400, detail="absolute audio file paths are not allowed") # Reject path traversal via '..' components @@ -2220,8 +2240,8 @@ def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest: repainting_end=p.float("repainting_end"), instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION), audio_cover_strength=p.float("audio_cover_strength", 1.0), - reference_audio_path=_validate_audio_path(ref_audio), - src_audio_path=_validate_audio_path(src_audio), + reference_audio_path=ref_audio, + src_audio_path=src_audio, task_type=p.str("task_type", "text2music"), use_adg=p.bool("use_adg"), cfg_interval_start=p.float("cfg_interval_start", 0.0), @@ -2252,14 +2272,27 @@ def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest: if not isinstance(body, dict): raise HTTPException(status_code=400, detail="JSON payload must be an object") verify_token_from_request(body, authorization) - req = _build_request(RequestParser(body)) + + # Explicitly validate manual string paths from JSON input + p = RequestParser(body) + req = _build_request( + p, + reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None), + src_audio_path=_validate_audio_path(p.str("src_audio_path") or None) + ) elif content_type.endswith("+json"): body = await request.json() if not isinstance(body, dict): raise HTTPException(status_code=400, detail="JSON payload must be an object") verify_token_from_request(body, authorization) - req = _build_request(RequestParser(body)) + + p = RequestParser(body) + req = _build_request( + p, + reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None), + src_audio_path=_validate_audio_path(p.str("src_audio_path") or None) + ) elif content_type.startswith("multipart/form-data"): form = await request.form() @@ -2312,7 +2345,12 @@ def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest: body = json.loads(raw.decode("utf-8")) if isinstance(body, dict): verify_token_from_request(body, authorization) - req = _build_request(RequestParser(body)) + p = RequestParser(body) + req = _build_request( + p, + reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None), + src_audio_path=_validate_audio_path(p.str("src_audio_path") or None) + ) else: raise HTTPException(status_code=400, detail="JSON payload must be an object") except HTTPException: