diff --git a/docs/apache-airflow-providers-openlineage/guides/user.rst b/docs/apache-airflow-providers-openlineage/guides/user.rst index 9b408f728ba1d..c4a12a7962960 100644 --- a/docs/apache-airflow-providers-openlineage/guides/user.rst +++ b/docs/apache-airflow-providers-openlineage/guides/user.rst @@ -413,6 +413,7 @@ This is because each emitting task sends a `ParentRunFacet `_. -.. warning:: +This configuration serves as the default behavior for all Operators that support automatic Spark properties injection, +unless it is explicitly overridden at the Operator level. +To prevent a specific Operator from injecting the parent job information while +allowing all other supported Operators to do so by default, ``openlineage_inject_parent_job_info=False`` +can be explicitly provided to that specific Operator. + +.. note:: - If any of the above properties are manually specified in the Spark job configuration, the integration will refrain from injecting parent job properties to ensure that manually provided values are preserved. + If any of the ``spark.openlineage.parent*`` properties are manually specified in the Spark job configuration, the integration will refrain from injecting parent job properties to ensure that manually provided values are preserved. You can enable this automation by setting ``spark_inject_parent_job_info`` option to ``true`` in Airflow configuration. diff --git a/docs/exts/templates/openlineage.rst.jinja2 b/docs/exts/templates/openlineage.rst.jinja2 index 4c7341ab52363..217e634457c70 100644 --- a/docs/exts/templates/openlineage.rst.jinja2 +++ b/docs/exts/templates/openlineage.rst.jinja2 @@ -29,12 +29,15 @@ Spark operators =============== The OpenLineage integration can automatically inject information into Spark application properties when its being submitted from Airflow. The following is a list of supported operators along with the corresponding information that can be injected. +See :ref:`automatic injection of parent job information ` for more details. apache-airflow-providers-google """"""""""""""""""""""""""""""" - :class:`~airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator` - Parent Job Information +- :class:`~airflow.providers.google.cloud.operators.dataproc.DataprocCreateBatchOperator` + - Parent Job Information :class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator` diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/src/airflow/providers/google/cloud/openlineage/utils.py index 53c0e4676e58c..1700ff29619e3 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import copy import logging import os import pathlib @@ -30,6 +31,8 @@ from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.utils.context import Context +from google.cloud.dataproc_v1 import Batch, RuntimeConfig + from airflow.providers.common.compat.openlineage.facet import ( BaseFacet, ColumnLineageDatasetFacet, @@ -386,3 +389,135 @@ def inject_openlineage_properties_into_dataproc_job( job=job, job_type=job_type, new_properties=properties ) return job_with_ol_config + + +def _is_dataproc_batch_of_supported_type(batch: dict | Batch) -> bool: + """ + Check if a Dataproc batch is of a supported type for Openlineage automatic injection. + + This function determines if the given batch is of a supported type + by checking for specific job type attributes or keys in the batch. + + Args: + batch: The Dataproc batch to check. + + Returns: + True if the batch is of a supported type (`spark_batch` or + `pyspark_batch`), otherwise False. + """ + supported_job_types = ("spark_batch", "pyspark_batch") + if isinstance(batch, Batch): + if any(getattr(batch, job_type) for job_type in supported_job_types): + return True + return False + + # For dictionary-based batch + if any(job_type in batch for job_type in supported_job_types): + return True + return False + + +def _extract_dataproc_batch_properties(batch: dict | Batch) -> dict: + """ + Extract Dataproc batch properties from a Batch object or dictionary. + + This function retrieves the `properties` from the `runtime_config` of a + Dataproc `Batch` object or a dictionary representation of a batch. + + Args: + batch: The Dataproc batch to extract properties from. + + Returns: + Extracted `properties` if found, otherwise an empty dictionary. + """ + if isinstance(batch, Batch): + return dict(batch.runtime_config.properties) + + # For dictionary-based batch + run_time_config = batch.get("runtime_config", {}) + if isinstance(run_time_config, RuntimeConfig): + return dict(run_time_config.properties) + return run_time_config.get("properties", {}) + + +def _replace_dataproc_batch_properties(batch: dict | Batch, new_properties: dict) -> dict | Batch: + """ + Replace the properties of a Dataproc batch. + + Args: + batch: The original Dataproc batch definition. + new_properties: The new properties to replace the existing ones. + + Returns: + A modified copy of the Dataproc batch definition with updated properties. + """ + batch = copy.deepcopy(batch) + if isinstance(batch, Batch): + if not batch.runtime_config: + batch.runtime_config = RuntimeConfig(properties=new_properties) + elif isinstance(batch.runtime_config, dict): + batch.runtime_config["properties"] = new_properties + else: + batch.runtime_config.properties = new_properties + return batch + + # For dictionary-based batch + run_time_config = batch.get("runtime_config") + if not run_time_config: + batch["runtime_config"] = {"properties": new_properties} + elif isinstance(run_time_config, dict): + run_time_config["properties"] = new_properties + else: + run_time_config.properties = new_properties + return batch + + +def inject_openlineage_properties_into_dataproc_batch( + batch: dict | Batch, context: Context, inject_parent_job_info: bool +) -> dict | Batch: + """ + Inject OpenLineage properties into Dataproc batch definition. + + It's not removing any configuration or modifying the batch in any other way. + This function add desired OpenLineage properties to Dataproc batch configuration. + + Note: + Any modification to job will be skipped if: + - OpenLineage provider is not accessible. + - The batch type is not supported. + - Automatic parent job information injection is disabled. + - Any OpenLineage properties with parent job information are already present + in the Spark job configuration. + + Args: + batch: The original Dataproc batch definition. + context: The Airflow context in which the job is running. + inject_parent_job_info: Flag indicating whether to inject parent job information. + + Returns: + The modified batch definition with OpenLineage properties injected, if applicable. + """ + if not inject_parent_job_info: + log.debug("Automatic injection of OpenLineage information is disabled.") + return batch + + if not _is_openlineage_provider_accessible(): + log.warning( + "Could not access OpenLineage provider for automatic OpenLineage " + "properties injection. No action will be performed." + ) + return batch + + if not _is_dataproc_batch_of_supported_type(batch): + log.warning( + "Could not find a supported Dataproc batch type for automatic OpenLineage " + "properties injection. No action will be performed.", + ) + return batch + + properties = _extract_dataproc_batch_properties(batch) + + properties = inject_parent_job_information_into_spark_properties(properties=properties, context=context) + + batch_with_ol_config = _replace_dataproc_batch_properties(batch=batch, new_properties=properties) + return batch_with_ol_config diff --git a/providers/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/src/airflow/providers/google/cloud/operators/dataproc.py index 9bf34e0ab8c95..5e64f7d920707 100644 --- a/providers/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/src/airflow/providers/google/cloud/operators/dataproc.py @@ -55,6 +55,7 @@ DataprocWorkflowTemplateLink, ) from airflow.providers.google.cloud.openlineage.utils import ( + inject_openlineage_properties_into_dataproc_batch, inject_openlineage_properties_into_dataproc_job, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -2425,6 +2426,9 @@ def __init__( asynchronous: bool = False, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 5, + openlineage_inject_parent_job_info: bool = conf.getboolean( + "openlineage", "spark_inject_parent_job_info", fallback=False + ), **kwargs, ): super().__init__(**kwargs) @@ -2446,6 +2450,7 @@ def __init__( self.asynchronous = asynchronous self.deferrable = deferrable self.polling_interval_seconds = polling_interval_seconds + self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info def execute(self, context: Context): if self.asynchronous and self.deferrable: @@ -2468,6 +2473,14 @@ def execute(self, context: Context): else: self.log.info("Starting batch. The batch ID will be generated since it was not provided.") + if self.openlineage_inject_parent_job_info: + self.log.info("Automatic injection of OpenLineage information into Spark properties is enabled.") + self.batch = inject_openlineage_properties_into_dataproc_batch( + batch=self.batch, + context=context, + inject_parent_job_info=self.openlineage_inject_parent_job_info, + ) + try: self.operation = self.hook.create_batch( region=self.region, diff --git a/providers/tests/google/cloud/openlineage/test_utils.py b/providers/tests/google/cloud/openlineage/test_utils.py index 789ddf1f4efd6..86f87531d2400 100644 --- a/providers/tests/google/cloud/openlineage/test_utils.py +++ b/providers/tests/google/cloud/openlineage/test_utils.py @@ -22,6 +22,7 @@ import pytest from google.cloud.bigquery.table import Table +from google.cloud.dataproc_v1 import Batch, RuntimeConfig from airflow.providers.common.compat.openlineage.facet import ( ColumnLineageDatasetFacet, @@ -35,12 +36,16 @@ SymlinksDatasetFacet, ) from airflow.providers.google.cloud.openlineage.utils import ( + _extract_dataproc_batch_properties, _extract_supported_job_type_from_dataproc_job, + _is_dataproc_batch_of_supported_type, _is_openlineage_provider_accessible, + _replace_dataproc_batch_properties, _replace_dataproc_job_properties, extract_ds_name_from_gcs_path, get_facets_from_bq_table, get_identity_column_lineage_facet, + inject_openlineage_properties_into_dataproc_batch, inject_openlineage_properties_into_dataproc_job, ) @@ -419,3 +424,277 @@ def test_inject_openlineage_properties_into_dataproc_job(mock_is_ol_accessible): job = {"sparkJob": {"properties": {"existingProperty": "value"}}} result = inject_openlineage_properties_into_dataproc_job(job, context, True) assert result == {"sparkJob": {"properties": expected_properties}} + + +@pytest.mark.parametrize( + "batch, expected", + [ + ({"spark_batch": {}}, True), + ({"pyspark_batch": {}}, True), + ({"unsupported_batch": {}}, False), + ({}, False), + (Batch(spark_batch={"jar_file_uris": ["uri"]}), True), + (Batch(pyspark_batch={"main_python_file_uri": "uri"}), True), + (Batch(pyspark_batch={}), False), + (Batch(spark_sql_batch={}), False), + (Batch(), False), + ], +) +def test_is_dataproc_batch_of_supported_type(batch, expected): + assert _is_dataproc_batch_of_supported_type(batch) == expected + + +def test__extract_dataproc_batch_properties_batch_object_with_runtime_object(): + properties = {"key1": "value1", "key2": "value2"} + mock_runtime_config = RuntimeConfig(properties=properties) + mock_batch = Batch(runtime_config=mock_runtime_config) + result = _extract_dataproc_batch_properties(mock_batch) + assert result == properties + + +def test_extract_dataproc_batch_properties_batch_object_with_runtime_dict(): + properties = {"key1": "value1", "key2": "value2"} + mock_batch = Batch(runtime_config={"properties": properties}) + result = _extract_dataproc_batch_properties(mock_batch) + assert result == {"key1": "value1", "key2": "value2"} + + +def test_extract_dataproc_batch_properties_batch_object_with_runtime_object_empty(): + mock_batch = Batch(runtime_config=RuntimeConfig()) + result = _extract_dataproc_batch_properties(mock_batch) + assert result == {} + + +def test_extract_dataproc_batch_properties_dict_with_runtime_config_object(): + properties = {"key1": "value1", "key2": "value2"} + mock_runtime_config = RuntimeConfig(properties=properties) + mock_batch_dict = {"runtime_config": mock_runtime_config} + + result = _extract_dataproc_batch_properties(mock_batch_dict) + assert result == properties + + +def test_extract_dataproc_batch_properties_dict_with_properties_dict(): + properties = {"key1": "value1", "key2": "value2"} + mock_batch_dict = {"runtime_config": {"properties": properties}} + result = _extract_dataproc_batch_properties(mock_batch_dict) + assert result == properties + + +def test_extract_dataproc_batch_properties_empty_runtime_config(): + mock_batch_dict = {"runtime_config": {}} + result = _extract_dataproc_batch_properties(mock_batch_dict) + assert result == {} + + +def test_extract_dataproc_batch_properties_empty_dict(): + assert _extract_dataproc_batch_properties({}) == {} + + +def test_extract_dataproc_batch_properties_empty_batch(): + assert _extract_dataproc_batch_properties(Batch()) == {} + + +def test_replace_dataproc_batch_properties_with_batch_object(): + original_batch = Batch( + spark_batch={ + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + runtime_config=RuntimeConfig(properties={"existingProperty": "value"}), + ) + new_properties = {"newProperty": "newValue"} + expected_batch = Batch( + spark_batch={ + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + runtime_config=RuntimeConfig(properties={"newProperty": "newValue"}), + ) + + updated_batch = _replace_dataproc_batch_properties(original_batch, new_properties) + + assert updated_batch == expected_batch + assert original_batch.runtime_config.properties == {"existingProperty": "value"} + assert original_batch.spark_batch.main_class == "org.apache.spark.examples.SparkPi" + + +def test_replace_dataproc_batch_properties_with_batch_object_and_run_time_config_dict(): + original_batch = Batch( + spark_batch={ + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + runtime_config={"properties": {"existingProperty": "value"}}, + ) + new_properties = {"newProperty": "newValue"} + expected_batch = Batch( + spark_batch={ + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + runtime_config={"properties": {"newProperty": "newValue"}}, + ) + + updated_batch = _replace_dataproc_batch_properties(original_batch, new_properties) + + assert updated_batch == expected_batch + assert original_batch.runtime_config.properties == {"existingProperty": "value"} + assert original_batch.spark_batch.main_class == "org.apache.spark.examples.SparkPi" + + +def test_replace_dataproc_batch_properties_with_empty_batch_object(): + original_batch = Batch() + new_properties = {"newProperty": "newValue"} + expected_batch = Batch(runtime_config=RuntimeConfig(properties={"newProperty": "newValue"})) + + updated_batch = _replace_dataproc_batch_properties(original_batch, new_properties) + + assert updated_batch == expected_batch + assert original_batch == Batch() + + +def test_replace_dataproc_batch_properties_with_dict(): + original_batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {"existingProperty": "value"}}, + } + new_properties = {"newProperty": "newValue"} + expected_batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {"newProperty": "newValue"}}, + } + + updated_batch = _replace_dataproc_batch_properties(original_batch, new_properties) + + assert updated_batch == expected_batch + assert original_batch["runtime_config"]["properties"] == {"existingProperty": "value"} + assert original_batch["spark_batch"]["main_class"] == "org.apache.spark.examples.SparkPi" + + +def test_replace_dataproc_batch_properties_with_dict_and_run_time_config_object(): + original_batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": RuntimeConfig(properties={"existingProperty": "value"}), + } + new_properties = {"newProperty": "newValue"} + expected_batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": RuntimeConfig(properties={"newProperty": "newValue"}), + } + + updated_batch = _replace_dataproc_batch_properties(original_batch, new_properties) + + assert updated_batch == expected_batch + assert original_batch["runtime_config"].properties == {"existingProperty": "value"} + assert original_batch["spark_batch"]["main_class"] == "org.apache.spark.examples.SparkPi" + + +def test_replace_dataproc_batch_properties_with_empty_dict(): + original_batch = {} + new_properties = {"newProperty": "newValue"} + expected_batch = {"runtime_config": {"properties": {"newProperty": "newValue"}}} + + updated_batch = _replace_dataproc_batch_properties(original_batch, new_properties) + + assert updated_batch == expected_batch + assert original_batch == {} + + +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_batch_provider_not_accessible(mock_is_accessible): + mock_is_accessible.return_value = False + batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {"existingProperty": "value"}}, + } + result = inject_openlineage_properties_into_dataproc_batch(batch, None, True) + assert result == batch + + +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +@patch("airflow.providers.google.cloud.openlineage.utils._is_dataproc_batch_of_supported_type") +def test_inject_openlineage_properties_into_dataproc_batch_unsupported_batch_type( + mock_valid_job_type, mock_is_accessible +): + mock_is_accessible.return_value = True + mock_valid_job_type.return_value = False + batch = { + "unsupported_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {"existingProperty": "value"}}, + } + result = inject_openlineage_properties_into_dataproc_batch(batch, None, True) + assert result == batch + + +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +@patch("airflow.providers.google.cloud.openlineage.utils._is_dataproc_batch_of_supported_type") +def test_inject_openlineage_properties_into_dataproc_batch_no_inject_parent_job_info( + mock_valid_job_type, mock_is_accessible +): + mock_is_accessible.return_value = True + mock_valid_job_type.return_value = True + inject_parent_job_info = False + batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {"existingProperty": "value"}}, + } + result = inject_openlineage_properties_into_dataproc_batch(batch, None, inject_parent_job_info) + assert result == batch + + +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_batch(mock_is_ol_accessible): + mock_is_ol_accessible.return_value = True + context = { + "ti": MagicMock( + dag_id="dag_id", + task_id="task_id", + try_number=1, + map_index=1, + logical_date=dt.datetime(2024, 11, 11), + ) + } + batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {"existingProperty": "value"}}, + } + expected_properties = { + "existingProperty": "value", + "spark.openlineage.parentJobName": "dag_id.task_id", + "spark.openlineage.parentJobNamespace": "default", + "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + } + expected_batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": expected_properties}, + } + result = inject_openlineage_properties_into_dataproc_batch(batch, context, True) + assert result == expected_batch diff --git a/providers/tests/google/cloud/operators/test_dataproc.py b/providers/tests/google/cloud/operators/test_dataproc.py index 860373829f638..5d4a9b0d79c8d 100644 --- a/providers/tests/google/cloud/operators/test_dataproc.py +++ b/providers/tests/google/cloud/operators/test_dataproc.py @@ -2607,6 +2607,192 @@ def test_execute_batch_already_exists_cancelled(self, mock_hook): metadata=METADATA, ) + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_parent_job_info_injection(self, mock_hook, to_dict_mock, mock_ol_accessible): + expected_batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": { + "properties": { + "spark.openlineage.parentJobName": "dag_id.task_id", + "spark.openlineage.parentJobNamespace": "default", + "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + } + }, + } + context = { + "ti": MagicMock( + dag_id="dag_id", + task_id="task_id", + try_number=1, + map_index=1, + logical_date=dt.datetime(2024, 11, 11), + ) + } + + mock_ol_accessible.return_value = True + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=BATCH, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + openlineage_inject_parent_job_info=True, + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=context) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=expected_batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_parent_job_info_injection_skipped_when_already_present( + self, mock_hook, to_dict_mock, mock_ol_accessible + ): + batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": { + "properties": { + "spark.openlineage.parentJobName": "dag_id.task_id", + } + }, + } + mock_ol_accessible.return_value = True + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + openlineage_inject_parent_job_info=True, + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=MagicMock()) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless_enabled( + self, mock_hook, to_dict_mock, mock_ol_accessible + ): + batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {}}, + } + mock_ol_accessible.return_value = True + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + # not passing openlineage_inject_parent_job_info, should be False by default + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=MagicMock()) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_accessible( + self, mock_hook, to_dict_mock, mock_ol_accessible + ): + batch = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {}}, + } + mock_ol_accessible.return_value = False + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + openlineage_inject_parent_job_info=True, + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=MagicMock()) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + class TestDataprocDeleteBatchOperator: @mock.patch(DATAPROC_PATH.format("DataprocHook"))