Skip to content

Commit

Permalink
Remove Pydanitc models introduced for AIP-44
Browse files Browse the repository at this point in the history
The Pudanic models have been used in a number of places still and
we are using them also for context passing for PythonVirtualEnv
and ExternaPythonOperator  - this PR removes all the models and
their usages.

Closes: apache#44436

 # Please enter the commit message for your changes. Lines starting
  • Loading branch information
potiuk committed Dec 12, 2024
1 parent b7df463 commit 5c8c85c
Show file tree
Hide file tree
Showing 30 changed files with 81 additions and 1,154 deletions.
13 changes: 12 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/dag_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@

from __future__ import annotations

from pydantic import BaseModel
from pydantic import ConfigDict

from airflow.api_fastapi.core_api.base import BaseModel


class DagTagResponse(BaseModel):
"""DAG Tag serializer for responses."""

model_config = ConfigDict(populate_by_name=True, from_attributes=True)

name: str
dag_id: str


class DAGTagCollectionResponse(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
)

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.datamodels.dag_tags import DagTagResponse
from airflow.configuration import conf
from airflow.serialization.pydantic.dag import DagTagPydantic


class DAGResponse(BaseModel):
Expand All @@ -50,7 +50,7 @@ class DAGResponse(BaseModel):
description: str | None
timetable_summary: str | None
timetable_description: str | None
tags: list[DagTagPydantic]
tags: list[DagTagResponse]
max_active_tasks: int
max_active_runs: int | None
max_consecutive_failed_dag_runs: int
Expand Down
13 changes: 6 additions & 7 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6714,7 +6714,7 @@ components:
title: Timetable Description
tags:
items:
$ref: '#/components/schemas/DagTagPydantic'
$ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
Expand Down Expand Up @@ -6936,7 +6936,7 @@ components:
title: Timetable Description
tags:
items:
$ref: '#/components/schemas/DagTagPydantic'
$ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
Expand Down Expand Up @@ -7412,7 +7412,7 @@ components:
title: Timetable Description
tags:
items:
$ref: '#/components/schemas/DagTagPydantic'
$ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
Expand Down Expand Up @@ -7665,7 +7665,7 @@ components:
- count
title: DagStatsStateResponse
description: DagStatsState serializer for responses.
DagTagPydantic:
DagTagResponse:
properties:
name:
type: string
Expand All @@ -7677,9 +7677,8 @@ components:
required:
- name
- dag_id
title: DagTagPydantic
description: Serializable representation of the DagTag ORM SqlAlchemyModel used
by internal API.
title: DagTagResponse
description: DAG Tag serializer for responses.
DagWarningType:
type: string
enum:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TIEnterRunningPayload(BaseModel):

state: Annotated[
Literal[TIState.RUNNING],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
# Specify a default in the schema, but not in code.
WithJsonSchema({"type": "string", "enum": [TIState.RUNNING], "default": TIState.RUNNING}),
]
hostname: str
Expand Down
20 changes: 2 additions & 18 deletions airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskReturnCode
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
Expand All @@ -74,7 +73,6 @@
from sqlalchemy.orm.session import Session

from airflow.models.operator import Operator
from airflow.serialization.pydantic.dag_run import DagRunPydantic

log = logging.getLogger(__name__)

