diff --git a/.changeset/tricky-coins-sniff.md b/.changeset/tricky-coins-sniff.md new file mode 100644 index 0000000000000..da3879ece93d3 --- /dev/null +++ b/.changeset/tricky-coins-sniff.md @@ -0,0 +1,6 @@ +--- +"@gradio/client": patch +"gradio": patch +--- + +fix:Set `root` correctly for Gradio apps that are deployed behind reverse proxies diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 41516ed48bfa5..2fb161c055f0a 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -301,6 +301,9 @@ export function api_factory( async function config_success(_config: Config): Promise { config = _config; + if (window.location.protocol === "https:") { + config.root = config.root.replace("http://", "https://"); + } api_map = map_names_to_ids(_config?.dependencies || []); if (config.auth_required) { return { diff --git a/gradio/route_utils.py b/gradio/route_utils.py index 6e663e5ff4b5b..bceda289f0a8a 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -261,18 +261,24 @@ async def call_process_api( return output -def get_root_url(request: fastapi.Request) -> str: +def get_root_url( + request: fastapi.Request, route_path: str, root_path: str | None +) -> str: """ - Gets the root url of the request, stripping off any query parameters and trailing slashes. - Also ensures that the root url is https if the request is https. + Gets the root url of the request, stripping off any query parameters, the route_path, and trailing slashes. + Also ensures that the root url is https if the request is https. If root_path is provided, it is appended to the root url. + The final root url will not have a trailing slash. """ root_url = str(request.url) root_url = httpx.URL(root_url) root_url = root_url.copy_with(query=None) - root_url = str(root_url) + root_url = str(root_url).rstrip("/") if request.headers.get("x-forwarded-proto") == "https": root_url = root_url.replace("http://", "https://") - return root_url.rstrip("/") + route_path = route_path.rstrip("/") + if len(route_path) > 0: + root_url = root_url[: -len(route_path)] + return (root_url.rstrip("/") + (root_path or "")).rstrip("/") def _user_safe_decode(src: bytes, codec: str) -> str: diff --git a/gradio/routes.py b/gradio/routes.py index 45af1a2bb644a..240aeb45117e1 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -311,7 +311,9 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()): def main(request: fastapi.Request, user: str = Depends(get_current_user)): mimetypes.add_type("application/javascript", ".js") blocks = app.get_blocks() - root_path = route_utils.get_root_url(request) + root_path = route_utils.get_root_url( + request=request, route_path="/", root_path=app.root_path + ) if app.auth is None or user is not None: config = copy.deepcopy(app.get_blocks().config) config["root"] = root_path @@ -353,7 +355,9 @@ def api_info(): @app.get("/config", dependencies=[Depends(login_check)]) def get_config(request: fastapi.Request): config = copy.deepcopy(app.get_blocks().config) - root_path = route_utils.get_root_url(request)[: -len("/config")] + root_path = route_utils.get_root_url( + request=request, route_path="/config", root_path=app.root_path + ) config["root"] = root_path config = add_root_url(config, root_path) return config @@ -570,7 +574,9 @@ async def predict( content={"error": str(error) if show_error else None}, status_code=500, ) - root_path = route_utils.get_root_url(request)[: -len(f"/api/{api_name}")] + root_path = route_utils.get_root_url( + request=request, route_path=f"/api/{api_name}", root_path=app.root_path + ) output = add_root_url(output, root_path) return output @@ -580,7 +586,9 @@ async def queue_data( session_hash: str, ): blocks = app.get_blocks() - root_path = route_utils.get_root_url(request)[: -len("/queue/data")] + root_path = route_utils.get_root_url( + request=request, route_path="/queue/data", root_path=app.root_path + ) async def sse_stream(request: fastapi.Request): try: diff --git a/test/test_routes.py b/test/test_routes.py index d6cf76bc7c670..84a3e1a5c907b 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -10,7 +10,7 @@ import pandas as pd import pytest import starlette.routing -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.testclient import TestClient from gradio_client import media_data @@ -25,7 +25,7 @@ routes, wasm_utils, ) -from gradio.route_utils import FnIndexInferError +from gradio.route_utils import FnIndexInferError, get_root_url @pytest.fixture() @@ -862,3 +862,50 @@ def test_component_server_endpoints(connect): }, ) assert fail_req.status_code == 404 + + +@pytest.mark.parametrize( + "request_url, route_path, root_path, expected_root_url", + [ + ("http://localhost:7860/", "/", None, "http://localhost:7860"), + ( + "http://localhost:7860/demo/test", + "/demo/test", + None, + "http://localhost:7860", + ), + ( + "http://localhost:7860/demo/test/", + "/demo/test", + None, + "http://localhost:7860", + ), + ( + "http://localhost:7860/demo/test?query=1", + "/demo/test", + None, + "http://localhost:7860", + ), + ( + "http://localhost:7860/demo/test?query=1", + "/demo/test/", + "/gradio/", + "http://localhost:7860/gradio", + ), + ( + "http://localhost:7860/demo/test?query=1", + "/demo/test", + "/gradio/", + "http://localhost:7860/gradio", + ), + ( + "https://localhost:7860/demo/test?query=1", + "/demo/test", + "/gradio/", + "https://localhost:7860/gradio", + ), + ], +) +def test_get_root_url(request_url, route_path, root_path, expected_root_url): + request = Request({"path": request_url, "type": "http", "headers": {}}) + assert get_root_url(request, route_path, root_path) == expected_root_url