Skip to content

Commit

Permalink
Add group_display_name attribute to allow specifying a custom displ…
Browse files Browse the repository at this point in the history
…ay name in the UI for TaskGroup (#45264)

* Add group_display_name attribute to allow specifying a custom display name in the UI for TaskGroup

* Add new tests and update old test to suppport the changes made.
  • Loading branch information
hardeybisey authored Jan 4, 2025
1 parent 1e04741 commit 3a9a032
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 2 deletions.
1 change: 1 addition & 0 deletions airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def task_group(
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
group_display_name: str = "",
) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, FReturn]]: ...


Expand Down
2 changes: 2 additions & 0 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@
"type": "object",
"required": [
"_group_id",
"group_display_name",
"prefix_group_id",
"children",
"tooltip",
Expand All @@ -322,6 +323,7 @@
],
"properties": {
"_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]},
"group_display_name": {"type": "string" },
"is_mapped": { "type": "boolean" },
"prefix_group_id": { "type": "boolean" },
"children": { "$ref": "#/definitions/dict" },
Expand Down
3 changes: 2 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
# When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
encoded = {
"_group_id": task_group._group_id,
"group_display_name": task_group.group_display_name,
"prefix_group_id": task_group.prefix_group_id,
"tooltip": task_group.tooltip,
"ui_color": task_group.ui_color,
Expand Down Expand Up @@ -1822,7 +1823,7 @@ def deserialize_task_group(
group_id = cls.deserialize(encoded_group["_group_id"])
kwargs = {
key: cls.deserialize(encoded_group[key])
for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor", "group_display_name"]
}

if not encoded_group.get("is_mapped"):
Expand Down
4 changes: 3 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ class TaskGroup(DAGNode):
:param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
:param add_suffix_on_collision: If this task group name already exists,
automatically add `__1` etc suffixes
:param group_display_name: If set, this will be the display name for the TaskGroup node in the UI.
"""

_group_id: str | None = attrs.field(
validator=attrs.validators.optional(attrs.validators.instance_of(str))
)
group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str))
prefix_group_id: bool = attrs.field(default=True)
parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group)
dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True))
Expand Down Expand Up @@ -270,7 +272,7 @@ def group_id(self) -> str | None:
@property
def label(self) -> str | None:
"""group_id excluding parent's group_id used as the node label in UI."""
return self._group_id
return self.group_display_name or self._group_id

def update_relative(
self,
Expand Down
16 changes: 16 additions & 0 deletions tests/decorators/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,19 @@ def another_tg():
assert test_task.retries == 1
assert test_task.owner == "y"
assert test_task.execution_timeout == timedelta(seconds=10)


def test_task_group_display_name_used_as_label():
"""Test that the group_display_name for TaskGroup is used as the label for display on the UI."""

@dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1))
def pipeline():
@task_group(group_display_name="my_custom_name")
def tg():
pass

tg()

p = pipeline()

assert p.task_group_dict["tg"].label == "my_custom_name"
2 changes: 2 additions & 0 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
},
"task_group": {
"_group_id": None,
"group_display_name": "",
"prefix_group_id": True,
"children": {
"bash_task": ("operator", "bash_task"),
Expand Down Expand Up @@ -2994,6 +2995,7 @@ def tg(a: str) -> None:
"type": "dict-of-lists",
"value": {"__type": "dict", "__var": {"a": [".", ".."]}},
},
"group_display_name": "",
"is_mapped": True,
"prefix_group_id": True,
"tooltip": "",
Expand Down
29 changes: 29 additions & 0 deletions tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,3 +1744,32 @@ def test_task_group_with_invalid_arg_type_raises_error():
with DAG(dag_id="dag_with_tg_invalid_arg_type", schedule=None):
with pytest.raises(TypeError, match=error_msg):
_ = TaskGroup("group_1", ui_color=123)


def test_task_group_display_name_used_as_label():
"""Test that the group_display_name for TaskGroup is used as the label for display on the UI."""

with DAG(dag_id="display_name", schedule=None, start_date=pendulum.datetime(2022, 1, 1)) as dag:
with TaskGroup(group_id="tg", group_display_name="my_custom_name") as tg:
task1 = BaseOperator(task_id="task1")
task2 = BaseOperator(task_id="task2")
task1 >> task2

assert tg.group_id == "tg"
assert tg.label == "my_custom_name"
expected_node_id = {
"id": None,
"label": None,
"children": [
{
"id": "tg",
"label": "my_custom_name",
"children": [
{"id": "tg.task1", "label": "task1"},
{"id": "tg.task2", "label": "task2"},
],
},
],
}

assert extract_node_id(task_group_to_dict(dag.task_group), include_label=True) == expected_node_id

0 comments on commit 3a9a032

Please sign in to comment.