diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/connections.py b/dbt-bigquery/src/dbt/adapters/bigquery/connections.py index bb062f330..e0eed5be1 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/connections.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/connections.py @@ -221,6 +221,12 @@ def get_labels_from_query_comment(cls): return {} + def get_job_labels(self): + labels = self.get_labels_from_query_comment() + labels["dbt_invocation_id"] = get_invocation_id() + + return labels + def generate_job_id(self) -> str: # Generating a fresh job_id for every _query_and_results call to avoid job_id reuse. # Generating a job id instead of persisting a BigQuery-generated one after client.query is called. @@ -244,9 +250,7 @@ def raw_execute( fire_event(SQLQuery(conn_name=conn.name, sql=sql, node_info=get_node_info())) - labels = self.get_labels_from_query_comment() - - labels["dbt_invocation_id"] = get_invocation_id() + labels = self.get_job_labels() job_params = { "use_legacy_sql": use_legacy_sql, @@ -424,6 +428,7 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: destination_ref = self.table_ref( destination.database, destination.schema, destination.table ) + labels = self.get_job_labels() logger.debug( 'Copying table(s) "{}" to "{}" with disposition: "{}"', @@ -440,7 +445,7 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: copy_job = client.copy_table( source_ref_array, destination_ref, - job_config=CopyJobConfig(write_disposition=write_disposition), + job_config=CopyJobConfig(write_disposition=write_disposition, labels=labels), retry=self._retry.create_reopen_with_deadline(conn), ) copy_job.result(timeout=self._retry.create_job_execution_timeout(fallback=300)) @@ -456,10 +461,12 @@ def write_dataframe_to_table( field_delimiter: str, fallback_timeout: Optional[float] = None, ) -> None: + labels = self.get_job_labels() load_config = LoadJobConfig( skip_leading_rows=1, schema=table_schema, field_delimiter=field_delimiter, + labels=labels, ) table = self.table_ref(database, schema, identifier) self._write_file_to_table(client, file_path, table, load_config, fallback_timeout) @@ -477,6 +484,9 @@ def write_file_to_table( config = kwargs["kwargs"] if "schema" in config: config["schema"] = json.load(config["schema"]) + if "labels" not in config: + config["labels"] = self.get_job_labels() + load_config = LoadJobConfig(**config) table = self.table_ref(database, schema, identifier) self._write_file_to_table(client, file_path, table, load_config, fallback_timeout) diff --git a/dbt-bigquery/tests/unit/test_bigquery_connection_manager.py b/dbt-bigquery/tests/unit/test_bigquery_connection_manager.py index e7afd692f..706935d5e 100644 --- a/dbt-bigquery/tests/unit/test_bigquery_connection_manager.py +++ b/dbt-bigquery/tests/unit/test_bigquery_connection_manager.py @@ -128,12 +128,23 @@ def test_copy_bq_table_truncates(self): kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_TRUNCATE ) - def test_job_labels_valid_json(self): + @patch("dbt.adapters.bigquery.connections.get_invocation_id", new=Mock(return_value="uuid4")) + @patch.object(BigQueryConnectionManager, "get_labels_from_query_comment") + def test_get_job_labels(self, mock_get_query_comment_labels): + query_labels = {"key": "value"} + expected = {**query_labels, "dbt_invocation_id": "uuid4"} + mock_get_query_comment_labels.return_value = \ + self.connections._labels_from_query_comment(json.dumps(query_labels)) + + labels = self.connections.get_job_labels() + self.assertEqual(labels, expected) + + def test_job_labels_from_query_comment_valid_json(self): expected = {"key": "value"} labels = self.connections._labels_from_query_comment(json.dumps(expected)) self.assertEqual(labels, expected) - def test_job_labels_invalid_json(self): + def test_job_labels_from_query_comment_invalid_json(self): labels = self.connections._labels_from_query_comment("not json") self.assertEqual(labels, {"query_comment": "not_json"})