Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate Edge calls for Worker to FastAPI part 4 - Cleanup #44434

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions providers/src/airflow/providers/edge/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
Changelog
---------

0.9.0pre0
jscheffl marked this conversation as resolved.
Show resolved Hide resolved
.........

Misc
~~~~

* ``Remove dependency to Internal API after migration to FastAPI.``

0.8.2pre0
.........

Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

__all__ = ["__version__"]

__version__ = "0.8.2pre0"
__version__ = "0.9.0pre0"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.10.0"
Expand Down
4 changes: 2 additions & 2 deletions providers/src/airflow/providers/edge/cli/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def worker_register(
def worker_set_state(
jscheffl marked this conversation as resolved.
Show resolved Hide resolved
hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] | None, sysinfo: dict
) -> list[str] | None:
"""Register worker with the Edge API."""
"""Update the state of the worker in the central site and thereby implicitly heartbeat."""
return _make_generic_request(
"PATCH",
f"worker/{quote(hostname)}",
Expand All @@ -123,7 +123,7 @@ def worker_set_state(
def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
result = _make_generic_request(
"GET",
"POST",
f"jobs/fetch/{quote(hostname)}",
WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency).model_dump_json(
exclude_unset=True
Expand Down
105 changes: 1 addition & 104 deletions providers/src/airflow/providers/edge/models/edge_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,21 @@
# under the License.
from __future__ import annotations

from ast import literal_eval
from datetime import datetime
from typing import TYPE_CHECKING, Optional

from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
Column,
Index,
Integer,
String,
select,
text,
)

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.models.base import Base, StringID
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
from airflow.utils.sqlalchemy import UtcDateTime


class EdgeJobModel(Base, LoggingMixin):
Expand Down Expand Up @@ -103,95 +92,3 @@ def key(self):
@property
def last_update_t(self) -> float:
return self.last_update.timestamp()


class EdgeJob(BaseModel, LoggingMixin):
"""Accessor for edge jobs as logical model."""

dag_id: str
task_id: str
run_id: str
map_index: int
try_number: int
state: TaskInstanceState
queue: str
concurrency_slots: int
command: list[str]
queued_dttm: datetime
edge_worker: Optional[str] # noqa: UP007 - prevent Sphinx failing
last_update: Optional[datetime] # noqa: UP007 - prevent Sphinx failing
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

@property
def key(self) -> TaskInstanceKey:
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)

@staticmethod
@internal_api_call
@provide_session
def reserve_task(
worker_name: str,
free_concurrency: int,
queues: list[str] | None = None,
session: Session = NEW_SESSION,
) -> EdgeJob | None:
query = (
select(EdgeJobModel)
.where(
EdgeJobModel.state == TaskInstanceState.QUEUED,
EdgeJobModel.concurrency_slots <= free_concurrency,
)
.order_by(EdgeJobModel.queued_dttm)
)
if queues:
query = query.where(EdgeJobModel.queue.in_(queues))
query = query.limit(1)
query = with_row_locks(query, of=EdgeJobModel, session=session, skip_locked=True)
job: EdgeJobModel = session.scalar(query)
if not job:
return None
job.state = TaskInstanceState.RUNNING
job.edge_worker = worker_name
job.last_update = timezone.utcnow()
session.commit()
return EdgeJob(
dag_id=job.dag_id,
task_id=job.task_id,
run_id=job.run_id,
map_index=job.map_index,
try_number=job.try_number,
state=job.state,
queue=job.queue,
concurrency_slots=job.concurrency_slots,
command=literal_eval(job.command),
queued_dttm=job.queued_dttm,
edge_worker=job.edge_worker,
last_update=job.last_update,
)

@staticmethod
@internal_api_call
@provide_session
def set_state(task: TaskInstanceKey | tuple, state: TaskInstanceState, session: Session = NEW_SESSION):
if isinstance(task, tuple):
task = TaskInstanceKey(*task)
query = select(EdgeJobModel).where(
EdgeJobModel.dag_id == task.dag_id,
EdgeJobModel.task_id == task.task_id,
EdgeJobModel.run_id == task.run_id,
EdgeJobModel.map_index == task.map_index,
EdgeJobModel.try_number == task.try_number,
)
job: EdgeJobModel = session.scalar(query)
if job:
job.state = state
job.last_update = timezone.utcnow()
session.commit()

def __hash__(self):
return f"{self.dag_id}|{self.task_id}|{self.run_id}|{self.map_index}|{self.try_number}".__hash__()


EdgeJob.model_rebuild()

add_pydantic_class_type_mapping("edge_job", EdgeJobModel, EdgeJob)
81 changes: 0 additions & 81 deletions providers/src/airflow/providers/edge/models/edge_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
from __future__ import annotations

from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING

from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
Column,
Integer,
Expand All @@ -30,19 +26,10 @@
)
from sqlalchemy.dialects.mysql import MEDIUMTEXT

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.models.base import Base, StringID
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session


