diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 138018f14..1896c26af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,13 +46,15 @@ repos: rev: v1.11.2 hooks: - id: mypy - exclude: dbt-adapters/src/dbt/adapters/events/adapter_types_pb2.py|dbt-tests-adapter/src/dbt/__init__.py + files: (dbt-adapters|dbt-athena|dbt-bigquery|dbt-postgres|dbt-redshift|dbt-snowflake|dbt-spark)/src/dbt/adapters|dbt-tests-adapter/src/dbt/tests args: - --explicit-package-bases + - --namespace-packages - --ignore-missing-imports + - --warn-redundant-casts + - --warn-unused-ignores - --pretty - --show-error-codes - files: ^dbt-adapters/src/dbt/adapters/ additional_dependencies: - types-PyYAML - types-protobuf diff --git a/dbt-adapters/pyproject.toml b/dbt-adapters/pyproject.toml index 271987b62..d9c8721b1 100644 --- a/dbt-adapters/pyproject.toml +++ b/dbt-adapters/pyproject.toml @@ -42,12 +42,6 @@ Repository = "https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-adap Issues = "https://github.com/dbt-labs/dbt-adapters/issues" Changelog = "https://github.com/dbt-labs/dbt-adapters/blob/main/dbt-adapters/CHANGELOG.md" -[tool.mypy] -mypy_path = "third-party-stubs/" -[[tool.mypy.overrides]] -module = ["dbt.adapters.events.adapter_types_pb2"] -follow_imports = "skip" - [tool.pytest.ini_options] testpaths = ["tests/unit", "tests/functional"] addopts = "-v --color=yes -n auto" diff --git a/dbt-adapters/src/dbt/adapters/base/impl.py b/dbt-adapters/src/dbt/adapters/base/impl.py index 9bd9358fd..dfe9dc12b 100644 --- a/dbt-adapters/src/dbt/adapters/base/impl.py +++ b/dbt-adapters/src/dbt/adapters/base/impl.py @@ -1224,7 +1224,7 @@ def _get_one_catalog( kwargs = {"information_schema": information_schema, "schemas": schemas} table = self.execute_macro(GET_CATALOG_MACRO_NAME, kwargs=kwargs) - results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type] + results = self._catalog_filter_table(table, used_schemas) return results def _get_one_catalog_by_relations( @@ -1239,7 +1239,7 @@ def _get_one_catalog_by_relations( } table = self.execute_macro(GET_CATALOG_RELATIONS_MACRO_NAME, kwargs=kwargs) - results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type] + results = self._catalog_filter_table(table, used_schemas) return results def get_filtered_catalog( @@ -1435,7 +1435,7 @@ def calculate_freshness_from_metadata_batch( macro_resolver=macro_resolver, needs_conn=True, ) - adapter_response, table = result.response, result.table # type: ignore[attr-defined] + adapter_response, table = result.response, result.table adapter_responses.append(adapter_response) for row in table: diff --git a/dbt-adapters/src/dbt/adapters/base/meta.py b/dbt-adapters/src/dbt/adapters/base/meta.py index e522a0562..2d5c8d5c8 100644 --- a/dbt-adapters/src/dbt/adapters/base/meta.py +++ b/dbt-adapters/src/dbt/adapters/base/meta.py @@ -121,7 +121,7 @@ def __new__(mcls, name, bases, namespace, **kwargs) -> "AdapterMeta": # I'm not sure there is any benefit to it after poking around a bit, # but having it doesn't hurt on the python side (and omitting it could # hurt for obscure metaclass reasons, for all I know) - cls = abc.ABCMeta.__new__(mcls, name, bases, namespace, **kwargs) # type: ignore + cls = abc.ABCMeta.__new__(mcls, name, bases, namespace, **kwargs) # this is very much inspired by ABCMeta's own implementation diff --git a/dbt-adapters/src/dbt/adapters/base/relation.py b/dbt-adapters/src/dbt/adapters/base/relation.py index 7d4888e42..77290ad43 100644 --- a/dbt-adapters/src/dbt/adapters/base/relation.py +++ b/dbt-adapters/src/dbt/adapters/base/relation.py @@ -135,7 +135,7 @@ def matches( if str(self.path.get_lowered_part(k)).strip(self.quote_character) != v.lower().strip( self.quote_character ): - approximate_match = False # type: ignore[union-attr] + approximate_match = False if approximate_match and not exact_match: target = self.create(database=database, schema=schema, identifier=identifier) diff --git a/dbt-adapters/src/dbt/adapters/contracts/connection.py b/dbt-adapters/src/dbt/adapters/contracts/connection.py index 2d10c9a32..67763ccec 100644 --- a/dbt-adapters/src/dbt/adapters/contracts/connection.py +++ b/dbt-adapters/src/dbt/adapters/contracts/connection.py @@ -124,7 +124,7 @@ def resolve(self, connection: Connection) -> Connection: # and https://github.com/python/mypy/issues/5374 # for why we have type: ignore. Maybe someday dataclasses + abstract classes # will work. -@dataclass # type: ignore +@dataclass class Credentials(ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta): database: str schema: str diff --git a/dbt-adapters/src/dbt/adapters/protocol.py b/dbt-adapters/src/dbt/adapters/protocol.py index 352198663..2298468e5 100644 --- a/dbt-adapters/src/dbt/adapters/protocol.py +++ b/dbt-adapters/src/dbt/adapters/protocol.py @@ -75,7 +75,7 @@ def __call__( # TODO CT-211 -class AdapterProtocol( # type: ignore[misc] +class AdapterProtocol( Protocol, Generic[ AdapterConfig_T, diff --git a/dbt-adapters/src/dbt/adapters/relation_configs/config_base.py b/dbt-adapters/src/dbt/adapters/relation_configs/config_base.py index 62d140595..0bb4ee612 100644 --- a/dbt-adapters/src/dbt/adapters/relation_configs/config_base.py +++ b/dbt-adapters/src/dbt/adapters/relation_configs/config_base.py @@ -37,7 +37,7 @@ def from_dict(cls, kwargs_dict) -> "RelationConfigBase": Returns: the `RelationConfigBase` representation associated with the provided dict """ - return cls(**filter_null_values(kwargs_dict)) # type: ignore + return cls(**filter_null_values(kwargs_dict)) @classmethod def _not_implemented_error(cls) -> NotImplementedError: diff --git a/dbt-adapters/src/dbt/adapters/sql/impl.py b/dbt-adapters/src/dbt/adapters/sql/impl.py index 8a8473f27..86bd1250b 100644 --- a/dbt-adapters/src/dbt/adapters/sql/impl.py +++ b/dbt-adapters/src/dbt/adapters/sql/impl.py @@ -75,7 +75,7 @@ def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str: import agate # TODO CT-211 - decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined] + decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "float8" if decimals else "integer" @classmethod @@ -247,7 +247,7 @@ def validate_sql(self, sql: str) -> AdapterResponse: # return fetched output for engines where explain plans are emitted as columnar # results. Any macro override that deviates from this behavior may encounter an # assertion error in the runtime. - adapter_response = result.response # type: ignore[attr-defined] + adapter_response = result.response assert isinstance(adapter_response, AdapterResponse), ( f"Expected AdapterResponse from validate_sql macro execution, " f"got {type(adapter_response)}." diff --git a/dbt-adapters/tests/unit/fixtures/adapter.py b/dbt-adapters/tests/unit/fixtures/adapter.py index 3730a083f..eb1ecdc35 100644 --- a/dbt-adapters/tests/unit/fixtures/adapter.py +++ b/dbt-adapters/tests/unit/fixtures/adapter.py @@ -34,7 +34,7 @@ def is_cancelable(cls) -> bool: return False def list_schemas(self, database: str) -> List[str]: - return list(self.cache.schemas) + return list(schema for database, schema in self.cache.schemas if isinstance(schema, str)) ### # Abstract methods about relations diff --git a/dbt-adapters/tests/unit/fixtures/connection_manager.py b/dbt-adapters/tests/unit/fixtures/connection_manager.py index 8b353fbee..a11444121 100644 --- a/dbt-adapters/tests/unit/fixtures/connection_manager.py +++ b/dbt-adapters/tests/unit/fixtures/connection_manager.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import ContextManager, List, Optional, Tuple +from typing import Generator, List, Optional, Tuple, Any import agate @@ -15,7 +15,7 @@ class ConnectionManagerStub(BaseConnectionManager): raised_exceptions: List[Exception] @contextmanager - def exception_handler(self, sql: str) -> ContextManager: # type: ignore + def exception_handler(self, sql: str) -> Generator[None, Any, None]: # type: ignore # catch all exceptions and put them on this class for inspection in tests try: yield @@ -28,7 +28,7 @@ def cancel_open(self) -> Optional[List[str]]: names = [] for connection in self.thread_connections.values(): if connection.state == ConnectionState.OPEN: - connection.state = ConnectionState.CLOSED + connection.state = ConnectionState.CLOSED # type: ignore if name := connection.name: names.append(name) return names @@ -36,7 +36,7 @@ def cancel_open(self) -> Optional[List[str]]: @classmethod def open(cls, connection: Connection) -> Connection: # there's no database, so just change the state - connection.state = ConnectionState.OPEN + connection.state = ConnectionState.OPEN # type: ignore return connection def begin(self) -> None: diff --git a/dbt-adapters/tests/unit/fixtures/credentials.py b/dbt-adapters/tests/unit/fixtures/credentials.py index 88817f6bf..79721f08b 100644 --- a/dbt-adapters/tests/unit/fixtures/credentials.py +++ b/dbt-adapters/tests/unit/fixtures/credentials.py @@ -6,6 +6,7 @@ class CredentialsStub(Credentials): A stub for a database credentials that does not connect to a database """ + @property def type(self) -> str: return "test" diff --git a/dbt-athena/src/dbt/adapters/athena/__init__.py b/dbt-athena/src/dbt/adapters/athena/__init__.py index c2f140db6..37dd0c54e 100644 --- a/dbt-athena/src/dbt/adapters/athena/__init__.py +++ b/dbt-athena/src/dbt/adapters/athena/__init__.py @@ -4,7 +4,9 @@ from dbt.include import athena Plugin: AdapterPlugin = AdapterPlugin( - adapter=AthenaAdapter, credentials=AthenaCredentials, include_path=athena.PACKAGE_PATH + adapter=AthenaAdapter, # type:ignore + credentials=AthenaCredentials, + include_path=athena.PACKAGE_PATH, ) __all__ = [ diff --git a/dbt-athena/src/dbt/adapters/athena/column.py b/dbt-athena/src/dbt/adapters/athena/column.py index a220bf3ba..7e3fde79e 100644 --- a/dbt-athena/src/dbt/adapters/athena/column.py +++ b/dbt-athena/src/dbt/adapters/athena/column.py @@ -30,7 +30,7 @@ def is_timestamp(self) -> bool: return self.dtype.lower() in {"timestamp"} def is_array(self) -> bool: - return self.dtype.lower().startswith("array") # type: ignore + return self.dtype.lower().startswith("array") @classmethod def string_type(cls, size: int) -> str: @@ -58,7 +58,7 @@ def array_inner_type(self) -> str: if match: return match.group(1) # If for some reason there's no match, fall back to the original string - return self.dtype # type: ignore + return self.dtype def string_size(self) -> int: if not self.is_string(): @@ -72,7 +72,7 @@ def data_type(self) -> str: return self.string_type(self.string_size()) if self.is_numeric(): - return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) # type: ignore + return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) if self.is_binary(): return self.binary_type() @@ -94,4 +94,4 @@ def data_type(self) -> str: ) return self.array_type(inner_type_col.data_type) - return self.dtype # type: ignore + return self.dtype diff --git a/dbt-athena/src/dbt/adapters/athena/config.py b/dbt-athena/src/dbt/adapters/athena/config.py index c583140e9..b1a80000b 100644 --- a/dbt-athena/src/dbt/adapters/athena/config.py +++ b/dbt-athena/src/dbt/adapters/athena/config.py @@ -112,17 +112,21 @@ def set_engine_config(self) -> Dict[str, Any]: default_spark_properties: Dict[str, str] = dict( **( - DEFAULT_SPARK_PROPERTIES.get(table_type) + DEFAULT_SPARK_PROPERTIES.get(table_type, {}) if table_type.lower() in ["iceberg", "hudi", "delta_lake"] else {} ), - **DEFAULT_SPARK_PROPERTIES.get("spark_encryption") if spark_encryption else {}, + **DEFAULT_SPARK_PROPERTIES.get("spark_encryption", {}) if spark_encryption else {}, **( - DEFAULT_SPARK_PROPERTIES.get("spark_cross_account_catalog") + DEFAULT_SPARK_PROPERTIES.get("spark_cross_account_catalog", {}) if spark_cross_account_catalog else {} ), - **DEFAULT_SPARK_PROPERTIES.get("spark_requester_pays") if spark_requester_pays else {}, + **( + DEFAULT_SPARK_PROPERTIES.get("spark_requester_pays", {}) + if spark_requester_pays + else {} + ), ) default_engine_config = { diff --git a/dbt-athena/src/dbt/adapters/athena/connections.py b/dbt-athena/src/dbt/adapters/athena/connections.py index 8357af0f0..9b5d23081 100644 --- a/dbt-athena/src/dbt/adapters/athena/connections.py +++ b/dbt-athena/src/dbt/adapters/athena/connections.py @@ -111,7 +111,7 @@ def _connection_keys(self) -> Tuple[str, ...]: class AthenaCursor(Cursor): - def __init__(self, **kwargs) -> None: # type: ignore + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self._executor = ThreadPoolExecutor() @@ -224,9 +224,9 @@ def execute_with_iceberg_retries() -> AthenaCursor: return self raise OperationalError(query_execution.state_change_reason) - return execute_with_iceberg_retries() # type: ignore + return execute_with_iceberg_retries() - return inner() # type: ignore + return inner() class AthenaConnectionManager(SQLConnectionManager): @@ -236,7 +236,7 @@ def set_query_header(self, query_header_context: Dict[str, Any]) -> None: self.query_header = AthenaMacroQueryStringSetter(self.profile, query_header_context) @classmethod - def data_type_code_to_name(cls, type_code: str) -> str: + def data_type_code_to_name(cls, type_code: str) -> str: # type:ignore """ Get the string representation of the data type from the Athena metadata. Dbt performs a query to retrieve the types of the columns in the SQL query. Then these types are compared @@ -287,7 +287,7 @@ def open(cls, connection: Connection) -> Connection: config=get_boto3_config(num_retries=creds.effective_num_retries), ) - connection.state = ConnectionState.OPEN + connection.state = ConnectionState.OPEN # type:ignore connection.handle = handle except Exception as exc: @@ -295,7 +295,7 @@ def open(cls, connection: Connection) -> Connection: f"Got an error when attempting to open a Athena connection due to {exc}" ) connection.handle = None - connection.state = ConnectionState.FAIL + connection.state = ConnectionState.FAIL # type:ignore raise ConnectionError(str(exc)) return connection diff --git a/dbt-athena/src/dbt/adapters/athena/impl.py b/dbt-athena/src/dbt/adapters/athena/impl.py index 1875953bb..f646504bc 100755 --- a/dbt-athena/src/dbt/adapters/athena/impl.py +++ b/dbt-athena/src/dbt/adapters/athena/impl.py @@ -216,7 +216,7 @@ def apply_lf_grants(self, relation: AthenaRelation, lf_grants_config: Dict[str, region_name=client.region_name, config=get_boto3_config(num_retries=creds.effective_num_retries), ) - catalog = self._get_data_catalog(relation.database) + catalog = self._get_data_catalog(relation.database) # type:ignore catalog_id = get_catalog_id(catalog) lf_permissions = LfPermissions(catalog_id, relation, lf) # type: ignore lf_permissions.process_filters(lf_config) @@ -321,11 +321,15 @@ def generate_s3_location( mapping = { S3DataNaming.UNIQUE: path.join(table_prefix, str(uuid4())), - S3DataNaming.TABLE: path.join(table_prefix, s3_path_table_part), - S3DataNaming.TABLE_UNIQUE: path.join(table_prefix, s3_path_table_part, str(uuid4())), - S3DataNaming.SCHEMA_TABLE: path.join(table_prefix, schema_name, s3_path_table_part), + S3DataNaming.TABLE: path.join(table_prefix, s3_path_table_part), # type:ignore + S3DataNaming.TABLE_UNIQUE: path.join( + table_prefix, s3_path_table_part, str(uuid4()) # type:ignore + ), + S3DataNaming.SCHEMA_TABLE: path.join( + table_prefix, schema_name, s3_path_table_part # type:ignore + ), S3DataNaming.SCHEMA_TABLE_UNIQUE: path.join( - table_prefix, schema_name, s3_path_table_part, str(uuid4()) + table_prefix, schema_name, s3_path_table_part, str(uuid4()) # type:ignore ), } @@ -340,7 +344,7 @@ def get_glue_table(self, relation: AthenaRelation) -> Optional[GetTableResponseT creds = conn.credentials client = conn.handle - data_catalog = self._get_data_catalog(relation.database) + data_catalog = self._get_data_catalog(relation.database) # type:ignore catalog_id = get_catalog_id(data_catalog) with boto3_client_lock: @@ -402,7 +406,7 @@ def clean_up_partitions(self, relation: AthenaRelation, where_condition: str) -> creds = conn.credentials client = conn.handle - data_catalog = self._get_data_catalog(relation.database) + data_catalog = self._get_data_catalog(relation.database) # type:ignore catalog_id = get_catalog_id(data_catalog) with boto3_client_lock: @@ -441,7 +445,7 @@ def clean_up_table(self, relation: AthenaRelation) -> None: def generate_unique_temporary_table_suffix(self, suffix_initial: str = "__dbt_tmp") -> str: return f"{suffix_initial}_{str(uuid4()).replace('-', '_')}" - def quote(self, identifier: str) -> str: + def quote(self, identifier: str) -> str: # type:ignore return f"{self.quote_character}{identifier}{self.quote_character}" @available @@ -622,7 +626,7 @@ def _get_one_catalog( """ This function is invoked by Adapter.get_catalog for each schema. """ - data_catalog = self._get_data_catalog(information_schema.database) + data_catalog = self._get_data_catalog(information_schema.database) # type:ignore data_catalog_type = get_catalog_type(data_catalog) conn = self.connections.get_thread_connection() @@ -652,7 +656,9 @@ def _get_one_catalog( for page in paginator.paginate(**kwargs): for table in page["TableList"]: catalog.extend( - self._get_one_table_for_catalog(table, information_schema.database) + self._get_one_table_for_catalog( + table, information_schema.database # type:ignore + ) ) table = agate.Table.from_object(catalog) else: @@ -674,14 +680,14 @@ def _get_one_catalog( for table in page["TableMetadataList"]: catalog.extend( self._get_one_table_for_non_glue_catalog( - table, schema, information_schema.database + table, schema, information_schema.database # type:ignore ) ) table = agate.Table.from_object(catalog) return self._catalog_filter_table(table, used_schemas) - def _get_catalog_schemas( + def _get_catalog_schemas( # type:ignore self, relation_configs: Iterable[RelationConfig] ) -> AthenaSchemaSearchMap: """ @@ -723,10 +729,10 @@ def _get_data_catalog(self, database: str) -> Optional[DataCatalogTypeDef]: def list_relations_without_caching( self, schema_relation: AthenaRelation ) -> List[BaseRelation]: - data_catalog = self._get_data_catalog(schema_relation.database) + data_catalog = self._get_data_catalog(schema_relation.database) # type:ignore if data_catalog and data_catalog["Type"] != "GLUE": # For non-Glue Data Catalogs, use the original Athena query against INFORMATION_SCHEMA approach - return super().list_relations_without_caching(schema_relation) # type: ignore + return super().list_relations_without_caching(schema_relation) conn = self.connections.get_thread_connection() creds = conn.credentials @@ -795,7 +801,7 @@ def _get_one_catalog_by_relations( glue_table_definition = self.get_glue_table(_rel) if glue_table_definition: _table_definition = self._get_one_table_for_catalog( - glue_table_definition["Table"], _rel.database + glue_table_definition["Table"], _rel.database # type:ignore ) _table_definitions.extend(_table_definition) table = agate.Table.from_object(_table_definitions) @@ -812,7 +818,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati creds = conn.credentials client = conn.handle - data_catalog = self._get_data_catalog(src_relation.database) + data_catalog = self._get_data_catalog(src_relation.database) # type:ignore src_catalog_id = get_catalog_id(data_catalog) with boto3_client_lock: @@ -838,7 +844,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati ) src_table_partitions = src_table_partitions_result.build_full_result().get("Partitions") - data_catalog = self._get_data_catalog(src_relation.database) + data_catalog = self._get_data_catalog(src_relation.database) # type:ignore target_catalog_id = get_catalog_id(data_catalog) target_get_partitions_paginator = glue_client.get_paginator("get_partitions") @@ -945,7 +951,7 @@ def expire_glue_table_versions( creds = conn.credentials client = conn.handle - data_catalog = self._get_data_catalog(relation.database) + data_catalog = self._get_data_catalog(relation.database) # type:ignore catalog_id = get_catalog_id(data_catalog) with boto3_client_lock: @@ -1006,7 +1012,7 @@ def persist_docs_to_glue( creds = conn.credentials client = conn.handle - data_catalog = self._get_data_catalog(relation.database) + data_catalog = self._get_data_catalog(relation.database) # type:ignore catalog_id = get_catalog_id(data_catalog) with boto3_client_lock: @@ -1174,7 +1180,7 @@ def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn creds = conn.credentials client = conn.handle - data_catalog = self._get_data_catalog(relation.database) + data_catalog = self._get_data_catalog(relation.database) # type:ignore catalog_id = get_catalog_id(data_catalog) with boto3_client_lock: @@ -1222,7 +1228,7 @@ def delete_from_glue_catalog(self, relation: AthenaRelation) -> None: creds = conn.credentials client = conn.handle - data_catalog = self._get_data_catalog(relation.database) + data_catalog = self._get_data_catalog(relation.database) # type:ignore catalog_id = get_catalog_id(data_catalog) with boto3_client_lock: @@ -1309,7 +1315,7 @@ def _generate_snapshot_migration_sql( """ col_csv = f", \n{' ' * 16}".join(table_columns) staging_relation = relation.incorporate( - path={"identifier": relation.identifier + "__dbt_tmp_migration_staging"} + path={"identifier": relation.identifier + "__dbt_tmp_migration_staging"} # type:ignore ) ctas = dedent( f"""\ @@ -1329,7 +1335,7 @@ def _generate_snapshot_migration_sql( ) backup_relation = relation.incorporate( - path={"identifier": relation.identifier + "__dbt_tmp_migration_backup"} + path={"identifier": relation.identifier + "__dbt_tmp_migration_backup"} # type:ignore ) backup_sql = self.execute_macro( "create_table_as", @@ -1477,7 +1483,7 @@ def _run_query(self, sql: str, catch_partitions_limit: bool) -> AthenaCursor: @classmethod def _get_adapter_specific_run_info(cls, config: RelationConfig) -> Dict[str, Any]: try: - table_format = config._extra.get("table_type") + table_format = config._extra.get("table_type") # type:ignore except AttributeError: table_format = None return { diff --git a/dbt-athena/src/dbt/adapters/athena/lakeformation.py b/dbt-athena/src/dbt/adapters/athena/lakeformation.py index e4952cebd..25599853f 100644 --- a/dbt-athena/src/dbt/adapters/athena/lakeformation.py +++ b/dbt-athena/src/dbt/adapters/athena/lakeformation.py @@ -174,7 +174,7 @@ def _parse_and_log_lf_response( ) -> None: table_appendix = f".{self.table}" if self.table else "" columns_appendix = f" for columns {columns}" if columns else "" - resource_msg = self.database + table_appendix + columns_appendix + resource_msg = self.database + table_appendix + columns_appendix # type:ignore if failures := response.get("Failures", []): base_msg = f"Failed to {verb} LF tags: {lf_tags} to " + resource_msg for failure in failures: @@ -224,8 +224,8 @@ def __init__( ) -> None: self.catalog_id = catalog_id self.relation = relation - self.database: str = relation.schema - self.table: str = relation.identifier + self.database: str = relation.schema # type:ignore + self.table: str = relation.identifier # type:ignore self.lf_client = lf_client def get_filters(self) -> Dict[str, DataCellsFilterTypeDef]: diff --git a/dbt-athena/src/dbt/adapters/athena/relation.py b/dbt-athena/src/dbt/adapters/athena/relation.py index cd4e65934..21673c443 100644 --- a/dbt-athena/src/dbt/adapters/athena/relation.py +++ b/dbt-athena/src/dbt/adapters/athena/relation.py @@ -85,7 +85,7 @@ def add(self, relation: AthenaRelation) -> None: self[key] = {} if relation.schema is not None: schema = relation.schema.lower() - relation_name = relation.name.lower() + relation_name = relation.name.lower() # type:ignore if schema not in self[key]: self[key][schema] = set() self[key][schema].add(relation_name) diff --git a/dbt-bigquery/pyproject.toml b/dbt-bigquery/pyproject.toml index aa180b3a9..94dc9ed60 100644 --- a/dbt-bigquery/pyproject.toml +++ b/dbt-bigquery/pyproject.toml @@ -43,9 +43,6 @@ Repository = "https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-bigq Issues = "https://github.com/dbt-labs/dbt-adapters/issues" Changelog = "https://github.com/dbt-labs/dbt-adapters/blob/main/dbt-bigquery/CHANGELOG.md" -[tool.mypy] -mypy_path = "third-party-stubs/" - [tool.pytest.ini_options] testpaths = ["tests/unit", "tests/functional"] addopts = "-v --color=yes -n auto" diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/__init__.py b/dbt-bigquery/src/dbt/adapters/bigquery/__init__.py index 74fa17cda..f0e6b0b57 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/__init__.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/__init__.py @@ -8,5 +8,7 @@ from dbt.include import bigquery Plugin = AdapterPlugin( - adapter=BigQueryAdapter, credentials=BigQueryCredentials, include_path=bigquery.PACKAGE_PATH + adapter=BigQueryAdapter, # type:ignore + credentials=BigQueryCredentials, + include_path=bigquery.PACKAGE_PATH, ) diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/column.py b/dbt-bigquery/src/dbt/adapters/bigquery/column.py index a676fef4b..48aa1ce14 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/column.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/column.py @@ -19,7 +19,7 @@ class BigQueryColumn(Column): "INTEGER": "INT64", } fields: List[Self] # type: ignore - mode: str + mode: str = "NULLABLE" def __init__( self, @@ -111,7 +111,7 @@ def is_numeric(self) -> bool: def is_float(self): return self.dtype.lower() == "float64" - def can_expand_to(self: Self, other_column: Self) -> bool: + def can_expand_to(self: Self, other_column: Column) -> bool: """returns True if both columns are strings""" return self.is_string() and other_column.is_string() diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/connections.py b/dbt-bigquery/src/dbt/adapters/bigquery/connections.py index cafaffef7..73e55446b 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/connections.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/connections.py @@ -127,7 +127,7 @@ def exception_handler(self, sql): exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip() raise DbtDatabaseError(exc_message) - def cancel_open(self): + def cancel_open(self) -> List[str]: names = [] this_connection = self.get_if_exists() with self.lock: diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/credentials.py b/dbt-bigquery/src/dbt/adapters/bigquery/credentials.py index 94d70a931..cb3a9a723 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/credentials.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/credentials.py @@ -46,8 +46,8 @@ class BigQueryCredentials(Credentials): # BigQuery allows an empty database / project, where it defers to the # environment for the project - database: Optional[str] = None - schema: Optional[str] = None + database: Optional[str] = None # type:ignore + schema: Optional[str] = None # type:ignore execution_project: Optional[str] = None quota_project: Optional[str] = None location: Optional[str] = None diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/impl.py b/dbt-bigquery/src/dbt/adapters/bigquery/impl.py index 51c457129..b22dad6e8 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/impl.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/impl.py @@ -158,7 +158,7 @@ def is_cancelable(cls) -> bool: return True def drop_relation(self, relation: BigQueryRelation) -> None: - is_cached = self._schema_is_cached(relation.database, relation.schema) + is_cached = self._schema_is_cached(relation.database, relation.schema) # type:ignore if is_cached: self.cache_dropped(relation) @@ -350,7 +350,7 @@ def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str: def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str: import agate - decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined] + decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "float64" if decimals else "int64" @classmethod @@ -461,7 +461,9 @@ def _bq_table_to_relation(self, bq_table) -> Union[BigQueryRelation, None]: schema=bq_table.dataset_id, identifier=bq_table.table_id, quote_policy={"schema": True, "identifier": True}, - type=self.RELATION_TYPES.get(bq_table.table_type, RelationType.External), + type=self.RELATION_TYPES.get( + bq_table.table_type, RelationType.External + ), # type:ignore ) @classmethod @@ -661,7 +663,7 @@ def load_dataframe( connection = self.connections.get_thread_connection() client: Client = connection.handle table_schema = self._agate_to_schema(agate_table, column_override) - file_path = agate_table.original_abspath # type: ignore + file_path = agate_table.original_abspath self.connections.write_dataframe_to_table( client, @@ -713,8 +715,8 @@ def _get_catalog_schemas(self, relation_config: Iterable[RelationConfig]) -> Sch for candidate, schemas in candidates.items(): database = candidate.database if database not in db_schemas: - db_schemas[database] = set(self.list_schemas(database)) - if candidate.schema in db_schemas[database]: + db_schemas[database] = set(self.list_schemas(database)) # type:ignore + if candidate.schema in db_schemas[database]: # type:ignore result[candidate] = schemas else: logger.debug( @@ -826,7 +828,7 @@ def grant_access_to(self, entity, entity_type, role, grant_target_dict) -> None: Given an entity, grants it access to a dataset. """ conn: BigQueryConnectionManager = self.connections.get_thread_connection() - client = conn.handle + client = conn.handle # type:ignore GrantTarget.validate(grant_target_dict) grant_target = GrantTarget.from_dict(grant_target_dict) if entity_type == "view": diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/relation.py b/dbt-bigquery/src/dbt/adapters/bigquery/relation.py index 037761918..05784d324 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/relation.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/relation.py @@ -34,7 +34,7 @@ class BigQueryRelation(BaseRelation): renameable_relations: FrozenSet[RelationType] = field( default_factory=lambda: frozenset( { - RelationType.Table, + RelationType.Table, # type:ignore } ) ) @@ -42,8 +42,8 @@ class BigQueryRelation(BaseRelation): replaceable_relations: FrozenSet[RelationType] = field( default_factory=lambda: frozenset( { - RelationType.View, - RelationType.Table, + RelationType.View, # type:ignore + RelationType.Table, # type:ignore } ) ) @@ -97,7 +97,7 @@ def materialized_view_config_changeset( if new_materialized_view.options != existing_materialized_view.options: config_change_collection.options = BigQueryOptionsConfigChange( - action=RelationConfigChangeAction.alter, + action=RelationConfigChangeAction.alter, # type:ignore context=new_materialized_view.options, ) @@ -105,12 +105,12 @@ def materialized_view_config_changeset( # the existing PartitionConfig is not hashable, but since we need to do # a full refresh either way, we don't need to provide a context config_change_collection.partition = BigQueryPartitionConfigChange( - action=RelationConfigChangeAction.alter, + action=RelationConfigChangeAction.alter, # type:ignore ) if new_materialized_view.cluster != existing_materialized_view.cluster: config_change_collection.cluster = BigQueryClusterConfigChange( - action=RelationConfigChangeAction.alter, + action=RelationConfigChangeAction.alter, # type:ignore context=new_materialized_view.cluster, ) diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_base.py b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_base.py index 8bc861587..5b25a75a3 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_base.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_base.py @@ -32,7 +32,7 @@ def quote_policy(cls) -> Policy: def from_relation_config(cls, relation_config: RelationConfig) -> Self: relation_config_dict = cls.parse_relation_config(relation_config) relation = cls.from_dict(relation_config_dict) - return relation + return relation # type:ignore @classmethod def parse_relation_config(cls, relation_config: RelationConfig) -> Dict: @@ -44,7 +44,7 @@ def parse_relation_config(cls, relation_config: RelationConfig) -> Dict: def from_bq_table(cls, table: BigQueryTable) -> Self: relation_config = cls.parse_bq_table(table) relation = cls.from_dict(relation_config) - return relation + return relation # type:ignore @classmethod def parse_bq_table(cls, table: BigQueryTable) -> Dict: diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_cluster.py b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_cluster.py index b3dbaf2e9..79efb4f66 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_cluster.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_cluster.py @@ -25,13 +25,13 @@ class BigQueryClusterConfig(BigQueryBaseRelationConfig): @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> Self: kwargs_dict = {"fields": config_dict.get("fields")} - return super().from_dict(kwargs_dict) + return super().from_dict(kwargs_dict) # type:ignore @classmethod def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]: config_dict = {} - if cluster_by := relation_config.config.extra.get("cluster_by"): + if cluster_by := relation_config.config.extra.get("cluster_by"): # type:ignore # users may input a single field as a string if isinstance(cluster_by, str): cluster_by = [cluster_by] diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_materialized_view.py b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_materialized_view.py index 7c63ba3bc..f63923da2 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_materialized_view.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_materialized_view.py @@ -48,9 +48,15 @@ class BigQueryMaterializedViewConfig(BigQueryBaseRelationConfig): def from_dict(cls, config_dict: Dict[str, Any]) -> "BigQueryMaterializedViewConfig": # required kwargs_dict: Dict[str, Any] = { - "table_id": cls._render_part(ComponentName.Identifier, config_dict["table_id"]), - "dataset_id": cls._render_part(ComponentName.Schema, config_dict["dataset_id"]), - "project_id": cls._render_part(ComponentName.Database, config_dict["project_id"]), + "table_id": cls._render_part( + ComponentName.Identifier, config_dict["table_id"] # type:ignore + ), + "dataset_id": cls._render_part( + ComponentName.Schema, config_dict["dataset_id"] # type:ignore + ), + "project_id": cls._render_part( + ComponentName.Database, config_dict["project_id"] # type:ignore + ), "options": BigQueryOptionsConfig.from_dict(config_dict["options"]), } @@ -61,7 +67,9 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "BigQueryMaterializedViewConf if cluster := config_dict.get("cluster"): kwargs_dict.update({"cluster": BigQueryClusterConfig.from_dict(cluster)}) - materialized_view: "BigQueryMaterializedViewConfig" = super().from_dict(kwargs_dict) + materialized_view: "BigQueryMaterializedViewConfig" = super().from_dict( + kwargs_dict + ) # type:ignore return materialized_view @classmethod diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_options.py b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_options.py index 7fd8797df..5cfc51ac1 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_options.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_options.py @@ -103,13 +103,13 @@ def formatted_setting(name: str) -> Any: if kwargs_dict["enable_refresh"] is False: kwargs_dict.update({"refresh_interval_minutes": None, "max_staleness": None}) - options: Self = super().from_dict(kwargs_dict) + options: Self = super().from_dict(kwargs_dict) # type:ignore return options @classmethod def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]: config_dict = { - option: relation_config.config.extra.get(option) + option: relation_config.config.extra.get(option) # type:ignore for option in [ "enable_refresh", "refresh_interval_minutes", @@ -122,11 +122,13 @@ def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any } # update dbt-specific versions of these settings - if hours_to_expiration := relation_config.config.extra.get("hours_to_expiration"): + if hours_to_expiration := relation_config.config.extra.get( # type:ignore + "hours_to_expiration" + ): config_dict.update( {"expiration_timestamp": datetime.now() + timedelta(hours=hours_to_expiration)} ) - if not relation_config.config.persist_docs: + if not relation_config.config.persist_docs: # type:ignore del config_dict["description"] return config_dict diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_partition.py b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_partition.py index e1a5ac171..0699f4232 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_partition.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/relation_configs/_partition.py @@ -111,7 +111,9 @@ def parse_model_node(cls, relation_config: RelationConfig) -> Dict[str, Any]: This doesn't currently collect `time_ingestion_partitioning` and `copy_partitions` because this was built for materialized views, which do not support those settings. """ - config_dict: Dict[str, Any] = relation_config.config.extra.get("partition_by") + config_dict: Dict[str, Any] = relation_config.config.extra.get( # type:ignore + "partition_by" + ) if "time_ingestion_partitioning" in config_dict: del config_dict["time_ingestion_partitioning"] if "copy_partitions" in config_dict: diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/retry.py b/dbt-bigquery/src/dbt/adapters/bigquery/retry.py index cc197a7d3..7f0f0bfcb 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/retry.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/retry.py @@ -101,14 +101,14 @@ def on_error(error: Exception): try: connection.handle = create_bigquery_client(connection.credentials) - connection.state = ConnectionState.OPEN + connection.state = ConnectionState.OPEN # type:ignore except Exception as e: _logger.debug( f"""Got an error when attempting to create a bigquery " "client: '{e}'""" ) connection.handle = None - connection.state = ConnectionState.FAIL + connection.state = ConnectionState.FAIL # type:ignore raise FailedToConnectError(str(e)) return on_error diff --git a/dbt-bigquery/tests/functional/adapter/incremental/test_incremental_microbatch.py b/dbt-bigquery/tests/functional/adapter/incremental/test_incremental_microbatch.py index 912f96eec..2823195c2 100644 --- a/dbt-bigquery/tests/functional/adapter/incremental/test_incremental_microbatch.py +++ b/dbt-bigquery/tests/functional/adapter/incremental/test_incremental_microbatch.py @@ -53,7 +53,7 @@ def insert_two_rows_sql(self, project) -> str: class TestBigQueryMicrobatchMissingPartitionBy: @pytest.fixture(scope="class") - def models(self) -> str: + def models(self): return { "microbatch.sql": microbatch_model_no_partition_by_sql, "input_model.sql": microbatch_input_sql, @@ -67,7 +67,7 @@ def test_execution_failure_no_partition_by(self, project): class TestBigQueryMicrobatchInvalidPartitionByGranularity: @pytest.fixture(scope="class") - def models(self) -> str: + def models(self): return { "microbatch.sql": microbatch_model_invalid_partition_by_sql, "input_model.sql": microbatch_input_sql, diff --git a/dbt-bigquery/third-party-stubs/agate/__init__.pyi b/dbt-bigquery/third-party-stubs/agate/__init__.pyi deleted file mode 100644 index c773cc7d7..000000000 --- a/dbt-bigquery/third-party-stubs/agate/__init__.pyi +++ /dev/null @@ -1,89 +0,0 @@ -from collections.abc import Sequence - -from typing import Any, Optional, Callable, Iterable, Dict, Union - -from . import data_types as data_types -from .data_types import ( - Text as Text, - Number as Number, - Boolean as Boolean, - DateTime as DateTime, - Date as Date, - TimeDelta as TimeDelta, -) - -class MappedSequence(Sequence): - def __init__(self, values: Any, keys: Optional[Any] = ...) -> None: ... - def __unicode__(self): ... - def __getitem__(self, key: Any): ... - def __setitem__(self, key: Any, value: Any) -> None: ... - def __iter__(self): ... - def __len__(self): ... - def __eq__(self, other: Any): ... - def __ne__(self, other: Any): ... - def __contains__(self, value: Any): ... - def keys(self): ... - def values(self): ... - def items(self): ... - def get(self, key: Any, default: Optional[Any] = ...): ... - def dict(self): ... - -class Row(MappedSequence): ... - -class Table: - def __init__( - self, - rows: Any, - column_names: Optional[Any] = ..., - column_types: Optional[Any] = ..., - row_names: Optional[Any] = ..., - _is_fork: bool = ..., - ) -> None: ... - def __len__(self): ... - def __iter__(self): ... - def __getitem__(self, key: Any): ... - @property - def column_types(self): ... - @property - def column_names(self): ... - @property - def row_names(self): ... - @property - def columns(self): ... - @property - def rows(self): ... - def print_csv(self, **kwargs: Any) -> None: ... - def print_json(self, **kwargs: Any) -> None: ... - def where(self, test: Callable[[Row], bool]) -> "Table": ... - def select(self, key: Union[Iterable[str], str]) -> "Table": ... - # these definitions are much narrower than what's actually accepted - @classmethod - def from_object( - cls, obj: Iterable[Dict[str, Any]], *, column_types: Optional["TypeTester"] = None - ) -> "Table": ... - @classmethod - def from_csv( - cls, path: Iterable[str], *, column_types: Optional["TypeTester"] = None - ) -> "Table": ... - @classmethod - def merge(cls, tables: Iterable["Table"]) -> "Table": ... - def rename( - self, - column_names: Optional[Iterable[str]] = None, - row_names: Optional[Any] = None, - slug_columns: bool = False, - slug_rows: bool = False, - **kwargs: Any, - ) -> "Table": ... - -class TypeTester: - def __init__( - self, force: Any = ..., limit: Optional[Any] = ..., types: Optional[Any] = ... - ) -> None: ... - def run(self, rows: Any, column_names: Any): ... - -class MaxPrecision: - def __init__(self, column_name: Any) -> None: ... - -# this is not strictly true, but it's all we care about. -def aggregate(self, aggregations: MaxPrecision) -> int: ... diff --git a/dbt-bigquery/third-party-stubs/agate/data_types.pyi b/dbt-bigquery/third-party-stubs/agate/data_types.pyi deleted file mode 100644 index 8114f7b55..000000000 --- a/dbt-bigquery/third-party-stubs/agate/data_types.pyi +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Any, Optional - -DEFAULT_NULL_VALUES: Any - -class DataType: - null_values: Any = ... - def __init__(self, null_values: Any = ...) -> None: ... - def test(self, d: Any): ... - def cast(self, d: Any) -> None: ... - def csvify(self, d: Any): ... - def jsonify(self, d: Any): ... - -DEFAULT_TRUE_VALUES: Any -DEFAULT_FALSE_VALUES: Any - -class Boolean(DataType): - true_values: Any = ... - false_values: Any = ... - def __init__( - self, true_values: Any = ..., false_values: Any = ..., null_values: Any = ... - ) -> None: ... - def cast(self, d: Any): ... - def jsonify(self, d: Any): ... - -ZERO_DT: Any - -class Date(DataType): - date_format: Any = ... - parser: Any = ... - def __init__(self, date_format: Optional[Any] = ..., **kwargs: Any) -> None: ... - def cast(self, d: Any): ... - def csvify(self, d: Any): ... - def jsonify(self, d: Any): ... - -class DateTime(DataType): - datetime_format: Any = ... - timezone: Any = ... - def __init__( - self, datetime_format: Optional[Any] = ..., timezone: Optional[Any] = ..., **kwargs: Any - ) -> None: ... - def cast(self, d: Any): ... - def csvify(self, d: Any): ... - def jsonify(self, d: Any): ... - -DEFAULT_CURRENCY_SYMBOLS: Any -POSITIVE: Any -NEGATIVE: Any - -class Number(DataType): - locale: Any = ... - currency_symbols: Any = ... - group_symbol: Any = ... - decimal_symbol: Any = ... - def __init__( - self, - locale: str = ..., - group_symbol: Optional[Any] = ..., - decimal_symbol: Optional[Any] = ..., - currency_symbols: Any = ..., - **kwargs: Any, - ) -> None: ... - def cast(self, d: Any): ... - def jsonify(self, d: Any): ... - -class TimeDelta(DataType): - def cast(self, d: Any): ... - -class Text(DataType): - cast_nulls: Any = ... - def __init__(self, cast_nulls: bool = ..., **kwargs: Any) -> None: ... - def cast(self, d: Any): ... diff --git a/dbt-postgres/src/dbt/adapters/postgres/relation.py b/dbt-postgres/src/dbt/adapters/postgres/relation.py index e8128c462..1fb7ae5d6 100644 --- a/dbt-postgres/src/dbt/adapters/postgres/relation.py +++ b/dbt-postgres/src/dbt/adapters/postgres/relation.py @@ -23,16 +23,16 @@ class PostgresRelation(BaseRelation): renameable_relations: FrozenSet[RelationType] = field( default_factory=lambda: frozenset( { - RelationType.View, - RelationType.Table, + RelationType.View, # type:ignore + RelationType.Table, # type:ignore } ) ) replaceable_relations: FrozenSet[RelationType] = field( default_factory=lambda: frozenset( { - RelationType.View, - RelationType.Table, + RelationType.View, # type:ignore + RelationType.Table, # type:ignore } ) ) @@ -108,4 +108,4 @@ def _get_index_config_changes( ) for index in new_indexes.difference(existing_indexes) ] - return drop_changes + create_changes + return drop_changes + create_changes # type:ignore diff --git a/dbt-redshift/pyproject.toml b/dbt-redshift/pyproject.toml index c6a799871..6783867a3 100644 --- a/dbt-redshift/pyproject.toml +++ b/dbt-redshift/pyproject.toml @@ -44,9 +44,6 @@ Repository = "https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-reds Issues = "https://github.com/dbt-labs/dbt-adapters/issues" Changelog = "https://github.com/dbt-labs/dbt-adapters/blob/main/dbt-redshift/CHANGELOG.md" -[tool.mypy] -mypy_path = "third-party-stubs/" - [tool.pytest.ini_options] testpaths = ["tests/unit", "tests/functional"] addopts = "-v --color=yes -n auto" diff --git a/dbt-redshift/src/dbt/adapters/redshift/connections.py b/dbt-redshift/src/dbt/adapters/redshift/connections.py index 9632be77b..5b0d43751 100644 --- a/dbt-redshift/src/dbt/adapters/redshift/connections.py +++ b/dbt-redshift/src/dbt/adapters/redshift/connections.py @@ -2,9 +2,9 @@ import redshift_connector import sqlparse -from multiprocessing import Lock +from multiprocessing.synchronize import RLock from contextlib import contextmanager -from typing import Any, Callable, Dict, Tuple, Union, Optional, List, TYPE_CHECKING +from typing import Any, Callable, Dict, Generator, Tuple, Union, Optional, List, TYPE_CHECKING from dataclasses import dataclass, field from dbt.adapters.exceptions import FailedToConnectError @@ -85,7 +85,7 @@ class RedshiftSSLMode(StrEnum): @dataclass -class RedshiftSSLConfig(dbtClassMixin, Replaceable): # type: ignore +class RedshiftSSLConfig(dbtClassMixin, Replaceable): ssl: bool = True sslmode: Optional[RedshiftSSLMode] = SSL_MODE_TRANSLATION[UserSSLMode.default()] @@ -119,9 +119,9 @@ def parse(cls, user_sslmode: UserSSLMode) -> "RedshiftSSLConfig": class RedshiftCredentials(Credentials): host: str port: Port - method: str = RedshiftConnectionMethod.DATABASE # type: ignore + method: str = RedshiftConnectionMethod.DATABASE user: Optional[str] = None - password: Optional[str] = None # type: ignore + password: Optional[str] = None cluster_id: Optional[str] = field( default=None, metadata={"description": "If using IAM auth, the name of the cluster"}, @@ -392,7 +392,7 @@ class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" def cancel(self, connection: Connection): - pid = connection.backend_pid # type: ignore + pid = connection.backend_pid sql = f"select pg_terminate_backend({pid})" logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}") logger.debug(sql) @@ -443,14 +443,14 @@ def exception_handler(self, sql): raise DbtRuntimeError(str(e)) from e @contextmanager - def fresh_transaction(self): + def fresh_transaction(self) -> Generator[None, None, None]: """On entrance to this context manager, hold an exclusive lock and create a fresh transaction for redshift, then commit and begin a new one before releasing the lock on exit. See drop_relation in RedshiftAdapter for more information. """ - drop_lock: Lock = self.lock + drop_lock: RLock = self.lock with drop_lock: connection = self.get_thread_connection() @@ -486,7 +486,7 @@ def open(cls, connection): retry_limit=credentials.retries, retryable_exceptions=retryable_exceptions, ) - open_connection.backend_pid = cls._get_backend_pid(open_connection) # type: ignore + open_connection.backend_pid = cls._get_backend_pid(open_connection) return open_connection def execute( @@ -560,7 +560,7 @@ def _initialize_sqlparse_lexer(): Resolves: https://github.com/dbt-labs/dbt-redshift/issues/710 Implementation of this fix: https://github.com/dbt-labs/dbt-core/pull/8215 """ - from sqlparse.lexer import Lexer # type: ignore + from sqlparse.lexer import Lexer if hasattr(Lexer, "get_default_instance"): Lexer.get_default_instance() diff --git a/dbt-redshift/src/dbt/adapters/redshift/impl.py b/dbt-redshift/src/dbt/adapters/redshift/impl.py index aaf3d46ca..f67b5fedd 100644 --- a/dbt-redshift/src/dbt/adapters/redshift/impl.py +++ b/dbt-redshift/src/dbt/adapters/redshift/impl.py @@ -44,11 +44,11 @@ class RedshiftConfig(AdapterConfig): class RedshiftAdapter(SQLAdapter): - Relation = RedshiftRelation # type: ignore + Relation = RedshiftRelation ConnectionManager = RedshiftConnectionManager connections: RedshiftConnectionManager - AdapterSpecificConfigs = RedshiftConfig # type: ignore + AdapterSpecificConfigs = RedshiftConfig CONSTRAINT_SUPPORT = { ConstraintType.check: ConstraintSupport.NOT_SUPPORTED, diff --git a/dbt-redshift/src/dbt/adapters/redshift/relation.py b/dbt-redshift/src/dbt/adapters/redshift/relation.py index eaf60f54c..9c5ab56fa 100644 --- a/dbt-redshift/src/dbt/adapters/redshift/relation.py +++ b/dbt-redshift/src/dbt/adapters/redshift/relation.py @@ -29,20 +29,20 @@ class RedshiftRelation(BaseRelation): quote_policy = RedshiftQuotePolicy # type: ignore require_alias: bool = False relation_configs = { - RelationType.MaterializedView.value: RedshiftMaterializedViewConfig, + RelationType.MaterializedView.value: RedshiftMaterializedViewConfig, # type:ignore } renameable_relations: FrozenSet[RelationType] = field( default_factory=lambda: frozenset( { - RelationType.View, - RelationType.Table, + RelationType.View, # type:ignore + RelationType.Table, # type:ignore } ) ) replaceable_relations: FrozenSet[RelationType] = field( default_factory=lambda: frozenset( { - RelationType.View, + RelationType.View, # type:ignore } ) ) @@ -89,19 +89,19 @@ def materialized_view_config_changeset( if new_materialized_view.autorefresh != existing_materialized_view.autorefresh: config_change_collection.autorefresh = RedshiftAutoRefreshConfigChange( - action=RelationConfigChangeAction.alter, + action=RelationConfigChangeAction.alter, # type:ignore context=new_materialized_view.autorefresh, ) if new_materialized_view.dist != existing_materialized_view.dist: config_change_collection.dist = RedshiftDistConfigChange( - action=RelationConfigChangeAction.alter, + action=RelationConfigChangeAction.alter, # type:ignore context=new_materialized_view.dist, ) if new_materialized_view.sort != existing_materialized_view.sort: config_change_collection.sort = RedshiftSortConfigChange( - action=RelationConfigChangeAction.alter, + action=RelationConfigChangeAction.alter, # type:ignore context=new_materialized_view.sort, ) diff --git a/dbt-redshift/src/dbt/adapters/redshift/relation_configs/materialized_view.py b/dbt-redshift/src/dbt/adapters/redshift/relation_configs/materialized_view.py index a01185f22..db249b1c5 100644 --- a/dbt-redshift/src/dbt/adapters/redshift/relation_configs/materialized_view.py +++ b/dbt-redshift/src/dbt/adapters/redshift/relation_configs/materialized_view.py @@ -99,10 +99,14 @@ def validation_rules(self) -> Set[RelationConfigValidationRule]: @classmethod def from_dict(cls, config_dict) -> Self: kwargs_dict = { - "mv_name": cls._render_part(ComponentName.Identifier, config_dict.get("mv_name")), - "schema_name": cls._render_part(ComponentName.Schema, config_dict.get("schema_name")), + "mv_name": cls._render_part( + ComponentName.Identifier, config_dict.get("mv_name") # type:ignore + ), + "schema_name": cls._render_part( + ComponentName.Schema, config_dict.get("schema_name") # type:ignore + ), "database_name": cls._render_part( - ComponentName.Database, config_dict.get("database_name") + ComponentName.Database, config_dict.get("database_name") # type:ignore ), "query": config_dict.get("query"), "backup": config_dict.get("backup"), @@ -136,7 +140,7 @@ def parse_relation_config(cls, config: RelationConfig) -> Dict[str, Any]: if autorefresh_value is not None: config_dict["autorefresh"] = evaluate_bool(autorefresh_value) - if query := config.compiled_code: # type: ignore + if query := config.compiled_code: config_dict.update({"query": query.strip()}) if config.config.get("dist"): # type: ignore diff --git a/dbt-redshift/src/dbt/adapters/redshift/relation_configs/sort.py b/dbt-redshift/src/dbt/adapters/redshift/relation_configs/sort.py index f38d5a1e1..edaae1454 100644 --- a/dbt-redshift/src/dbt/adapters/redshift/relation_configs/sort.py +++ b/dbt-redshift/src/dbt/adapters/redshift/relation_configs/sort.py @@ -106,7 +106,7 @@ def from_dict(cls, config_dict) -> Self: "sortkey": tuple(column for column in config_dict.get("sortkey", {})), } sort: Self = super().from_dict(kwargs_dict) # type: ignore - return sort # type: ignore + return sort @classmethod def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]: diff --git a/dbt-redshift/tests/boundary/conftest.py b/dbt-redshift/tests/boundary/conftest.py index 402fa2d66..5db54a53b 100644 --- a/dbt-redshift/tests/boundary/conftest.py +++ b/dbt-redshift/tests/boundary/conftest.py @@ -12,7 +12,7 @@ def connection() -> redshift_connector.Connection: user=os.getenv("REDSHIFT_TEST_USER"), password=os.getenv("REDSHIFT_TEST_PASS"), host=os.getenv("REDSHIFT_TEST_HOST"), - port=int(os.getenv("REDSHIFT_TEST_PORT")), + port=int(os.getenv("REDSHIFT_TEST_PORT", 5439)), database=os.getenv("REDSHIFT_TEST_DBNAME"), region=os.getenv("REDSHIFT_TEST_REGION"), ) diff --git a/dbt-redshift/tests/boundary/test_redshift_connector.py b/dbt-redshift/tests/boundary/test_redshift_connector.py index 200d0cccf..711c4e1ea 100644 --- a/dbt-redshift/tests/boundary/test_redshift_connector.py +++ b/dbt-redshift/tests/boundary/test_redshift_connector.py @@ -1,8 +1,10 @@ +from typing import Generator + import pytest @pytest.fixture -def schema(connection, schema_name) -> str: +def schema(connection, schema_name) -> Generator[str, None, None]: with connection.cursor() as cursor: cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") yield schema_name diff --git a/dbt-redshift/tests/functional/adapter/conftest.py b/dbt-redshift/tests/functional/adapter/conftest.py index c5c980154..be7d3c64e 100644 --- a/dbt-redshift/tests/functional/adapter/conftest.py +++ b/dbt-redshift/tests/functional/adapter/conftest.py @@ -1,8 +1,10 @@ +from typing import Generator + import pytest @pytest.fixture -def model_ddl(request) -> str: +def model_ddl(request) -> Generator[str, None, None]: """ Returns the contents of the DDL file for the model provided. Use with pytest parameterization. diff --git a/dbt-snowflake/src/dbt/adapters/snowflake/__init__.py b/dbt-snowflake/src/dbt/adapters/snowflake/__init__.py index f0c546067..8585201a0 100644 --- a/dbt-snowflake/src/dbt/adapters/snowflake/__init__.py +++ b/dbt-snowflake/src/dbt/adapters/snowflake/__init__.py @@ -8,5 +8,7 @@ from dbt.include import snowflake Plugin = AdapterPlugin( - adapter=SnowflakeAdapter, credentials=SnowflakeCredentials, include_path=snowflake.PACKAGE_PATH + adapter=SnowflakeAdapter, # type:ignore + credentials=SnowflakeCredentials, + include_path=snowflake.PACKAGE_PATH, ) diff --git a/dbt-snowflake/src/dbt/adapters/snowflake/column.py b/dbt-snowflake/src/dbt/adapters/snowflake/column.py index 281831b29..3eb5c97a7 100644 --- a/dbt-snowflake/src/dbt/adapters/snowflake/column.py +++ b/dbt-snowflake/src/dbt/adapters/snowflake/column.py @@ -47,5 +47,5 @@ def from_description(cls, name: str, raw_data_type: str) -> "SnowflakeColumn": if "vector" in raw_data_type.lower(): column = cls(name, raw_data_type, None, None, None) else: - column = super().from_description(name, raw_data_type) + column = super().from_description(name, raw_data_type) # type:ignore return column diff --git a/dbt-snowflake/src/dbt/adapters/snowflake/connections.py b/dbt-snowflake/src/dbt/adapters/snowflake/connections.py index fc2c09c19..af286cb77 100644 --- a/dbt-snowflake/src/dbt/adapters/snowflake/connections.py +++ b/dbt-snowflake/src/dbt/adapters/snowflake/connections.py @@ -530,6 +530,8 @@ def add_query( auto_begin: bool = True, bindings: Optional[Any] = None, abridge_sql_log: bool = False, + *args, + **kwargs, ) -> Tuple[Connection, Any]: if bindings: # The snowflake connector is stricter than, e.g., psycopg2 - diff --git a/dbt-snowflake/src/dbt/adapters/snowflake/relation.py b/dbt-snowflake/src/dbt/adapters/snowflake/relation.py index f3ee3e510..1bca5fa4e 100644 --- a/dbt-snowflake/src/dbt/adapters/snowflake/relation.py +++ b/dbt-snowflake/src/dbt/adapters/snowflake/relation.py @@ -76,7 +76,7 @@ def get_relation_type(cls) -> Type[SnowflakeRelationType]: @classmethod def from_config(cls, config: RelationConfig) -> RelationConfigBase: - relation_type: str = config.config.materialized + relation_type: str = config.config.materialized # type:ignore if relation_config := cls.relation_configs.get(relation_type): return relation_config.from_relation_config(config) @@ -98,14 +98,14 @@ def dynamic_table_config_changeset( if new_dynamic_table.target_lag != existing_dynamic_table.target_lag: config_change_collection.target_lag = SnowflakeDynamicTableTargetLagConfigChange( - action=RelationConfigChangeAction.alter, + action=RelationConfigChangeAction.alter, # type:ignore context=new_dynamic_table.target_lag, ) if new_dynamic_table.snowflake_warehouse != existing_dynamic_table.snowflake_warehouse: config_change_collection.snowflake_warehouse = ( SnowflakeDynamicTableWarehouseConfigChange( - action=RelationConfigChangeAction.alter, + action=RelationConfigChangeAction.alter, # type:ignore context=new_dynamic_table.snowflake_warehouse, ) ) @@ -115,13 +115,13 @@ def dynamic_table_config_changeset( and new_dynamic_table.refresh_mode != existing_dynamic_table.refresh_mode ): config_change_collection.refresh_mode = SnowflakeDynamicTableRefreshModeConfigChange( - action=RelationConfigChangeAction.create, + action=RelationConfigChangeAction.create, # type:ignore context=new_dynamic_table.refresh_mode, ) if new_dynamic_table.catalog != existing_dynamic_table.catalog: config_change_collection.catalog = SnowflakeCatalogConfigChange( - action=RelationConfigChangeAction.create, + action=RelationConfigChangeAction.create, # type:ignore context=new_dynamic_table.catalog, ) @@ -132,7 +132,7 @@ def dynamic_table_config_changeset( def as_case_sensitive(self) -> "SnowflakeRelation": path_part_map = {} - for path in ComponentName: + for path in ComponentName: # type:ignore if self.include_policy.get_part(path): part = self.path.get_part(path) if part: @@ -166,7 +166,7 @@ def get_ddl_prefix_for_create(self, config: RelationConfig, temporary: bool) -> support temporary relations. """ - transient_explicitly_set_true: bool = config.get("transient", False) + transient_explicitly_set_true: bool = config.get("transient", False) # type:ignore # Temporary tables are a Snowflake feature that do not exist in the # Iceberg framework. We ignore the Iceberg status of the model. @@ -191,7 +191,7 @@ def get_ddl_prefix_for_create(self, config: RelationConfig, temporary: bool) -> # Always supply transient on table create DDL unless user specifically sets # transient to false or unset. Might as well update the object attribute too! - elif transient_explicitly_set_true or config.get("transient", True): + elif transient_explicitly_set_true or config.get("transient", True): # type:ignore return "transient" else: return "" @@ -206,14 +206,15 @@ def get_ddl_prefix_for_alter(self) -> str: def get_iceberg_ddl_options(self, config: RelationConfig) -> str: # If the base_location_root config is supplied, overwrite the default value ("_dbt/") base_location: str = ( - f"{config.get('base_location_root', '_dbt')}/{self.schema}/{self.name}" + f"{config.get('base_location_root', '_dbt')}/{self.schema}/{self.name}" # type:ignore ) - if subpath := config.get("base_location_subpath"): + if subpath := config.get("base_location_subpath"): # type:ignore base_location += f"/{subpath}" + external_volume = config.get("external_volume") # type:ignore iceberg_ddl_predicates: str = f""" - external_volume = '{config.get('external_volume')}' + external_volume = '{external_volume}' catalog = 'snowflake' base_location = '{base_location}' """ diff --git a/dbt-snowflake/src/dbt/adapters/snowflake/relation_configs/catalog.py b/dbt-snowflake/src/dbt/adapters/snowflake/relation_configs/catalog.py index c8d7de40f..44febf8a0 100644 --- a/dbt-snowflake/src/dbt/adapters/snowflake/relation_configs/catalog.py +++ b/dbt-snowflake/src/dbt/adapters/snowflake/relation_configs/catalog.py @@ -66,24 +66,26 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> Self: } if table_format := config_dict.get("table_format"): kwargs_dict["table_format"] = TableFormat(table_format) - return super().from_dict(kwargs_dict) + return super().from_dict(kwargs_dict) # type:ignore @classmethod def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]: - if relation_config.config.extra.get("table_format") is None: + if relation_config.config.extra.get("table_format") is None: # type:ignore return {} config_dict = { - "table_format": relation_config.config.extra.get("table_format"), + "table_format": relation_config.config.extra.get("table_format"), # type:ignore "name": "SNOWFLAKE", # this is not currently configurable } - if external_volume := relation_config.config.extra.get("external_volume"): + if external_volume := relation_config.config.extra.get("external_volume"): # type:ignore config_dict["external_volume"] = external_volume catalog_dirs: List[str] = ["_dbt", relation_config.schema, relation_config.name] - if base_location_subpath := relation_config.config.extra.get("base_location_subpath"): + if base_location_subpath := relation_config.config.extra.get( # type:ignore + "base_location_subpath" + ): catalog_dirs.append(base_location_subpath) config_dict["base_location"] = "/".join(catalog_dirs) diff --git a/dbt-snowflake/src/dbt/adapters/snowflake/relation_configs/dynamic_table.py b/dbt-snowflake/src/dbt/adapters/snowflake/relation_configs/dynamic_table.py index 7361df80a..055d80ddd 100644 --- a/dbt-snowflake/src/dbt/adapters/snowflake/relation_configs/dynamic_table.py +++ b/dbt-snowflake/src/dbt/adapters/snowflake/relation_configs/dynamic_table.py @@ -67,10 +67,14 @@ class SnowflakeDynamicTableConfig(SnowflakeRelationConfigBase): @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> Self: kwargs_dict = { - "name": cls._render_part(ComponentName.Identifier, config_dict.get("name")), - "schema_name": cls._render_part(ComponentName.Schema, config_dict.get("schema_name")), + "name": cls._render_part( + ComponentName.Identifier, config_dict.get("name") # type:ignore + ), + "schema_name": cls._render_part( + ComponentName.Schema, config_dict.get("schema_name") # type:ignore + ), "database_name": cls._render_part( - ComponentName.Database, config_dict.get("database_name") + ComponentName.Database, config_dict.get("database_name") # type:ignore ), "query": config_dict.get("query"), "target_lag": config_dict.get("target_lag"), @@ -80,7 +84,7 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> Self: "initialize": config_dict.get("initialize"), } - return super().from_dict(kwargs_dict) + return super().from_dict(kwargs_dict) # type:ignore @classmethod def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]: @@ -89,15 +93,17 @@ def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any "schema_name": relation_config.schema, "database_name": relation_config.database, "query": relation_config.compiled_code, - "target_lag": relation_config.config.extra.get("target_lag"), - "snowflake_warehouse": relation_config.config.extra.get("snowflake_warehouse"), + "target_lag": relation_config.config.extra.get("target_lag"), # type:ignore + "snowflake_warehouse": relation_config.config.extra.get( # type:ignore + "snowflake_warehouse" + ), "catalog": SnowflakeCatalogConfig.parse_relation_config(relation_config), } - if refresh_mode := relation_config.config.extra.get("refresh_mode"): + if refresh_mode := relation_config.config.extra.get("refresh_mode"): # type:ignore config_dict["refresh_mode"] = refresh_mode.upper() - if initialize := relation_config.config.extra.get("initialize"): + if initialize := relation_config.config.extra.get("initialize"): # type:ignore config_dict["initialize"] = initialize.upper() return config_dict diff --git a/dbt-snowflake/tests/functional/adapter/list_relations_tests/test_pagination.py b/dbt-snowflake/tests/functional/adapter/list_relations_tests/test_pagination.py index 7dd382af5..b839a7069 100644 --- a/dbt-snowflake/tests/functional/adapter/list_relations_tests/test_pagination.py +++ b/dbt-snowflake/tests/functional/adapter/list_relations_tests/test_pagination.py @@ -33,7 +33,7 @@ materialized='dynamic_table', target_lag='1 hour', snowflake_warehouse='""" - + os.getenv("SNOWFLAKE_TEST_WAREHOUSE") + + os.getenv("SNOWFLAKE_TEST_WAREHOUSE", "") + """', ) }} diff --git a/dbt-snowflake/tests/functional/adapter/list_relations_tests/test_show_objects.py b/dbt-snowflake/tests/functional/adapter/list_relations_tests/test_show_objects.py index 91fb94f79..a6d2e065b 100644 --- a/dbt-snowflake/tests/functional/adapter/list_relations_tests/test_show_objects.py +++ b/dbt-snowflake/tests/functional/adapter/list_relations_tests/test_show_objects.py @@ -36,7 +36,7 @@ materialized='dynamic_table', target_lag='1 day', snowflake_warehouse='""" - + os.getenv("SNOWFLAKE_TEST_WAREHOUSE") + + os.getenv("SNOWFLAKE_TEST_WAREHOUSE", "") + """', ) }} select * from {{ ref('my_seed') }} diff --git a/dbt-spark/dagger/run_dbt_spark_tests.py b/dbt-spark/dagger/run_dbt_spark_tests.py index dae366f89..3857650b5 100644 --- a/dbt-spark/dagger/run_dbt_spark_tests.py +++ b/dbt-spark/dagger/run_dbt_spark_tests.py @@ -2,7 +2,7 @@ import argparse import sys -from typing import Dict +from typing import Dict, Tuple import anyio as anyio import dagger as dagger @@ -29,7 +29,7 @@ def env_variables_inner(ctr: dagger.Container): return env_variables_inner -def get_postgres_container(client: dagger.Client) -> (dagger.Container, str): +def get_postgres_container(client: dagger.Client) -> Tuple[dagger.Container, str]: ctr = ( client.container() .from_("postgres:13") @@ -41,7 +41,7 @@ def get_postgres_container(client: dagger.Client) -> (dagger.Container, str): return ctr, "postgres_db" -def get_spark_container(client: dagger.Client) -> (dagger.Service, str): +def get_spark_container(client: dagger.Client) -> Tuple[dagger.Service, str]: spark_dir = client.host().directory("./dagger/spark-container") spark_ctr_base = ( client.container() diff --git a/dbt-spark/src/dbt/__init__.py b/dbt-spark/src/dbt/__init__.py new file mode 100644 index 000000000..b36383a61 --- /dev/null +++ b/dbt-spark/src/dbt/__init__.py @@ -0,0 +1,3 @@ +from pkgutil import extend_path + +__path__ = extend_path(__path__, __name__) diff --git a/dbt-spark/src/dbt/adapters/spark/__init__.py b/dbt-spark/src/dbt/adapters/spark/__init__.py index 6ecc5eccf..fd984df0e 100644 --- a/dbt-spark/src/dbt/adapters/spark/__init__.py +++ b/dbt-spark/src/dbt/adapters/spark/__init__.py @@ -8,5 +8,7 @@ from dbt.include import spark Plugin = AdapterPlugin( - adapter=SparkAdapter, credentials=SparkCredentials, include_path=spark.PACKAGE_PATH + adapter=SparkAdapter, # type:ignore + credentials=SparkCredentials, + include_path=spark.PACKAGE_PATH, ) diff --git a/dbt-spark/src/dbt/adapters/spark/column.py b/dbt-spark/src/dbt/adapters/spark/column.py index 98fa24a17..44d2fc39b 100644 --- a/dbt-spark/src/dbt/adapters/spark/column.py +++ b/dbt-spark/src/dbt/adapters/spark/column.py @@ -21,7 +21,7 @@ class SparkColumn(dbtClassMixin, Column): def translate_type(cls, dtype: str) -> str: return dtype - def can_expand_to(self: Self, other_column: Self) -> bool: + def can_expand_to(self, other_column: Column) -> bool: """returns True if both columns are strings""" return self.is_string() and other_column.is_string() diff --git a/dbt-spark/src/dbt/adapters/spark/connections.py b/dbt-spark/src/dbt/adapters/spark/connections.py index d9b615ecb..571fa0493 100644 --- a/dbt-spark/src/dbt/adapters/spark/connections.py +++ b/dbt-spark/src/dbt/adapters/spark/connections.py @@ -65,9 +65,9 @@ class SparkConnectionMethod(StrEnum): @dataclass class SparkCredentials(Credentials): host: Optional[str] = None - schema: Optional[str] = None + schema: Optional[str] = None # type:ignore method: SparkConnectionMethod = None # type: ignore - database: Optional[str] = None + database: Optional[str] = None # type:ignore driver: Optional[str] = None cluster: Optional[str] = None endpoint: Optional[str] = None @@ -578,11 +578,11 @@ def open(cls, connection: Connection) -> Connection: raise exc # type: ignore connection.handle = handle - connection.state = ConnectionState.OPEN + connection.state = ConnectionState.OPEN # type:ignore return connection @classmethod - def data_type_code_to_name(cls, type_code: Union[type, str]) -> str: + def data_type_code_to_name(cls, type_code: Union[type, str]) -> str: # type:ignore """ :param Union[type, str] type_code: The sql to execute. * type_code is a python type (!) in pyodbc https://github.com/mkleehammer/pyodbc/wiki/Cursor#description, and a string for other spark runtimes. diff --git a/dbt-spark/src/dbt/adapters/spark/impl.py b/dbt-spark/src/dbt/adapters/spark/impl.py index 5f8178a9d..d4c666fda 100644 --- a/dbt-spark/src/dbt/adapters/spark/impl.py +++ b/dbt-spark/src/dbt/adapters/spark/impl.py @@ -157,7 +157,7 @@ def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str: def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "timestamp" - def quote(self, identifier: str) -> str: + def quote(self, identifier: str) -> str: # type:ignore return "`{}`".format(identifier) def _get_relation_information(self, row: "agate.Row") -> RelationInfo: @@ -208,7 +208,9 @@ def _build_spark_relation_list( _schema, name, information = relation_info_func(row) rel_type: RelationType = ( - RelationType.View if "Type: VIEW" in information else RelationType.Table + RelationType.View + if "Type: VIEW" in information + else RelationType.Table # type:ignore ) is_delta: bool = "Provider: delta" in information is_hudi: bool = "Provider: hudi" in information diff --git a/dbt-spark/tests/unit/test_adapter_telemetry.py b/dbt-spark/tests/unit/test_adapter_telemetry.py index b0de952b6..67758c50c 100644 --- a/dbt-spark/tests/unit/test_adapter_telemetry.py +++ b/dbt-spark/tests/unit/test_adapter_telemetry.py @@ -1,5 +1,6 @@ from unittest import mock +import dbt.adapters.__about__ import dbt.adapters.spark.__version__ from dbt.adapters.spark.impl import SparkAdapter diff --git a/dbt-spark/tests/unit/test_credentials.py b/dbt-spark/tests/unit/test_credentials.py index 7a81fdbb1..5e436ff11 100644 --- a/dbt-spark/tests/unit/test_credentials.py +++ b/dbt-spark/tests/unit/test_credentials.py @@ -4,9 +4,9 @@ def test_credentials_server_side_parameters_keys_and_values_are_strings() -> None: credentials = SparkCredentials( host="localhost", - method=SparkConnectionMethod.THRIFT, + method=SparkConnectionMethod.THRIFT, # type:ignore database="tests", schema="tests", - server_side_parameters={"spark.configuration": 10}, + server_side_parameters={"spark.configuration": "10"}, ) assert credentials.server_side_parameters["spark.configuration"] == "10"