Expand All @@ -96,7 +94,7 @@ def _fetch_dag_run_from_run_id_or_logical_date_string(
dag_id: str,
value: str,
session: Session,
) -> tuple[DagRun | DagRunPydantic, pendulum.DateTime | None]:
) -> tuple[DagRun, pendulum.DateTime | None]:
"""
Try to find a DAG run with a given string value.
Expand Down Expand Up @@ -132,7 +130,7 @@ def _get_dag_run(
create_if_necessary: CreateIfNecessary,
logical_date_or_run_id: str | None = None,
session: Session | None = None,
) -> tuple[DagRun | DagRunPydantic, bool]:
) -> tuple[DagRun, bool]:
"""
Try to retrieve a DAG run from a string representing either a run ID or logical date.
Expand Down Expand Up @@ -259,8 +257,6 @@ def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None | Tas
- as raw task
- by executor
"""
if TYPE_CHECKING:
assert not isinstance(ti, TaskInstancePydantic) # Wait for AIP-44 implementation to complete
if args.local:
return _run_task_by_local_task_job(args, ti)
if args.raw:
Expand Down Expand Up @@ -497,9 +493,6 @@ def task_failed_deps(args) -> None:
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
# tasks_failed-deps is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
# TODO, Do we want to print or log this
Expand All @@ -524,9 +517,6 @@ def task_state(args) -> None:
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
# task_state is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
print(ti.current_state())


Expand Down Expand Up @@ -654,9 +644,6 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N
ti, dr_created = _get_ti(
task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="db"
)
# task_test is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
try:
with redirect_stdout(RedactedIO()):
if args.dry_run:
Expand Down Expand Up @@ -705,9 +692,6 @@ def task_render(args, dag: DAG | None = None) -> None:
ti, _ = _get_ti(
task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory"
)
# task_render is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
with create_session() as session, set_current_task_instance_session(session=session):
ti.render_templates()
for attr in task.template_fields:
Expand Down
8 changes: 4 additions & 4 deletions airflow/jobs/JOB_LIFECYCLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ sequenceDiagram
DB --> Internal API: Close Session
deactivate DB
Internal API->>CLI component: JobPydantic object
Internal API->>CLI component: Job object
CLI component->>JobRunner: Create Job Runner
JobRunner ->> CLI component: JobRunner object
Expand All @@ -109,7 +109,7 @@ sequenceDiagram
activate JobRunner
JobRunner->>Internal API: prepare_for_execution [JobPydantic]
JobRunner->>Internal API: prepare_for_execution [Job]
Internal API-->>DB: Create Session
activate DB
Expand All @@ -131,7 +131,7 @@ sequenceDiagram
deactivate DB
Internal API ->> JobRunner: returned data
and
JobRunner->>Internal API: perform_heartbeat <br> [Job Pydantic]
JobRunner->>Internal API: perform_heartbeat <br> [Job]
Internal API-->>DB: Create Session
activate DB
Internal API->>DB: perform_heartbeat [Job]
Expand All @@ -142,7 +142,7 @@ sequenceDiagram
deactivate DB
end
JobRunner->>Internal API: complete_execution <br> [Job Pydantic]
JobRunner->>Internal API: complete_execution <br> [Job]
Internal API-->>DB: Create Session
Internal API->>DB: complete_execution [Job]
activate DB
Expand Down
3 changes: 1 addition & 2 deletions airflow/jobs/base_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from sqlalchemy.orm import Session

from airflow.jobs.job import Job
from airflow.serialization.pydantic.job import JobPydantic


class BaseJobRunner:
Expand Down Expand Up @@ -64,7 +63,7 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:

@classmethod
@provide_session
def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | JobPydantic | None:
def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | None:
"""Return the most recent job of this type, if any, based on last heartbeat received."""
from airflow.jobs.job import most_recent_job

Expand Down
2 changes: 2 additions & 0 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __getattr__(name):


__lazy_imports = {
"Job": "airflow.jobs.job",
"DAG": "airflow.models.dag",
"ID_LEN": "airflow.models.base",
"Base": "airflow.models.base",
Expand Down Expand Up @@ -112,6 +113,7 @@ def __getattr__(name):
if TYPE_CHECKING:
# I was unable to get mypy to respect a airflow/models/__init__.pyi, so
# having to resort back to this hacky method
from airflow.jobs.job import Job
from airflow.models.base import ID_LEN, Base
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink
Expand Down
3 changes: 1 addition & 2 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.utils.context import Context

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -334,7 +333,7 @@ def deserialize(cls, data: dict, dags: dict) -> DagParam:
def process_params(
dag: DAG,
task: Operator,
dag_run: DagRun | DagRunPydantic | None,
dag_run: DagRun | None,
*,
suppress_exception: bool,
) -> dict[str, Any]:
Expand Down
5 changes: 2 additions & 3 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.node import DAGNode
from airflow.serialization.pydantic.dag_run import DagRunPydantic

# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"
Expand All @@ -61,7 +60,7 @@ class SkipMixin(LoggingMixin):

@staticmethod
def _set_state_to_skipped(
dag_run: DagRun | DagRunPydantic,
dag_run: DagRun,
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
) -> None:
Expand Down Expand Up @@ -95,7 +94,7 @@ def _set_state_to_skipped(
@provide_session
def skip(
self,
dag_run: DagRun | DagRunPydantic,
dag_run: DagRun,
tasks: Iterable[DAGNode],
map_index: int = -1,
session: Session = NEW_SESSION,
Expand Down
8 changes: 3 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -984,7 +982,7 @@ def get_prev_end_date_success() -> pendulum.DateTime | None:
return None
return timezone.coerce_datetime(dagrun.end_date)

def get_triggering_events() -> dict[str, list[AssetEvent | AssetEventPydantic]]:
def get_triggering_events() -> dict[str, list[AssetEvent]]:
if TYPE_CHECKING:
assert session is not None

Expand All @@ -995,7 +993,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent | AssetEventPydantic]]:
if dag_run not in session:
dag_run = session.merge(dag_run, load=False)
asset_events = dag_run.consumed_asset_events
triggering_events: dict[str, list[AssetEvent | AssetEventPydantic]] = defaultdict(list)
triggering_events: dict[str, list[AssetEvent]] = defaultdict(list)
for event in asset_events:
if event.asset:
triggering_events[event.asset.uri].append(event)
Expand Down Expand Up @@ -1890,7 +1888,7 @@ def _command_as_list(
pool: str | None = None,
cfg_path: str | None = None,
) -> list[str]:
dag: DAG | DagModel | DagModelPydantic | None
dag: DAG | DagModel | None
# Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None:
if TYPE_CHECKING:
Expand Down
16 changes: 0 additions & 16 deletions airflow/serialization/pydantic/__init__.py

This file was deleted.

Loading

0 comments on commit 5c8c85c

Please sign in to comment.