Skip to content

Commit

Permalink
Migrate Edge calls for Worker to FastAPI part 4 - Cleanup (#44434)
Browse files Browse the repository at this point in the history
* Remove internal API bindings after migration to FastAPI

* Move import to function preventing module import errors
  • Loading branch information
jscheffl authored Dec 1, 2024
1 parent 161beeb commit 0d98e2b
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 523 deletions.
8 changes: 8 additions & 0 deletions providers/src/airflow/providers/edge/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
Changelog
---------

0.9.0pre0
.........

Misc
~~~~

* ``Remove dependency to Internal API after migration to FastAPI.``

0.8.2pre0
.........

Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

__all__ = ["__version__"]

__version__ = "0.8.2pre0"
__version__ = "0.9.0pre0"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.10.0"
Expand Down
4 changes: 2 additions & 2 deletions providers/src/airflow/providers/edge/cli/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def worker_register(
def worker_set_state(
hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] | None, sysinfo: dict
) -> list[str] | None:
"""Register worker with the Edge API."""
"""Update the state of the worker in the central site and thereby implicitly heartbeat."""
return _make_generic_request(
"PATCH",
f"worker/{quote(hostname)}",
Expand All @@ -123,7 +123,7 @@ def worker_set_state(
def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
result = _make_generic_request(
"GET",
"POST",
f"jobs/fetch/{quote(hostname)}",
WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency).model_dump_json(
exclude_unset=True
Expand Down
105 changes: 1 addition & 104 deletions providers/src/airflow/providers/edge/models/edge_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,21 @@
# under the License.
from __future__ import annotations

from ast import literal_eval
from datetime import datetime
from typing import TYPE_CHECKING, Optional

from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
Column,
Index,
Integer,
String,
select,
text,
)

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.models.base import Base, StringID
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
from airflow.utils.sqlalchemy import UtcDateTime


class EdgeJobModel(Base, LoggingMixin):
Expand Down Expand Up @@ -103,95 +92,3 @@ def key(self):
@property
def last_update_t(self) -> float:
return self.last_update.timestamp()


class EdgeJob(BaseModel, LoggingMixin):
"""Accessor for edge jobs as logical model."""

dag_id: str
task_id: str
run_id: str
map_index: int
try_number: int
state: TaskInstanceState
queue: str
concurrency_slots: int
command: list[str]
queued_dttm: datetime
edge_worker: Optional[str] # noqa: UP007 - prevent Sphinx failing
last_update: Optional[datetime] # noqa: UP007 - prevent Sphinx failing
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

@property
def key(self) -> TaskInstanceKey:
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)

@staticmethod
@internal_api_call
@provide_session
def reserve_task(
worker_name: str,
free_concurrency: int,
queues: list[str] | None = None,
session: Session = NEW_SESSION,
) -> EdgeJob | None:
query = (
select(EdgeJobModel)
.where(
EdgeJobModel.state == TaskInstanceState.QUEUED,
EdgeJobModel.concurrency_slots <= free_concurrency,
)
.order_by(EdgeJobModel.queued_dttm)
)
if queues:
query = query.where(EdgeJobModel.queue.in_(queues))
query = query.limit(1)
query = with_row_locks(query, of=EdgeJobModel, session=session, skip_locked=True)
job: EdgeJobModel = session.scalar(query)
if not job:
return None
job.state = TaskInstanceState.RUNNING
job.edge_worker = worker_name
job.last_update = timezone.utcnow()
session.commit()
return EdgeJob(
dag_id=job.dag_id,
task_id=job.task_id,
run_id=job.run_id,
map_index=job.map_index,
try_number=job.try_number,
state=job.state,
queue=job.queue,
concurrency_slots=job.concurrency_slots,
command=literal_eval(job.command),
queued_dttm=job.queued_dttm,
edge_worker=job.edge_worker,
last_update=job.last_update,
)

@staticmethod
@internal_api_call
@provide_session
def set_state(task: TaskInstanceKey | tuple, state: TaskInstanceState, session: Session = NEW_SESSION):
if isinstance(task, tuple):
task = TaskInstanceKey(*task)
query = select(EdgeJobModel).where(
EdgeJobModel.dag_id == task.dag_id,
EdgeJobModel.task_id == task.task_id,
EdgeJobModel.run_id == task.run_id,
EdgeJobModel.map_index == task.map_index,
EdgeJobModel.try_number == task.try_number,
)
job: EdgeJobModel = session.scalar(query)
if job:
job.state = state
job.last_update = timezone.utcnow()
session.commit()

def __hash__(self):
return f"{self.dag_id}|{self.task_id}|{self.run_id}|{self.map_index}|{self.try_number}".__hash__()


EdgeJob.model_rebuild()

add_pydantic_class_type_mapping("edge_job", EdgeJobModel, EdgeJob)
81 changes: 0 additions & 81 deletions providers/src/airflow/providers/edge/models/edge_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
from __future__ import annotations

from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING

from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
Column,
Integer,
Expand All @@ -30,19 +26,10 @@
)
from sqlalchemy.dialects.mysql import MEDIUMTEXT

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.models.base import Base, StringID
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session


class EdgeLogsModel(Base, LoggingMixin):
"""
Expand Down Expand Up @@ -84,71 +71,3 @@ def __init__(
self.log_chunk_time = log_chunk_time
self.log_chunk_data = log_chunk_data
super().__init__()


class EdgeLogs(BaseModel, LoggingMixin):
"""Deprecated Internal API for Edge Worker instances as logical model."""

dag_id: str
task_id: str
run_id: str
map_index: int
try_number: int
log_chunk_time: datetime
log_chunk_data: str
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

@staticmethod
@internal_api_call
@provide_session
def push_logs(
task: TaskInstanceKey | tuple,
log_chunk_time: datetime,
log_chunk_data: str,
session: Session = NEW_SESSION,
) -> None:
"""Push an incremental log chunk from Edge Worker to central site."""
if isinstance(task, tuple):
task = TaskInstanceKey(*task)
log_chunk = EdgeLogsModel(
dag_id=task.dag_id,
task_id=task.task_id,
run_id=task.run_id,
map_index=task.map_index,
try_number=task.try_number,
log_chunk_time=log_chunk_time,
log_chunk_data=log_chunk_data,
)
session.add(log_chunk)
# Write logs to local file to make them accessible
logfile_path = EdgeLogs.logfile_path(task)
if not logfile_path.exists():
new_folder_permissions = int(
conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), 8
)
logfile_path.parent.mkdir(parents=True, exist_ok=True, mode=new_folder_permissions)
with logfile_path.open("a") as logfile:
logfile.write(log_chunk_data)

@staticmethod
@lru_cache
def logfile_path(task: TaskInstanceKey) -> Path:
"""Elaborate the path and filename to expect from task execution."""
from airflow.utils.log.file_task_handler import FileTaskHandler

ti = TaskInstance.get_task_instance(
dag_id=task.dag_id,
run_id=task.run_id,
task_id=task.task_id,
map_index=task.map_index,
)
if TYPE_CHECKING:
assert ti
assert isinstance(ti, TaskInstance)
base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE")
return Path(base_log_folder, FileTaskHandler(base_log_folder)._render_filename(ti, task.try_number))


EdgeLogs.model_rebuild()

add_pydantic_class_type_mapping("edge_logs", EdgeLogsModel, EdgeLogs)
75 changes: 1 addition & 74 deletions providers/src/airflow/providers/edge/models/edge_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,16 @@
import json
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, Optional

from pydantic import BaseModel, ConfigDict
from sqlalchemy import Column, Integer, String, select
from sqlalchemy import Column, Integer, String

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException
from airflow.models.base import Base
from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session


class EdgeWorkerVersionException(AirflowException):
"""Signal a version mismatch between core and Edge Site."""
Expand Down Expand Up @@ -170,68 +162,3 @@ def reset_metrics(worker_name: str) -> None:
free_concurrency=-1,
queues=None,
)


class EdgeWorker(BaseModel, LoggingMixin):
"""Deprecated Edge Worker internal API, keeping for one minor for graceful migration."""

worker_name: str
state: EdgeWorkerState
queues: Optional[list[str]] # noqa: UP007 - prevent Sphinx failing
first_online: datetime
last_update: Optional[datetime] = None # noqa: UP007 - prevent Sphinx failing
jobs_active: int
jobs_taken: int
jobs_success: int
jobs_failed: int
sysinfo: str
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

@staticmethod
@internal_api_call
@provide_session
def set_state(
worker_name: str,
state: EdgeWorkerState,
jobs_active: int,
sysinfo: dict[str, str],
session: Session = NEW_SESSION,
) -> list[str] | None:
"""Set state of worker and returns the current assigned queues."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker.state = state
worker.jobs_active = jobs_active
worker.sysinfo = json.dumps(sysinfo)
worker.last_update = timezone.utcnow()
session.commit()
Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1)
Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name": worker_name})
set_metrics(
worker_name=worker_name,
state=state,
jobs_active=jobs_active,
concurrency=int(sysinfo["concurrency"]),
free_concurrency=int(sysinfo["free_concurrency"]),
queues=worker.queues,
)
raise EdgeWorkerVersionException(
"Edge Worker runs on an old version. Rejecting access due to difference."
)

@staticmethod
@internal_api_call
def register_worker(
worker_name: str,
state: EdgeWorkerState,
queues: list[str] | None,
sysinfo: dict[str, str],
) -> EdgeWorker:
raise EdgeWorkerVersionException(
"Edge Worker runs on an old version. Rejecting access due to difference."
)


EdgeWorker.model_rebuild()

add_pydantic_class_type_mapping("edge_worker", EdgeWorkerModel, EdgeWorker)
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ paths:
tags:
- Worker
/jobs/fetch/{worker_name}:
get:
post:
description: Fetch a job to execute on the edge worker.
x-openapi-router-controller: airflow.providers.edge.worker_api.routes._v2_routes
operationId: job_fetch_v2
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ source-date-epoch: 1729683247

# note that those versions are maintained by release manager - do not update them manually
versions:
- 0.8.2pre0
- 0.9.0pre0

dependencies:
- apache-airflow>=2.10.0
Expand Down
Loading

0 comments on commit 0d98e2b

Please sign in to comment.