diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 591ba39018e1d..2fabd29157ddc 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -189,6 +189,7 @@ def task_group( ui_color: str = "CornflowerBlue", ui_fgcolor: str = "#000", add_suffix_on_collision: bool = False, + group_display_name: str = "", ) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, FReturn]]: ... diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index 1cc9c42db0737..4eaae95ba2b01 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -310,6 +310,7 @@ "type": "object", "required": [ "_group_id", + "group_display_name", "prefix_group_id", "children", "tooltip", @@ -322,6 +323,7 @@ ], "properties": { "_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]}, + "group_display_name": {"type": "string" }, "is_mapped": { "type": "boolean" }, "prefix_group_id": { "type": "boolean" }, "children": { "$ref": "#/definitions/dict" }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 3263f3dc5c320..a0e5da74145cf 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1787,6 +1787,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None: # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur encoded = { "_group_id": task_group._group_id, + "group_display_name": task_group.group_display_name, "prefix_group_id": task_group.prefix_group_id, "tooltip": task_group.tooltip, "ui_color": task_group.ui_color, @@ -1822,7 +1823,7 @@ def deserialize_task_group( group_id = cls.deserialize(encoded_group["_group_id"]) kwargs = { key: cls.deserialize(encoded_group[key]) - for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"] + for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor", "group_display_name"] } if not encoded_group.get("is_mapped"): diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index 52b30ba31f8af..07f8b452c1954 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -101,11 +101,13 @@ class TaskGroup(DAGNode): :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI :param add_suffix_on_collision: If this task group name already exists, automatically add `__1` etc suffixes + :param group_display_name: If set, this will be the display name for the TaskGroup node in the UI. """ _group_id: str | None = attrs.field( validator=attrs.validators.optional(attrs.validators.instance_of(str)) ) + group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) prefix_group_id: bool = attrs.field(default=True) parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group) dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True)) @@ -270,7 +272,7 @@ def group_id(self) -> str | None: @property def label(self) -> str | None: """group_id excluding parent's group_id used as the node label in UI.""" - return self._group_id + return self.group_display_name or self._group_id def update_relative( self, diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index ce1b518a8ff59..fb38c95759a26 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -333,3 +333,19 @@ def another_tg(): assert test_task.retries == 1 assert test_task.owner == "y" assert test_task.execution_timeout == timedelta(seconds=10) + + +def test_task_group_display_name_used_as_label(): + """Test that the group_display_name for TaskGroup is used as the label for display on the UI.""" + + @dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1)) + def pipeline(): + @task_group(group_display_name="my_custom_name") + def tg(): + pass + + tg() + + p = pipeline() + + assert p.task_group_dict["tg"].label == "my_custom_name" diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 3955d17477bab..d806e162d0c83 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -140,6 +140,7 @@ }, "task_group": { "_group_id": None, + "group_display_name": "", "prefix_group_id": True, "children": { "bash_task": ("operator", "bash_task"), @@ -2994,6 +2995,7 @@ def tg(a: str) -> None: "type": "dict-of-lists", "value": {"__type": "dict", "__var": {"a": [".", ".."]}}, }, + "group_display_name": "", "is_mapped": True, "prefix_group_id": True, "tooltip": "", diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 71d6748b44ab1..4d73349d08314 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -1744,3 +1744,32 @@ def test_task_group_with_invalid_arg_type_raises_error(): with DAG(dag_id="dag_with_tg_invalid_arg_type", schedule=None): with pytest.raises(TypeError, match=error_msg): _ = TaskGroup("group_1", ui_color=123) + + +def test_task_group_display_name_used_as_label(): + """Test that the group_display_name for TaskGroup is used as the label for display on the UI.""" + + with DAG(dag_id="display_name", schedule=None, start_date=pendulum.datetime(2022, 1, 1)) as dag: + with TaskGroup(group_id="tg", group_display_name="my_custom_name") as tg: + task1 = BaseOperator(task_id="task1") + task2 = BaseOperator(task_id="task2") + task1 >> task2 + + assert tg.group_id == "tg" + assert tg.label == "my_custom_name" + expected_node_id = { + "id": None, + "label": None, + "children": [ + { + "id": "tg", + "label": "my_custom_name", + "children": [ + {"id": "tg.task1", "label": "task1"}, + {"id": "tg.task2", "label": "task2"}, + ], + }, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group), include_label=True) == expected_node_id