Skip to content

Commit

Permalink
Fix authentication for cases where webserver.base_url is not defined …
Browse files Browse the repository at this point in the history
…and worker is not using localhost in 2.10. (#45785)

* Fix authentication for cases where webserver.base_url is not defined and worker is not using localhost in 2.10.

* Move Edge to new provider structure
  • Loading branch information
jscheffl authored Jan 20, 2025
1 parent 24b1fe8 commit 90af410
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 22 deletions.
6 changes: 3 additions & 3 deletions providers/edge/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Package ``apache-airflow-providers-edge``

Release: ``0.10.1pre0``
Release: ``0.10.2pre0``


Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites
Expand All @@ -37,7 +37,7 @@ This is a provider package for ``edge`` provider. All classes for this provider
are in ``airflow.providers.edge`` python package.

You can find package information and changelog for the provider
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-edge/0.10.1pre0/>`_.
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-edge/0.10.2pre0/>`_.

Installation
------------
Expand All @@ -60,4 +60,4 @@ PIP package Version required
================== ==================

The changelog for the provider package can be found in the
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-edge/0.10.1pre0/changelog.html>`_.
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-edge/0.10.2pre0/changelog.html>`_.
8 changes: 8 additions & 0 deletions providers/edge/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
Changelog
---------

0.10.2pre0
..........

Misc
~~~~

* ``Fix authentication for cases where webserver.base_url is not defined and worker is not using localhost in 2.10.``

0.10.1pre0
..........

Expand Down
2 changes: 1 addition & 1 deletion providers/edge/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ source-date-epoch: 1737371680

# note that those versions are maintained by release manager - do not update them manually
versions:
- 0.10.1pre0
- 0.10.2pre0

plugins:
- name: edge_executor
Expand Down
6 changes: 3 additions & 3 deletions providers/edge/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "apache-airflow-providers-edge"
version = "0.10.1pre0"
version = "0.10.2pre0"
description = "Provider package apache-airflow-providers-edge for Apache Airflow"
readme = "README.rst"
authors = [
Expand Down Expand Up @@ -61,8 +61,8 @@ dependencies = [
]

[project.urls]
"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.10.1pre0"
"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.10.1pre0/changelog.html"
"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.10.2pre0"
"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.10.2pre0/changelog.html"
"Bug Tracker" = "https://github.com/apache/airflow/issues"
"Source Code" = "https://github.com/apache/airflow"
"Slack Chat" = "https://s.apache.org/airflow-slack"
Expand Down
2 changes: 1 addition & 1 deletion providers/edge/src/airflow/providers/edge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

__all__ = ["__version__"]

__version__ = "0.10.1pre0"
__version__ = "0.10.2pre0"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.10.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_provider_info():
"description": "Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites\n",
"state": "not-ready",
"source-date-epoch": 1737371680,
"versions": ["0.10.1pre0"],
"versions": ["0.10.2pre0"],
"plugins": [
{
"name": "edge_executor",
Expand Down
9 changes: 2 additions & 7 deletions providers/edge/src/airflow/providers/edge/worker_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,13 @@ def _forbidden_response(message: str):
def jwt_token_authorization(method: str, authorization: str):
"""Check if the JWT token is correct."""
try:
# worker sends method without api_url
api_url = conf.get("edge", "api_url")
base_url = conf.get("webserver", "base_url")
url_prefix = api_url.replace(base_url, "").replace("/rpcapi", "/")
pure_method = method.replace(url_prefix, "")
payload = jwt_signer().verify_token(authorization)
signed_method = payload.get("method")
if not signed_method or signed_method != pure_method:
if not signed_method or signed_method != method:
_forbidden_response(
"Invalid method in token authorization. "
f"signed method='{signed_method}' "
f"called method='{pure_method}'",
f"called method='{method}'",
)
except BadSignature:
_forbidden_response("Bad Signature. Please use only the tokens provided by the API.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,19 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
return e.to_response() # type: ignore[attr-defined]


def jwt_token_authorization_v2(method: str, authorization: str):
"""Proxy for v2 method path handling."""
PREFIX = "/edge_worker/v1/"
method_path = method[method.find(PREFIX) + len(PREFIX) :] if PREFIX in method else method
jwt_token_authorization(method_path, authorization)


@provide_session
def register_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> Any:
"""Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10."""
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
request_obj = WorkerStateBody(
state=body["state"], jobs_active=0, queues=body["queues"], sysinfo=body["sysinfo"]
)
Expand All @@ -139,7 +146,7 @@ def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) ->
"""Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10."""
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
request_obj = WorkerStateBody(
state=body["state"],
jobs_active=body["jobs_active"],
Expand All @@ -158,7 +165,7 @@ def job_fetch_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) ->

try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
queues = body.get("queues")
free_concurrency = body.get("free_concurrency", 1)
request_obj = WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency)
Expand All @@ -183,7 +190,7 @@ def job_state_v2(

try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
state_api(dag_id, task_id, run_id, try_number, int(map_index), state, session)
except HTTPException as e:
return e.to_response() # type: ignore[attr-defined]
Expand All @@ -199,7 +206,7 @@ def logfile_path_v2(
"""Handle Edge Worker API `/edge_worker/v1/logs/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}` endpoint for Airflow 2.10."""
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
return logfile_path(dag_id, task_id, run_id, try_number, int(map_index))
except HTTPException as e:
return e.to_response() # type: ignore[attr-defined]
Expand All @@ -216,7 +223,7 @@ def push_logs_v2(
"""Handle Edge Worker API `/edge_worker/v1/logs/push/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}` endpoint for Airflow 2.10."""
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
jwt_token_authorization_v2(request.path, auth)
request_obj = PushLogsBody(
log_chunk_data=body["log_chunk_data"], log_chunk_time=body["log_chunk_time"]
)
Expand Down

0 comments on commit 90af410

Please sign in to comment.