class EdgeLogsModel(Base, LoggingMixin):
"""
Expand Down Expand Up @@ -84,71 +71,3 @@ def __init__(
self.log_chunk_time = log_chunk_time
self.log_chunk_data = log_chunk_data
super().__init__()


class EdgeLogs(BaseModel, LoggingMixin):
"""Deprecated Internal API for Edge Worker instances as logical model."""

dag_id: str
task_id: str
run_id: str
map_index: int
try_number: int
log_chunk_time: datetime
log_chunk_data: str
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

@staticmethod
@internal_api_call
@provide_session
def push_logs(
task: TaskInstanceKey | tuple,
log_chunk_time: datetime,
log_chunk_data: str,
session: Session = NEW_SESSION,
) -> None:
"""Push an incremental log chunk from Edge Worker to central site."""
if isinstance(task, tuple):
task = TaskInstanceKey(*task)
log_chunk = EdgeLogsModel(
dag_id=task.dag_id,
task_id=task.task_id,
run_id=task.run_id,
map_index=task.map_index,
try_number=task.try_number,
log_chunk_time=log_chunk_time,
log_chunk_data=log_chunk_data,
)
session.add(log_chunk)
# Write logs to local file to make them accessible
logfile_path = EdgeLogs.logfile_path(task)
if not logfile_path.exists():
new_folder_permissions = int(
conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), 8
)
logfile_path.parent.mkdir(parents=True, exist_ok=True, mode=new_folder_permissions)
with logfile_path.open("a") as logfile:
logfile.write(log_chunk_data)

@staticmethod
@lru_cache
def logfile_path(task: TaskInstanceKey) -> Path:
"""Elaborate the path and filename to expect from task execution."""
from airflow.utils.log.file_task_handler import FileTaskHandler

ti = TaskInstance.get_task_instance(
dag_id=task.dag_id,
run_id=task.run_id,
task_id=task.task_id,
map_index=task.map_index,
)
if TYPE_CHECKING:
assert ti
assert isinstance(ti, TaskInstance)
base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE")
return Path(base_log_folder, FileTaskHandler(base_log_folder)._render_filename(ti, task.try_number))


EdgeLogs.model_rebuild()

add_pydantic_class_type_mapping("edge_logs", EdgeLogsModel, EdgeLogs)
75 changes: 1 addition & 74 deletions providers/src/airflow/providers/edge/models/edge_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,16 @@
import json
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, Optional

from pydantic import BaseModel, ConfigDict
from sqlalchemy import Column, Integer, String, select
from sqlalchemy import Column, Integer, String

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException
from airflow.models.base import Base
from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session


class EdgeWorkerVersionException(AirflowException):
"""Signal a version mismatch between core and Edge Site."""
Expand Down Expand Up @@ -170,68 +162,3 @@ def reset_metrics(worker_name: str) -> None:
free_concurrency=-1,
queues=None,
)


class EdgeWorker(BaseModel, LoggingMixin):
"""Deprecated Edge Worker internal API, keeping for one minor for graceful migration."""

worker_name: str
state: EdgeWorkerState
queues: Optional[list[str]] # noqa: UP007 - prevent Sphinx failing
first_online: datetime
last_update: Optional[datetime] = None # noqa: UP007 - prevent Sphinx failing
jobs_active: int
jobs_taken: int
jobs_success: int
jobs_failed: int
sysinfo: str
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

@staticmethod
@internal_api_call
@provide_session
def set_state(
worker_name: str,
state: EdgeWorkerState,
jobs_active: int,
sysinfo: dict[str, str],
session: Session = NEW_SESSION,
) -> list[str] | None:
"""Set state of worker and returns the current assigned queues."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker.state = state
worker.jobs_active = jobs_active
worker.sysinfo = json.dumps(sysinfo)
worker.last_update = timezone.utcnow()
session.commit()
Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1)
Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name": worker_name})
set_metrics(
worker_name=worker_name,
state=state,
jobs_active=jobs_active,
concurrency=int(sysinfo["concurrency"]),
free_concurrency=int(sysinfo["free_concurrency"]),
queues=worker.queues,
)
raise EdgeWorkerVersionException(
"Edge Worker runs on an old version. Rejecting access due to difference."
)

@staticmethod
@internal_api_call
def register_worker(
worker_name: str,
state: EdgeWorkerState,
queues: list[str] | None,
sysinfo: dict[str, str],
) -> EdgeWorker:
raise EdgeWorkerVersionException(
"Edge Worker runs on an old version. Rejecting access due to difference."
)


EdgeWorker.model_rebuild()

add_pydantic_class_type_mapping("edge_worker", EdgeWorkerModel, EdgeWorker)
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ paths:
tags:
- Worker
/jobs/fetch/{worker_name}:
get:
post:
description: Fetch a job to execute on the edge worker.
x-openapi-router-controller: airflow.providers.edge.worker_api.routes._v2_routes
operationId: job_fetch_v2
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ source-date-epoch: 1729683247

# note that those versions are maintained by release manager - do not update them manually
versions:
- 0.8.2pre0
- 0.9.0pre0

dependencies:
- apache-airflow>=2.10.0
Expand Down
Loading