From 52885aa15a9fce454a793f61699b42f6a23edc12 Mon Sep 17 00:00:00 2001
From: Max <maximilianotorres@theluxergroup.com>
Date: Fri, 25 Oct 2024 17:08:01 -0300
Subject: [PATCH] Add option forwarded_allow_ips

---
 src/fastapi_cli/cli.py | 16 ++++++++++++++++
 tests/test_cli.py      | 33 +++++++++++++++++++++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/src/fastapi_cli/cli.py b/src/fastapi_cli/cli.py
index d5bcb8e..b12b4a0 100644
--- a/src/fastapi_cli/cli.py
+++ b/src/fastapi_cli/cli.py
@@ -60,6 +60,7 @@ def _run(
     command: str,
     app: Union[str, None] = None,
     proxy_headers: bool = False,
+    forwarded_allow_ips: Union[str, None] = None,
 ) -> None:
     try:
         use_uvicorn_app = get_import_string(path=path, app_name=app)
@@ -97,6 +98,7 @@ def _run(
         workers=workers,
         root_path=root_path,
         proxy_headers=proxy_headers,
+        forwarded_allow_ips=forwarded_allow_ips,
     )
 
 
@@ -145,6 +147,12 @@ def dev(
             help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to populate remote address info."
         ),
     ] = True,
+    forwarded_allow_ips: Annotated[
+        Union[str, None],
+        typer.Option(
+            help="Comma separated list of IP Addresses to trust with proxy headers. The literal '*' means trust everything."
+        ),
+    ] = None,
 ) -> Any:
     """
     Run a [bold]FastAPI[/bold] app in [yellow]development[/yellow] mode. ๐Ÿงช
@@ -180,6 +188,7 @@ def dev(
         app=app,
         command="dev",
         proxy_headers=proxy_headers,
+        forwarded_allow_ips=forwarded_allow_ips,
     )
 
 
@@ -234,6 +243,12 @@ def run(
             help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to populate remote address info."
         ),
     ] = True,
+    forwarded_allow_ips: Annotated[
+        Union[str, None],
+        typer.Option(
+            help="Comma separated list of IP Addresses to trust with proxy headers. The literal '*' means trust everything."
+        ),
+    ] = None,
 ) -> Any:
     """
     Run a [bold]FastAPI[/bold] app in [green]production[/green] mode. ๐Ÿš€
@@ -270,6 +285,7 @@ def run(
         app=app,
         command="run",
         proxy_headers=proxy_headers,
+        forwarded_allow_ips=forwarded_allow_ips,
     )
 
 
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 44c14d2..a2634a7 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -29,6 +29,7 @@ def test_dev() -> None:
                 "workers": None,
                 "root_path": "",
                 "proxy_headers": True,
+                "forwarded_allow_ips": None,
             }
         assert "Using import string single_file_app:app" in result.output
         assert (
@@ -71,6 +72,7 @@ def test_dev_args() -> None:
                 "workers": None,
                 "root_path": "/api",
                 "proxy_headers": False,
+                "forwarded_allow_ips": None,
             }
         assert "Using import string single_file_app:api" in result.output
         assert (
@@ -97,6 +99,36 @@ def test_run() -> None:
                 "workers": None,
                 "root_path": "",
                 "proxy_headers": True,
+                "forwarded_allow_ips": None,
+            }
+        assert "Using import string single_file_app:app" in result.output
+        assert (
+            "โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ FastAPI CLI - Production mode โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ" in result.output
+        )
+        assert "โ”‚  Serving at: http://0.0.0.0:8000" in result.output
+        assert "โ”‚  API docs: http://0.0.0.0:8000/docs" in result.output
+        assert "โ”‚  Running in production mode, for development use:" in result.output
+        assert "โ”‚  fastapi dev" in result.output
+
+
+def test_run_trust_proxy() -> None:
+    with changing_dir(assets_path):
+        with patch.object(uvicorn, "run") as mock_run:
+            result = runner.invoke(
+                app, ["run", "single_file_app.py", "--forwarded-allow-ips", "*"]
+            )
+            assert result.exit_code == 0, result.output
+            assert mock_run.called
+            assert mock_run.call_args
+            assert mock_run.call_args.kwargs == {
+                "app": "single_file_app:app",
+                "host": "0.0.0.0",
+                "port": 8000,
+                "reload": False,
+                "workers": None,
+                "root_path": "",
+                "proxy_headers": True,
+                "forwarded_allow_ips": "*",
             }
         assert "Using import string single_file_app:app" in result.output
         assert (
@@ -141,6 +173,7 @@ def test_run_args() -> None:
                 "workers": 2,
                 "root_path": "/api",
                 "proxy_headers": False,
+                "forwarded_allow_ips": None,
             }
         assert "Using import string single_file_app:api" in result.output
         assert (