From 9324cd1cc3b57b699aee6dfdd90114986e6be50e Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sat, 18 Jan 2025 22:33:55 +0100 Subject: [PATCH] Fix authentication for cases where webserver.base_url is not defined and worker is not using localhost in 2.10. --- providers/edge/docs/changelog.rst | 8 ++++++++ providers/edge/provider.yaml | 2 +- .../src/airflow/providers/edge/__init__.py | 2 +- .../airflow/providers/edge/worker_api/auth.py | 9 ++------- .../edge/worker_api/routes/_v2_routes.py | 19 +++++++++++++------ 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/providers/edge/docs/changelog.rst b/providers/edge/docs/changelog.rst index 8ef261b4f85c6..3b69ae8dad877 100644 --- a/providers/edge/docs/changelog.rst +++ b/providers/edge/docs/changelog.rst @@ -27,6 +27,14 @@ Changelog --------- +0.10.2pre0 +.......... + +Misc +~~~~ + +* ``Fix authentication for cases where webserver.base_url is not defined and worker is not using localhost in 2.10.`` + 0.10.1pre0 .......... diff --git a/providers/edge/provider.yaml b/providers/edge/provider.yaml index 4b36732e0392e..cca5561b2f372 100644 --- a/providers/edge/provider.yaml +++ b/providers/edge/provider.yaml @@ -25,7 +25,7 @@ source-date-epoch: 1729683247 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.10.1pre0 + - 0.10.2pre0 plugins: - name: edge_executor diff --git a/providers/edge/src/airflow/providers/edge/__init__.py b/providers/edge/src/airflow/providers/edge/__init__.py index a1fd5c5fb6e69..32f5d92608134 100644 --- a/providers/edge/src/airflow/providers/edge/__init__.py +++ b/providers/edge/src/airflow/providers/edge/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "0.10.1pre0" +__version__ = "0.10.2pre0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/edge/src/airflow/providers/edge/worker_api/auth.py b/providers/edge/src/airflow/providers/edge/worker_api/auth.py index 3d04fa93525c9..5373d39825e97 100644 --- a/providers/edge/src/airflow/providers/edge/worker_api/auth.py +++ b/providers/edge/src/airflow/providers/edge/worker_api/auth.py @@ -66,18 +66,13 @@ def _forbidden_response(message: str): def jwt_token_authorization(method: str, authorization: str): """Check if the JWT token is correct.""" try: - # worker sends method without api_url - api_url = conf.get("edge", "api_url") - base_url = conf.get("webserver", "base_url") - url_prefix = api_url.replace(base_url, "").replace("/rpcapi", "/") - pure_method = method.replace(url_prefix, "") payload = jwt_signer().verify_token(authorization) signed_method = payload.get("method") - if not signed_method or signed_method != pure_method: + if not signed_method or signed_method != method: _forbidden_response( "Invalid method in token authorization. " f"signed method='{signed_method}' " - f"called method='{pure_method}'", + f"called method='{method}'", ) except BadSignature: _forbidden_response("Bad Signature. Please use only the tokens provided by the API.") diff --git a/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py b/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py index 684c9dd4dae68..a6d1d3edfdce4 100644 --- a/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py +++ b/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py @@ -120,12 +120,19 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse: return e.to_response() # type: ignore[attr-defined] +def jwt_token_authorization_v2(method: str, authorization: str): + """Proxy for v2 method path handling.""" + PREFIX = "/edge_worker/v1/" + method_path = method[method.find(PREFIX) + len(PREFIX) :] if PREFIX in method else method + jwt_token_authorization(method_path, authorization) + + @provide_session def register_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> Any: """Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10.""" try: auth = request.headers.get("Authorization", "") - jwt_token_authorization(request.path, auth) + jwt_token_authorization_v2(request.path, auth) request_obj = WorkerStateBody( state=body["state"], jobs_active=0, queues=body["queues"], sysinfo=body["sysinfo"] ) @@ -139,7 +146,7 @@ def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> """Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10.""" try: auth = request.headers.get("Authorization", "") - jwt_token_authorization(request.path, auth) + jwt_token_authorization_v2(request.path, auth) request_obj = WorkerStateBody( state=body["state"], jobs_active=body["jobs_active"], @@ -158,7 +165,7 @@ def job_fetch_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> try: auth = request.headers.get("Authorization", "") - jwt_token_authorization(request.path, auth) + jwt_token_authorization_v2(request.path, auth) queues = body.get("queues") free_concurrency = body.get("free_concurrency", 1) request_obj = WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency) @@ -183,7 +190,7 @@ def job_state_v2( try: auth = request.headers.get("Authorization", "") - jwt_token_authorization(request.path, auth) + jwt_token_authorization_v2(request.path, auth) state_api(dag_id, task_id, run_id, try_number, int(map_index), state, session) except HTTPException as e: return e.to_response() # type: ignore[attr-defined] @@ -199,7 +206,7 @@ def logfile_path_v2( """Handle Edge Worker API `/edge_worker/v1/logs/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}` endpoint for Airflow 2.10.""" try: auth = request.headers.get("Authorization", "") - jwt_token_authorization(request.path, auth) + jwt_token_authorization_v2(request.path, auth) return logfile_path(dag_id, task_id, run_id, try_number, int(map_index)) except HTTPException as e: return e.to_response() # type: ignore[attr-defined] @@ -216,7 +223,7 @@ def push_logs_v2( """Handle Edge Worker API `/edge_worker/v1/logs/push/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}` endpoint for Airflow 2.10.""" try: auth = request.headers.get("Authorization", "") - jwt_token_authorization(request.path, auth) + jwt_token_authorization_v2(request.path, auth) request_obj = PushLogsBody( log_chunk_data=body["log_chunk_data"], log_chunk_time=body["log_chunk_time"] )