Skip to content

Commit

Permalink
Rewrite how DAG to dataset / dataset alias are stored (#41987) (#42055)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored Sep 6, 2024
1 parent ecc94f7 commit a02325f
Showing 1 changed file with 50 additions and 38 deletions.
88 changes: 50 additions & 38 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3236,8 +3236,6 @@ def bulk_write_to_db(
if not dags:
return

from airflow.models.dataset import DagScheduleDatasetAliasReference

log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}

Expand Down Expand Up @@ -3344,18 +3342,19 @@ def bulk_write_to_db(

from airflow.datasets import Dataset
from airflow.models.dataset import (
DagScheduleDatasetAliasReference,
DagScheduleDatasetReference,
DatasetModel,
TaskOutletDatasetReference,
)

dag_references: dict[str, set[Dataset | DatasetAlias]] = defaultdict(set)
dag_references: dict[str, set[tuple[Literal["dataset", "dataset-alias"], str]]] = defaultdict(set)
outlet_references = defaultdict(set)
# We can't use a set here as we want to preserve order
outlet_datasets: dict[DatasetModel, None] = {}
input_datasets: dict[DatasetModel, None] = {}
outlet_dataset_models: dict[DatasetModel, None] = {}
input_dataset_models: dict[DatasetModel, None] = {}
outlet_dataset_alias_models: set[DatasetAliasModel] = set()
input_dataset_aliases: set[DatasetAliasModel] = set()
input_dataset_alias_models: set[DatasetAliasModel] = set()

# here we go through dags and tasks to check for dataset references
# if there are now None and previously there were some, we delete them
Expand All @@ -3371,12 +3370,12 @@ def bulk_write_to_db(
curr_orm_dag.schedule_dataset_alias_references = []
else:
for _, dataset in dataset_condition.iter_datasets():
dag_references[dag.dag_id].add(Dataset(uri=dataset.uri))
input_datasets[DatasetModel.from_public(dataset)] = None
dag_references[dag.dag_id].add(("dataset", dataset.uri))
input_dataset_models[DatasetModel.from_public(dataset)] = None

for dataset_alias in dataset_condition.iter_dataset_aliases():
dag_references[dag.dag_id].add(dataset_alias)
input_dataset_aliases.add(DatasetAliasModel.from_public(dataset_alias))
dag_references[dag.dag_id].add(("dataset-alias", dataset_alias.name))
input_dataset_alias_models.add(DatasetAliasModel.from_public(dataset_alias))

curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
for task in dag.tasks:
Expand All @@ -3399,63 +3398,70 @@ def bulk_write_to_db(
curr_outlet_references.remove(ref)

for d in dataset_outlets:
outlet_dataset_models[DatasetModel.from_public(d)] = None
outlet_references[(task.dag_id, task.task_id)].add(d.uri)
outlet_datasets[DatasetModel.from_public(d)] = None

for d_a in dataset_alias_outlets:
outlet_dataset_alias_models.add(DatasetAliasModel.from_public(d_a))

all_datasets = outlet_datasets
all_datasets.update(input_datasets)
all_dataset_models = outlet_dataset_models
all_dataset_models.update(input_dataset_models)

# store datasets
stored_datasets: dict[str, DatasetModel] = {}
new_datasets: list[DatasetModel] = []
for dataset in all_datasets:
stored_dataset = session.scalar(
stored_dataset_models: dict[str, DatasetModel] = {}
new_dataset_models: list[DatasetModel] = []
for dataset in all_dataset_models:
stored_dataset_model = session.scalar(
select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1)
)
if stored_dataset:
if stored_dataset_model:
# Some datasets may have been previously unreferenced, and therefore orphaned by the
# scheduler. But if we're here, then we have found that dataset again in our DAGs, which
# means that it is no longer an orphan, so set is_orphaned to False.
stored_dataset.is_orphaned = expression.false()
stored_datasets[stored_dataset.uri] = stored_dataset
stored_dataset_model.is_orphaned = expression.false()
stored_dataset_models[stored_dataset_model.uri] = stored_dataset_model
else:
new_datasets.append(dataset)
dataset_manager.create_datasets(dataset_models=new_datasets, session=session)
stored_datasets.update({dataset.uri: dataset for dataset in new_datasets})
new_dataset_models.append(dataset)
dataset_manager.create_datasets(dataset_models=new_dataset_models, session=session)
stored_dataset_models.update(
{dataset_model.uri: dataset_model for dataset_model in new_dataset_models}
)

del new_datasets
del all_datasets
del new_dataset_models
del all_dataset_models

# store dataset aliases
all_datasets_alias_models = input_dataset_aliases | outlet_dataset_alias_models
stored_dataset_aliases: dict[str, DatasetAliasModel] = {}
all_datasets_alias_models = input_dataset_alias_models | outlet_dataset_alias_models
stored_dataset_alias_models: dict[str, DatasetAliasModel] = {}
new_dataset_alias_models: set[DatasetAliasModel] = set()
if all_datasets_alias_models:
all_dataset_alias_names = {dataset_alias.name for dataset_alias in all_datasets_alias_models}
all_dataset_alias_names = {
dataset_alias_model.name for dataset_alias_model in all_datasets_alias_models
}

stored_dataset_aliases = {
stored_dataset_alias_models = {
dsa_m.name: dsa_m
for dsa_m in session.scalars(
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names))
).fetchall()
}

if stored_dataset_aliases:
if stored_dataset_alias_models:
new_dataset_alias_models = {
dataset_alias_model
for dataset_alias_model in all_datasets_alias_models
if dataset_alias_model.name not in stored_dataset_aliases.keys()
if dataset_alias_model.name not in stored_dataset_alias_models.keys()
}
else:
new_dataset_alias_models = all_datasets_alias_models

session.add_all(new_dataset_alias_models)
session.flush()
stored_dataset_aliases.update(
{dataset_alias.name: dataset_alias for dataset_alias in new_dataset_alias_models}
stored_dataset_alias_models.update(
{
dataset_alias_model.name: dataset_alias_model
for dataset_alias_model in new_dataset_alias_models
}
)

del new_dataset_alias_models
Expand All @@ -3464,14 +3470,18 @@ def bulk_write_to_db(
# reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias references
for dag_id, base_dataset_list in dag_references.items():
dag_refs_needed = {
DagScheduleDatasetReference(dataset_id=stored_datasets[base_dataset.uri].id, dag_id=dag_id)
if isinstance(base_dataset, Dataset)
DagScheduleDatasetReference(
dataset_id=stored_dataset_models[base_dataset_identifier].id, dag_id=dag_id
)
if base_dataset_type == "dataset"
else DagScheduleDatasetAliasReference(
alias_id=stored_dataset_aliases[base_dataset.name].id, dag_id=dag_id
alias_id=stored_dataset_alias_models[base_dataset_identifier].id, dag_id=dag_id
)
for base_dataset in base_dataset_list
for base_dataset_type, base_dataset_identifier in base_dataset_list
}

# if isinstance(base_dataset, Dataset)

dag_refs_stored = (
set(existing_dags.get(dag_id).schedule_dataset_references) # type: ignore
| set(existing_dags.get(dag_id).schedule_dataset_alias_references) # type: ignore
Expand All @@ -3491,7 +3501,9 @@ def bulk_write_to_db(
# reconcile task-outlet-dataset references
for (dag_id, task_id), uri_list in outlet_references.items():
task_refs_needed = {
TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id)
TaskOutletDatasetReference(
dataset_id=stored_dataset_models[uri].id, dag_id=dag_id, task_id=task_id
)
for uri in uri_list
}
task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)]
Expand Down

0 comments on commit a02325f

Please sign in to comment.