Skip to content

Commit

Permalink
Remove AIP-44 from taskinstance (apache#44540)
Browse files Browse the repository at this point in the history
Part of apache#44436

Co-authored-by: kalyanr <[email protected]>
  • Loading branch information
2 people authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent 315f041 commit 0806683
Show file tree
Hide file tree
Showing 16 changed files with 83 additions and 442 deletions.
12 changes: 5 additions & 7 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _get_ti(
)

ti_or_none = dag_run.get_task_instance(task.task_id, map_index=map_index, session=session)
ti: TaskInstance | TaskInstancePydantic
ti: TaskInstance
if ti_or_none is None:
if not create_if_necessary:
raise TaskInstanceNotFound(
Expand All @@ -249,9 +249,7 @@ def _get_ti(
return ti, dr_created


def _run_task_by_selected_method(
args, dag: DAG, ti: TaskInstance | TaskInstancePydantic
) -> None | TaskReturnCode:
def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None | TaskReturnCode:
"""
Run the task based on a mode.
Expand Down Expand Up @@ -308,7 +306,7 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None:
executor.end()


def _run_task_by_local_task_job(args, ti: TaskInstance | TaskInstancePydantic) -> TaskReturnCode | None:
def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None:
"""Run LocalTaskJob, which monitors the raw task execution process."""
job_runner = LocalTaskJobRunner(
job=Job(dag_id=ti.dag_id),
Expand Down Expand Up @@ -354,7 +352,7 @@ def _extract_external_executor_id(args) -> str | None:


@contextmanager
def _move_task_handlers_to_root(ti: TaskInstance | TaskInstancePydantic) -> Generator[None, None, None]:
def _move_task_handlers_to_root(ti: TaskInstance) -> Generator[None, None, None]:
"""
Move handlers for task logging to root logger.
Expand All @@ -381,7 +379,7 @@ def _move_task_handlers_to_root(ti: TaskInstance | TaskInstancePydantic) -> Gene


@contextmanager
def _redirect_stdout_to_ti_log(ti: TaskInstance | TaskInstancePydantic) -> Generator[None, None, None]:
def _redirect_stdout_to_ti_log(ti: TaskInstance) -> Generator[None, None, None]:
"""
Redirect stdout to ti logger.
Expand Down
3 changes: 1 addition & 2 deletions airflow/jobs/local_task_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

from airflow.jobs.job import Job
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

SIGSEGV_MESSAGE = """
******************************************* Received SIGSEGV *******************************************
Expand Down Expand Up @@ -83,7 +82,7 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
def __init__(
self,
job: Job,
task_instance: TaskInstance | TaskInstancePydantic,
task_instance: TaskInstance,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
wait_for_past_depends_before_skipping: bool = False,
Expand Down
6 changes: 2 additions & 4 deletions airflow/models/renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from sqlalchemy.sql import FromClause

from airflow.models import Operator
from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic
from airflow.models.taskinstance import TaskInstance


def get_serialized_template_fields(task: Operator):
Expand Down Expand Up @@ -173,9 +173,7 @@ def _update_runtime_evaluated_template_fields(

@classmethod
@provide_session
def get_templated_fields(
cls, ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION
) -> dict | None:
def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None:
"""
Get templated field for a TaskInstance from the RenderedTaskInstanceFields table.
Expand Down
5 changes: 2 additions & 3 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from airflow.models.operator import Operator
from airflow.sdk.definitions.node import DAGNode
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"
Expand Down Expand Up @@ -153,7 +152,7 @@ def _skip(

def skip_all_except(
self,
ti: TaskInstance | TaskInstancePydantic,
ti: TaskInstance,
branch_task_ids: None | str | Iterable[str],
):
"""Facade for compatibility for call to internal API."""
Expand All @@ -167,7 +166,7 @@ def skip_all_except(
@provide_session
def _skip_all_except(
cls,
ti: TaskInstance | TaskInstancePydantic,
ti: TaskInstance,
branch_task_ids: None | str | Iterable[str],
session: Session = NEW_SESSION,
):
Expand Down
Loading

0 comments on commit 0806683

Please sign in to comment.