Skip to content

Commit

Permalink
Fix authentication for cases where webserver.base_url is not defined …
Browse files Browse the repository at this point in the history
…and worker is not using localhost in 2.10.
  • Loading branch information
jscheffl committed Jan 18, 2025
1 parent d414ff9 commit 169c2b1
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 15 deletions.
8 changes: 8 additions & 0 deletions providers/src/airflow/providers/edge/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
Changelog
---------

0.10.2pre0
..........

Misc
~~~~

* ``Fix authentication for cases where webserver.base_url is not defined and worker is not using localhost in 2.10.``

0.10.1pre0
..........

Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

__all__ = ["__version__"]

__version__ = "0.10.1pre0"
__version__ = "0.10.2pre0"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.10.0"
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ source-date-epoch: 1729683247

# note that those versions are maintained by release manager - do not update them manually
versions:
- 0.10.1pre0
- 0.10.2pre0

dependencies:
- apache-airflow>=2.10.0
Expand Down
9 changes: 2 additions & 7 deletions providers/src/airflow/providers/edge/worker_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,13 @@ def _forbidden_response(message: str):
def jwt_token_authorization(method: str, authorization: str):
"""Check if the JWT token is correct."""
try:
# worker sends method without api_url
api_url = conf.get("edge", "api_url")
base_url = conf.get("webserver", "base_url")
url_prefix = api_url.replace(base_url, "").replace("/rpcapi", "/")
pure_method = method.replace(url_prefix, "")
payload = jwt_signer().verify_token(authorization)
signed_method = payload.get("method")
if not signed_method or signed_method != pure_method:
if not signed_method or signed_method != method:
_forbidden_response(
"Invalid method in token authorization. "
f"signed method='{signed_method}' "
f"called method='{pure_method}'",
f"called method='{method}'",
)
except BadSignature:
_forbidden_response("Bad Signature. Please use only the tokens provided by the API.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,19 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
return e.to_response() # type: ignore[attr-defined]


def jwt_token_authorization_v2(method: str, authorization: str):
"""Proxy for v2 method path handling."""
PREFIX = "/edge_worker/v1/"
method_path = method[method.find(PREFIX) + len(PREFIX) :] if PREFIX in method else method
jwt_token_authorization(method_path, authorization)


@provide_session
def register_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> Any:
"""Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10."""
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
request_obj = WorkerStateBody(
state=body["state"], jobs_active=0, queues=body["queues"], sysinfo=body["sysinfo"]
)
Expand All @@ -139,7 +146,7 @@ def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) ->
"""Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10."""
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
request_obj = WorkerStateBody(
state=body["state"],
jobs_active=body["jobs_active"],
Expand All @@ -158,7 +165,7 @@ def job_fetch_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) ->

try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
queues = body.get("queues")
free_concurrency = body.get("free_concurrency", 1)
request_obj = WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency)
Expand All @@ -183,7 +190,7 @@ def job_state_v2(

try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
state_api(dag_id, task_id, run_id, try_number, int(map_index), state, session)
except HTTPException as e:
return e.to_response() # type: ignore[attr-defined]
Expand All @@ -199,7 +206,7 @@ def logfile_path_v2(
"""Handle Edge Worker API `/edge_worker/v1/logs/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}` endpoint for Airflow 2.10."""
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
return logfile_path(dag_id, task_id, run_id, try_number, int(map_index))
except HTTPException as e:
return e.to_response() # type: ignore[attr-defined]
Expand All @@ -216,7 +223,7 @@ def push_logs_v2(
"""Handle Edge Worker API `/edge_worker/v1/logs/push/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}` endpoint for Airflow 2.10."""
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
request_obj = PushLogsBody(
log_chunk_data=body["log_chunk_data"], log_chunk_time=body["log_chunk_time"]
)
Expand Down

0 comments on commit 169c2b1

Please sign in to comment.