diff --git a/dbt-adapters/src/dbt/adapters/base/impl.py b/dbt-adapters/src/dbt/adapters/base/impl.py index 9bd9358fd..9175ac582 100644 --- a/dbt-adapters/src/dbt/adapters/base/impl.py +++ b/dbt-adapters/src/dbt/adapters/base/impl.py @@ -42,6 +42,7 @@ NotImplementedError, UnexpectedNullError, ) +from dbt_common.record import auto_record_function, supports_replay, Recorder from dbt_common.utils import ( AttrDict, cast_to_str, @@ -90,8 +91,21 @@ ) from dbt.adapters.protocol import AdapterConfig, MacroContextGeneratorCallable -if TYPE_CHECKING: - import agate +# if TYPE_CHECKING: +import agate + +from dbt.adapters.record.serialization import AdapterExecuteSerializer, PartitionsMetadataSerializer, \ + AgateTableSerializer + +ExecuteReturn = Tuple[AdapterResponse, agate.Table] + +Recorder.register_serialization_strategy(ExecuteReturn, AdapterExecuteSerializer()) + +PartitionsMetadata = Tuple[agate.Table] + +Recorder.register_serialization_strategy(PartitionsMetadata, PartitionsMetadataSerializer()) +Recorder.register_serialization_strategy(agate.Table, AgateTableSerializer()) + GET_CATALOG_MACRO_NAME = "get_catalog" @@ -383,13 +397,15 @@ def connection_named( self.connections.query_header.reset() @available.parse(_parse_callback_empty_table) + @auto_record_function("AdapterExecute", group="Available") def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None, - ) -> Tuple[AdapterResponse, "agate.Table"]: + ) -> ExecuteReturn: + """Execute the given SQL. This is a thin wrapper around ConnectionManager.execute. @@ -414,8 +430,10 @@ def validate_sql(self, sql: str) -> AdapterResponse: """ raise NotImplementedError("`validate_sql` is not implemented for this adapter!") + @auto_record_function("AdapterGetColumnSchemaFromQuery", group="Available") @available.parse(lambda *a, **k: []) def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]: + """Get a list of the Columns with names and data types from the given sql.""" _, cursor = self.connections.add_select_query(sql) columns = [ @@ -427,8 +445,10 @@ def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]: ] return columns + @auto_record_function("AdapterGetPartitionsMetadata", group="Available") @available.parse(_parse_callback_empty_table) - def get_partitions_metadata(self, table: str) -> Tuple["agate.Table"]: + def get_partitions_metadata(self, table: str) -> PartitionsMetadata: + """ TODO: Can we move this to dbt-bigquery? Obtain partitions metadata for a BigQuery partitioned table. @@ -576,8 +596,10 @@ def set_relations_cache( self.cache.clear() self._relations_cache_for_schemas(relation_configs, required_schemas) + @auto_record_function("AdapterCacheAdded", group="Available") @available def cache_added(self, relation: Optional[BaseRelation]) -> str: + """Cache a new relation in dbt. It will show up in `list relations`.""" if relation is None: name = self.nice_connection_name() @@ -586,8 +608,10 @@ def cache_added(self, relation: Optional[BaseRelation]) -> str: # so jinja doesn't render things return "" + @auto_record_function("AdapterCacheDropped", group="Available") @available def cache_dropped(self, relation: Optional[BaseRelation]) -> str: + """Drop a relation in dbt. It will no longer show up in `list relations`, and any bound views will be dropped from the cache """ @@ -597,6 +621,7 @@ def cache_dropped(self, relation: Optional[BaseRelation]) -> str: self.cache.drop(relation) return "" + @auto_record_function("AdapterCacheRenamed", group="Available") @available def cache_renamed( self, @@ -637,8 +662,10 @@ def list_schemas(self, database: str) -> List[str]: """Get a list of existing schemas in database""" raise NotImplementedError("`list_schemas` is not implemented for this adapter!") + @auto_record_function("AdapterCheckSchemaExists", group="Available") @available.parse(lambda *a, **k: False) def check_schema_exists(self, database: str, schema: str) -> bool: + """Check if a schema exists. The default implementation of this is potentially unnecessarily slow, @@ -651,6 +678,7 @@ def check_schema_exists(self, database: str, schema: str) -> bool: ### # Abstract methods about relations ### + @auto_record_function("AdapterDropRelation", group="Available") @abc.abstractmethod @available.parse_none def drop_relation(self, relation: BaseRelation) -> None: @@ -660,12 +688,14 @@ def drop_relation(self, relation: BaseRelation) -> None: """ raise NotImplementedError("`drop_relation` is not implemented for this adapter!") + @auto_record_function("AdapterTruncateRelation", group="Available") @abc.abstractmethod @available.parse_none def truncate_relation(self, relation: BaseRelation) -> None: """Truncate the given relation.""" raise NotImplementedError("`truncate_relation` is not implemented for this adapter!") + @auto_record_function("AdapterRenameRelation", group="Available") @abc.abstractmethod @available.parse_none def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: @@ -675,6 +705,7 @@ def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation """ raise NotImplementedError("`rename_relation` is not implemented for this adapter!") + @auto_record_function("AdapterGetColumnsInRelation", group="Available") @abc.abstractmethod @available.parse_list def get_columns_in_relation(self, relation: BaseRelation) -> List[BaseColumn]: @@ -687,8 +718,10 @@ def get_catalog_for_single_relation(self, relation: BaseRelation) -> Optional[Ca "`get_catalog_for_single_relation` is not implemented for this adapter!" ) + @auto_record_function("AdapterGetColumnsInTable", group="Available") @available.deprecated("get_columns_in_relation", lambda *a, **k: []) def get_columns_in_table(self, schema: str, identifier: str) -> List[BaseColumn]: + """DEPRECATED: Get a list of the columns in the given table.""" relation = self.Relation.create( database=self.config.credentials.database, @@ -729,8 +762,9 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> List[ ### # Methods about grants ### + @auto_record_function("AdapterStandardizeGrantsDict", group="Available") @available - def standardize_grants_dict(self, grants_table: "agate.Table") -> dict: + def standardize_grants_dict(self, grants_table: agate.Table) -> dict: """Translate the result of `show grants` (or equivalent) to match the grants which a user would configure in their project. @@ -743,6 +777,7 @@ def standardize_grants_dict(self, grants_table: "agate.Table") -> dict: :return: A standardized dictionary matching the `grants` config :rtype: dict """ + grants_dict: Dict[str, List[str]] = {} for row in grants_table: grantee = row["grantee"] @@ -756,10 +791,12 @@ def standardize_grants_dict(self, grants_table: "agate.Table") -> dict: ### # Provided methods about relations ### + @auto_record_function("AdapterGetMissingColumns", group="Available") @available.parse_list def get_missing_columns( self, from_relation: BaseRelation, to_relation: BaseRelation ) -> List[BaseColumn]: + """Returns a list of Columns in from_relation that are missing from to_relation. """ @@ -787,10 +824,12 @@ def get_missing_columns( return [col for (col_name, col) in from_columns.items() if col_name in missing_columns] + @auto_record_function("AdapterValidSnapshotTarget", group="Available") @available.parse_none def valid_snapshot_target( self, relation: BaseRelation, column_names: Optional[Dict[str, str]] = None ) -> None: + """Ensure that the target relation is valid, by making sure it has the expected columns. @@ -819,10 +858,12 @@ def valid_snapshot_target( if missing: raise SnapshotTargetNotSnapshotTableError(missing) + @auto_record_function("AdapterAssertValidSnapshotTargetGivenStrategy", group="Available") @available.parse_none def assert_valid_snapshot_target_given_strategy( self, relation: BaseRelation, column_names: Dict[str, str], strategy: SnapshotStrategy ) -> None: + # Assert everything we can with the legacy function. self.valid_snapshot_target(relation, column_names) @@ -841,10 +882,12 @@ def assert_valid_snapshot_target_given_strategy( if missing: raise SnapshotTargetNotSnapshotTableError(missing) + @auto_record_function("AdapterExpandTargetColumnTypes", group="Available") @available.parse_none def expand_target_column_types( self, from_relation: BaseRelation, to_relation: BaseRelation ) -> None: + if not isinstance(from_relation, self.Relation): raise MacroArgTypeError( method_name="expand_target_column_types", @@ -925,18 +968,23 @@ def _make_match( schema: str, identifier: str, ) -> List[BaseRelation]: - matches = [] + try: + matches = [] - search = self._make_match_kwargs(database, schema, identifier) + search = self._make_match_kwargs(database, schema, identifier) - for relation in relations_list: - if relation.matches(**search): - matches.append(relation) + for relation in relations_list: + if relation.matches(**search): + matches.append(relation) - return matches + return matches + except Exception as e: + pass + @auto_record_function("AdapterGetRelation", group="Available") @available.parse_none def get_relation(self, database: str, schema: str, identifier: str) -> Optional[BaseRelation]: + relations_list = self.list_relations(database, schema) matches = self._make_match(relations_list, database, schema, identifier) @@ -954,9 +1002,11 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[ return None + @auto_record_function("AdapterAlreadyExists", group="Available") @available.deprecated("get_relation", lambda *a, **k: False) def already_exists(self, schema: str, name: str) -> bool: """DEPRECATED: Return if a model already exists in the database""" + database = self.config.credentials.database relation = self.get_relation(database, schema, name) return relation is not None @@ -965,12 +1015,14 @@ def already_exists(self, schema: str, name: str) -> bool: # ODBC FUNCTIONS -- these should not need to change for every adapter, # although some adapters may override them ### + @auto_record_function("AdapterCreateSchema", group="Available") @abc.abstractmethod @available.parse_none def create_schema(self, relation: BaseRelation): """Create the given schema if it does not exist.""" raise NotImplementedError("`create_schema` is not implemented for this adapter!") + @auto_record_function("AdapterDropSchema", group="Available") @abc.abstractmethod @available.parse_none def drop_schema(self, relation: BaseRelation): @@ -980,10 +1032,12 @@ def drop_schema(self, relation: BaseRelation): @available @classmethod @abc.abstractmethod + @auto_record_function("AdapterQuote", group="Available") def quote(cls, identifier: str) -> str: """Quote the given identifier, as appropriate for the database.""" raise NotImplementedError("`quote` is not implemented for this adapter!") + @auto_record_function("AdapterQuoteAsConfigured", group="Available") @available def quote_as_configured(self, identifier: str, quote_key: str) -> str: """Quote or do not quote the given identifer as configured in the @@ -992,6 +1046,7 @@ def quote_as_configured(self, identifier: str, quote_key: str) -> str: The quote key should be one of 'database' (on bigquery, 'profile'), 'identifier', or 'schema', or it will be treated as if you set `True`. """ + try: key = ComponentName(quote_key) except ValueError: @@ -1003,8 +1058,10 @@ def quote_as_configured(self, identifier: str, quote_key: str) -> str: else: return identifier + @auto_record_function("AdapterQuoteSeedColumn", group="Available") @available def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str: + quote_columns: bool = True if isinstance(quote_config, bool): quote_columns = quote_config @@ -1107,7 +1164,9 @@ def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str: @available @classmethod - def convert_type(cls, agate_table: "agate.Table", col_idx: int) -> Optional[str]: + @auto_record_function("AdapterConvertType", group="Available") + def convert_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[str]: + return cls.convert_agate_type(agate_table, col_idx) @classmethod @@ -1617,6 +1676,8 @@ def valid_incremental_strategies(self): """ return ["append"] + "".format() + def builtin_incremental_strategies(self): """ List of possible builtin strategies for adapters @@ -1709,7 +1770,9 @@ def render_column_constraint(cls, constraint: ColumnLevelConstraint) -> Optional @available @classmethod - def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]]) -> List: + @auto_record_function("AdapterRenderRawColumnConstraints", group="Available") + def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]]) -> List[str]: + rendered_column_constraints = [] for v in raw_columns.values(): @@ -1763,7 +1826,9 @@ def _parse_model_constraint(cls, raw_constraint: Dict[str, Any]) -> ModelLevelCo @available @classmethod + @auto_record_function("AdapterRenderRawModelConstraints", group="Available") def render_raw_model_constraints(cls, raw_constraints: List[Dict[str, Any]]) -> List[str]: + return [c for c in map(cls.render_raw_model_constraint, raw_constraints) if c is not None] @classmethod @@ -1835,7 +1900,10 @@ def _get_adapter_specific_run_info(cls, config) -> Dict[str, Any]: @available.parse_none @classmethod - def get_hard_deletes_behavior(cls, config): + # @auto_record_function("AdapterGetHardDeletesBehavior", group="Available") + # TODO: type is a lie + def get_hard_deletes_behavior(cls, config: Dict[str, str]) -> str: + """Check the hard_deletes config enum, and the legacy invalidate_hard_deletes config flag in order to determine which behavior should be used for deleted records in a snapshot. The default is to ignore them.""" diff --git a/dbt-adapters/src/dbt/adapters/record/cursor/cursor.py b/dbt-adapters/src/dbt/adapters/record/cursor/cursor.py index 577178dbb..24d7de8c0 100644 --- a/dbt-adapters/src/dbt/adapters/record/cursor/cursor.py +++ b/dbt-adapters/src/dbt/adapters/record/cursor/cursor.py @@ -51,4 +51,4 @@ def rowcount(self) -> int: @property @record_function(CursorGetDescriptionRecord, method=True, id_field_name="connection_name") def description(self) -> str: - return self.native_cursor.description + return self.native_cursor.description \ No newline at end of file diff --git a/dbt-adapters/src/dbt/adapters/record/handle.py b/dbt-adapters/src/dbt/adapters/record/handle.py index 31817c374..1b3a65cb2 100644 --- a/dbt-adapters/src/dbt/adapters/record/handle.py +++ b/dbt-adapters/src/dbt/adapters/record/handle.py @@ -22,3 +22,25 @@ def cursor(self) -> Any: # actual database access should be performed in that mode. cursor = None if self.native_handle is None else self.native_handle.cursor() return RecordReplayCursor(cursor, self.connection) + + def commit(self): + self.native_handle.commit() + + def rollback(self): + self.native_handle.rollback() + + def close(self): + # NOTE: Some native handles apparently don't have close, so this + # might cause record/replay problems. + self.native_handle.close() + + def get_backend_pid(self): + # NOTE: Some native handles apparently don't have close, so this + # might cause record/replay problems. + return self.native_handle.get_backend_pid() + + @property + def closed(self): + # NOTE: Some native handles apparently don't have close, so this + # might cause record/replay problems. + return self.native_handle.closed diff --git a/dbt-adapters/src/dbt/adapters/record/serialization.py b/dbt-adapters/src/dbt/adapters/record/serialization.py new file mode 100644 index 000000000..e04368dd0 --- /dev/null +++ b/dbt-adapters/src/dbt/adapters/record/serialization.py @@ -0,0 +1,56 @@ +from datetime import datetime, date +from decimal import Decimal +from typing import TYPE_CHECKING, Tuple, Dict, Any + +from dbt.adapters.contracts.connection import AdapterResponse + +import agate + +from mashumaro.types import SerializationStrategy + + +def _column_filter(val: Any) -> Any: + return float(val) if isinstance(val, Decimal) else str(val) if isinstance(val, datetime) else str(val) if isinstance(val, date) else str(val) + + +def _serialize_agate_table(table: agate.Table) -> Dict[str, Any]: + rows = [] + for row in table.rows: + row = list(map(_column_filter, row)) + rows.append(row) + + return { + "column_names": table.column_names, + "column_types": [t.__class__.__name__ for t in table.column_types], + "rows": rows + } + + +class AdapterExecuteSerializer(SerializationStrategy): + def serialize(self, table: Tuple[AdapterResponse, agate.Table]): + adapter_response, agate_table = table + return { + "adapter_response": adapter_response.to_dict(), + "table": _serialize_agate_table(agate_table) + } + + def deserialize(self, data): + # TODO: + adapter_response_dct, agate_table_dct = data + return None + +class PartitionsMetadataSerializer(SerializationStrategy): + def serialize(self, tables: Tuple[agate.Table]): + return list(map(_serialize_agate_table, tables)) + + def deserialize(self, data): + # TODO: + return None + +class AgateTableSerializer(SerializationStrategy): + def serialize(self, table: agate.Table): + return _serialize_agate_table(table) + + def deserialize(self, data): + # TODO: + return None \ No newline at end of file