Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starship Operator #100

Merged
merged 19 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
Empty file.
459 changes: 247 additions & 212 deletions astronomer_starship/providers/starship/hooks/starship.py

Large diffs are not rendered by default.

Empty file.
350 changes: 339 additions & 11 deletions astronomer_starship/providers/starship/operators/starship.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,349 @@
"""Operators, TaskGroups, and DAGs for interacting with the Starship migrations."""
from datetime import datetime
from typing import Any, Union, List

import airflow
from airflow import DAG
from airflow.decorators import task
from airflow.exceptions import AirflowSkipException
from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup

from astronomer_starship.providers.starship.hooks.starship import (
StarshipLocalHook,
StarshipHttpHook,
)

# Compatability Notes:
# - @task() is >=AF2.0
# - @task_group is >=AF2.1
# - Dynamic Task Mapping is >=AF2.3
# - Dynamic Task Mapping labelling is >=AF2.9


class StarshipMigrationOperator(BaseOperator):
def __init__(self, http_conn_id=None, **kwargs):
super().__init__(**kwargs)
self.source_hook = StarshipLocalHook()
self.target_hook = StarshipHttpHook(http_conn_id=http_conn_id)


class StarshipVariableMigrationOperator(StarshipMigrationOperator):
"""Operator to migrate a single Variable from one Airflow instance to another."""

def __init__(self, variable_key: Union[str, None] = None, **kwargs):
super().__init__(**kwargs)
self.variable_key = variable_key

def execute(self, context: Context) -> Any:
print("Getting Variable", self.variable_key)
variables = self.source_hook.get_variables()
variable: Union[dict, None] = (
[v for v in variables if v["key"] == self.variable_key] or [None]
)[0]
if variable is not None:
print("Migrating Variable", self.variable_key)
self.target_hook.set_variable(**variable)
else:
raise RuntimeError("Variable not found! " + self.variable_key)


def starship_variables_migration(variables: List[str] = None, **kwargs):
"""TaskGroup to fetch and migrate Variables from one Airflow instance to another."""
with TaskGroup("variables") as tg:

@task()
def get_variables():
_variables = StarshipLocalHook().get_variables()

_variables = (
[k["key"] for k in _variables if k["key"] in variables]
if variables is not None
else [k["key"] for k in _variables]
)

if not len(_variables):
raise AirflowSkipException("Nothing to migrate")
return _variables

variables_results = get_variables()
if airflow.__version__ >= "2.3.0":
StarshipVariableMigrationOperator.partial(
task_id="migrate_variables", **kwargs
).expand(variable_key=variables_results)
else:
for variable in variables_results.output:
variables_results >> StarshipVariableMigrationOperator(
task_id="migrate_variable_" + variable,
variable_key=variable,
**kwargs,
)
return tg


class StarshipPoolMigrationOperator(StarshipMigrationOperator):
"""Operator to migrate a single Pool from one Airflow instance to another."""

def __init__(self, pool_name: Union[str, None] = None, **kwargs):
super().__init__(**kwargs)
self.pool_name = pool_name

def execute(self, context: Context) -> Any:
print("Getting Pool", self.pool_name)
pool: Union[dict, None] = (
[v for v in self.source_hook.get_pools() if v["name"] == self.pool_name]
or [None]
)[0]
if pool is not None:
print("Migrating Pool", self.pool_name)
self.target_hook.set_pool(**pool)
else:
raise RuntimeError("Pool not found!")


def starship_pools_migration(pools: List[str] = None, **kwargs):
"""TaskGroup to fetch and migrate Pools from one Airflow instance to another."""
with TaskGroup("pools") as tg:

@task()
def get_pools():
_pools = StarshipLocalHook().get_pools()
_pools = (
[k["name"] for k in _pools if k["name"] in pools]
if pools is not None
else [k["name"] for k in _pools]
)

if not len(_pools):
raise AirflowSkipException("Nothing to migrate")
return _pools

pools_result = get_pools()
if airflow.__version__ >= "2.3.0":
StarshipPoolMigrationOperator.partial(
task_id="migrate_pools", **kwargs
).expand(pool_name=pools_result)
else:
for pool in pools_result.output:
pools_result >> StarshipPoolMigrationOperator(
task_id="migrate_pool_" + pool, pool_name=pool, **kwargs
)
return tg


from astronomer_starship.compat.starship_hook import StarshipDagRunMigrationHook
class StarshipConnectionMigrationOperator(StarshipMigrationOperator):
"""Operator to migrate a single Connection from one Airflow instance to another."""

def __init__(self, connection_id: Union[str, None] = None, **kwargs):
super().__init__(**kwargs)
self.connection_id = connection_id

class StarshipOperator(BaseOperator):
def execute(self, context: Context) -> Any:
print("Getting Connection", self.connection_id)
connection: Union[dict, None] = (
[
v
for v in self.source_hook.get_connections()
if v["conn_id"] == self.connection_id
]
or [None]
)[0]
if connection is not None:
print("Migrating Connection", self.connection_id)
self.target_hook.set_connection(**connection)
else:
raise RuntimeError("Connection not found!")


def starship_connections_migration(connections: List[str] = None, **kwargs):
"""TaskGroup to fetch and migrate Connections from one Airflow instance to another."""
with TaskGroup("connections") as tg:

@task()
def get_connections():
_connections = StarshipLocalHook().get_connections()
_connections = (
[k["conn_id"] for k in _connections if k["conn_id"] in connections]
if connections is not None
else [k["conn_id"] for k in _connections]
)

if not len(_connections):
raise AirflowSkipException("Nothing to migrate")
return _connections

connections_result = get_connections()
if airflow.__version__ >= "2.3.0":
StarshipConnectionMigrationOperator.partial(
task_id="migrate_connections", **kwargs
).expand(connection_id=connections_result)
else:
for connection in connections_result.output:
connections_result >> StarshipConnectionMigrationOperator(
task_id="migrate_connection_" + connection.conn_id,
connection_id=connection,
**kwargs,
)
return tg


class StarshipDagHistoryMigrationOperator(StarshipMigrationOperator):
"""Operator to migrate a single DAG from one Airflow instance to another, with it's history."""

def __init__(
self,
target_dag_id: str,
unpause_dag_in_target: bool = False,
dag_run_limit: int = 10,
**kwargs,
):
super().__init__(**kwargs)
self.target_dag_id = target_dag_id
self.unpause_dag_in_target = unpause_dag_in_target
self.dag_run_limit = dag_run_limit

def execute(self, context):
print("Pausing local DAG for", self.target_dag_id)
self.source_hook.set_dag_is_paused(dag_id=self.target_dag_id, is_paused=True)
# TODO - Poll until all tasks are done

print("Getting local DAG Runs for", self.target_dag_id)
dag_runs = self.source_hook.get_dag_runs(
dag_id=self.target_dag_id, limit=self.dag_run_limit
)
if len(dag_runs["dag_runs"]) == 0:
raise AirflowSkipException("No DAG Runs found for " + self.target_dag_id)

print("Getting local Task Instances for", self.target_dag_id)
task_instances = self.source_hook.get_task_instances(
dag_id=self.target_dag_id, limit=self.dag_run_limit
)
if len(task_instances["task_instances"]) == 0:
raise AirflowSkipException(
"No Task Instances found for " + self.target_dag_id
)

print("Setting target DAG Runs for", self.target_dag_id)
self.target_hook.set_dag_runs(dag_runs=dag_runs["dag_runs"])

print("Setting target Task Instances for", self.target_dag_id)
self.target_hook.set_task_instances(
task_instances=task_instances["task_instances"]
)

if self.unpause_dag_in_target:
print("Unpausing target DAG for", self.target_dag_id)
self.target_hook.set_dag_is_paused(
dag_id=self.target_dag_id, is_paused=False
)


def starship_dag_history_migration(dag_ids: List[str] = None, **kwargs):
"""TaskGroup to fetch and migrate DAGs with their history from one Airflow instance to another."""
with TaskGroup("dag_history") as tg:

@task()
def get_dags():
_dags = StarshipLocalHook().get_dags()
_dags = (
[
k["dag_id"]
for k in _dags
if k["dag_id"] in dag_ids
and k["dag_id"] != "StarshipAirflowMigrationDAG"
]
if dag_ids is not None
else [
k["dag_id"]
for k in _dags
if k["dag_id"] != "StarshipAirflowMigrationDAG"
]
)

if not len(_dags):
raise AirflowSkipException("Nothing to migrate")
return _dags

dags_result = get_dags()
if airflow.__version__ >= "2.3.0":
StarshipDagHistoryMigrationOperator.partial(
task_id="migrate_dag_ids",
**(
{"map_index_template": "{{ task.target_dag_id }}"}
if airflow.__version__ >= "2.9.0"
else {}
),
**kwargs,
).expand(target_dag_id=dags_result)
else:
for dag_id in dags_result.output:
dags_result >> StarshipDagHistoryMigrationOperator(
task_id="migrate_dag_" + dag_id, target_dag_id=dag_id, **kwargs
)
return tg


# noinspection PyPep8Naming
def StarshipAirflowMigrationDAG(
http_conn_id: str,
variables: List[str] = None,
pools: List[str] = None,
connections: List[str] = None,
dag_ids: List[str] = None,
**kwargs,
):
"""
Migrate dag run and task run history by using this operator as a task in a DAG
DAG to fetch and migrate Variables, Pools, Connections, and DAGs with history from one Airflow instance to another.
"""
dag = DAG(
dag_id="starship_airflow_migration_dag",
schedule="@once",
start_date=datetime(1970, 1, 1),
tags=["migration", "starship"],
default_args={"owner": "Astronomer"},
doc_md="""
# Starship Migration DAG
A DAG to migrate Airflow Variables, Pools, Connections, and DAG History from one Airflow instance to another.

def __init__(self, hook: StarshipDagRunMigrationHook = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hook = hook
You can use this DAG to migrate all items, or specific items by providing a list of names.

def execute(self, context):
conf = context["conf"]
You can skip migration by providing an empty list.

## Setup:
Make a connection in Airflow with the following details:
- **Conn ID**: `starship_default`
- **Conn Type**: `HTTP`
- **Host**: the URL of the homepage of Airflow (excluding `/home` on the end of the URL)
- For example, if your deployment URL is `https://astronomer.astronomer.run/abcdt4ry/home`, you'll use `https://astronomer.astronomer.run/abcdt4ry`
- **Schema**: `https`
- **Extras**: `{"Authorization": "Bearer <token>"}`

if not self.hook:
self.hook = StarshipDagRunMigrationHook(**conf)
## Usage:
```python
from astronomer_starship.providers.starship.operators.starship import (
StarshipAirflowMigrationDAG,
)

return self.hook.load_dagruns_to_target(dag_ids=conf.get("dag_ids"))
globals()["starship_airflow_migration_dag"] = StarshipAirflowMigrationDAG(
http_conn_id="starship_default",
variables=None, # None to migrate all, or ["var1", "var2"] to migrate specific items, or empty list to skip all
pools=None, # None to migrate all, or ["pool1", "pool2"] to migrate specific items, or empty list to skip all
connections=None, # None to migrate all, or ["conn1", "conn2"] to migrate specific items, or empty list to skip all
dag_ids=None, # None to migrate all, or ["dag1", "dag2"] to migrate specific items, or empty list to skip all
)
```
""", # noqa: E501
)
with dag:
starship_variables_migration(
variables=variables, http_conn_id=http_conn_id, **kwargs
)
starship_pools_migration(pools=pools, http_conn_id=http_conn_id, **kwargs)
starship_connections_migration(
connections=connections, http_conn_id=http_conn_id, **kwargs
)
starship_dag_history_migration(
dag_ids=dag_ids, http_conn_id=http_conn_id, **kwargs
)
return dag
Loading
Loading