Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize mypy config #800

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ repos:
rev: v1.11.2
hooks:
- id: mypy
exclude: dbt-adapters/src/dbt/adapters/events/adapter_types_pb2.py|dbt-tests-adapter/src/dbt/__init__.py
files: (dbt-adapters|dbt-athena|dbt-bigquery|dbt-postgres|dbt-redshift|dbt-snowflake|dbt-spark)/src/dbt/adapters|dbt-tests-adapter/src/dbt/tests
args:
- --explicit-package-bases
- --namespace-packages
- --ignore-missing-imports
- --warn-redundant-casts
- --warn-unused-ignores
- --pretty
- --show-error-codes
files: ^dbt-adapters/src/dbt/adapters/
additional_dependencies:
- types-PyYAML
- types-protobuf
Expand Down
6 changes: 0 additions & 6 deletions dbt-adapters/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ Repository = "https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-adap
Issues = "https://github.com/dbt-labs/dbt-adapters/issues"
Changelog = "https://github.com/dbt-labs/dbt-adapters/blob/main/dbt-adapters/CHANGELOG.md"

[tool.mypy]
mypy_path = "third-party-stubs/"
[[tool.mypy.overrides]]
module = ["dbt.adapters.events.adapter_types_pb2"]
follow_imports = "skip"

[tool.pytest]
env_files = ["test.env"]
testpaths = [
Expand Down
6 changes: 3 additions & 3 deletions dbt-adapters/src/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ def _get_one_catalog(
kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(GET_CATALOG_MACRO_NAME, kwargs=kwargs)

results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
results = self._catalog_filter_table(table, used_schemas)
return results

def _get_one_catalog_by_relations(
Expand All @@ -1239,7 +1239,7 @@ def _get_one_catalog_by_relations(
}
table = self.execute_macro(GET_CATALOG_RELATIONS_MACRO_NAME, kwargs=kwargs)

results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
results = self._catalog_filter_table(table, used_schemas)
return results

def get_filtered_catalog(
Expand Down Expand Up @@ -1435,7 +1435,7 @@ def calculate_freshness_from_metadata_batch(
macro_resolver=macro_resolver,
needs_conn=True,
)
adapter_response, table = result.response, result.table # type: ignore[attr-defined]
adapter_response, table = result.response, result.table
adapter_responses.append(adapter_response)

for row in table:
Expand Down
2 changes: 1 addition & 1 deletion dbt-adapters/src/dbt/adapters/base/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __new__(mcls, name, bases, namespace, **kwargs) -> "AdapterMeta":
# I'm not sure there is any benefit to it after poking around a bit,
# but having it doesn't hurt on the python side (and omitting it could
# hurt for obscure metaclass reasons, for all I know)
cls = abc.ABCMeta.__new__(mcls, name, bases, namespace, **kwargs) # type: ignore
cls = abc.ABCMeta.__new__(mcls, name, bases, namespace, **kwargs)

# this is very much inspired by ABCMeta's own implementation

Expand Down
2 changes: 1 addition & 1 deletion dbt-adapters/src/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def matches(
if str(self.path.get_lowered_part(k)).strip(self.quote_character) != v.lower().strip(
self.quote_character
):
approximate_match = False # type: ignore[union-attr]
approximate_match = False

if approximate_match and not exact_match:
target = self.create(database=database, schema=schema, identifier=identifier)
Expand Down
2 changes: 1 addition & 1 deletion dbt-adapters/src/dbt/adapters/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def resolve(self, connection: Connection) -> Connection:
# and https://github.com/python/mypy/issues/5374
# for why we have type: ignore. Maybe someday dataclasses + abstract classes
# will work.
@dataclass # type: ignore
@dataclass
class Credentials(ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta):
database: str
schema: str
Expand Down
2 changes: 1 addition & 1 deletion dbt-adapters/src/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __call__(


# TODO CT-211
class AdapterProtocol( # type: ignore[misc]
class AdapterProtocol(
Protocol,
Generic[
AdapterConfig_T,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def from_dict(cls, kwargs_dict) -> "RelationConfigBase":

Returns: the `RelationConfigBase` representation associated with the provided dict
"""
return cls(**filter_null_values(kwargs_dict)) # type: ignore
return cls(**filter_null_values(kwargs_dict))

@classmethod
def _not_implemented_error(cls) -> NotImplementedError:
Expand Down
4 changes: 2 additions & 2 deletions dbt-adapters/src/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
import agate

# TODO CT-211
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined]
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float8" if decimals else "integer"

@classmethod
Expand Down Expand Up @@ -247,7 +247,7 @@ def validate_sql(self, sql: str) -> AdapterResponse:
# return fetched output for engines where explain plans are emitted as columnar
# results. Any macro override that deviates from this behavior may encounter an
# assertion error in the runtime.
adapter_response = result.response # type: ignore[attr-defined]
adapter_response = result.response
assert isinstance(adapter_response, AdapterResponse), (
f"Expected AdapterResponse from validate_sql macro execution, "
f"got {type(adapter_response)}."
Expand Down
2 changes: 1 addition & 1 deletion dbt-adapters/tests/unit/fixtures/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def is_cancelable(cls) -> bool:
return False

def list_schemas(self, database: str) -> List[str]:
return list(self.cache.schemas)
return list(schema for database, schema in self.cache.schemas if isinstance(schema, str))

###
# Abstract methods about relations
Expand Down
8 changes: 4 additions & 4 deletions dbt-adapters/tests/unit/fixtures/connection_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import ContextManager, List, Optional, Tuple
from typing import Generator, List, Optional, Tuple, Any

import agate

Expand All @@ -15,7 +15,7 @@ class ConnectionManagerStub(BaseConnectionManager):
raised_exceptions: List[Exception]

@contextmanager
def exception_handler(self, sql: str) -> ContextManager: # type: ignore
def exception_handler(self, sql: str) -> Generator[None, Any, None]: # type: ignore
# catch all exceptions and put them on this class for inspection in tests
try:
yield
Expand All @@ -28,15 +28,15 @@ def cancel_open(self) -> Optional[List[str]]:
names = []
for connection in self.thread_connections.values():
if connection.state == ConnectionState.OPEN:
connection.state = ConnectionState.CLOSED
connection.state = ConnectionState.CLOSED # type: ignore
if name := connection.name:
names.append(name)
return names

@classmethod
def open(cls, connection: Connection) -> Connection:
# there's no database, so just change the state
connection.state = ConnectionState.OPEN
connection.state = ConnectionState.OPEN # type: ignore
return connection

def begin(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions dbt-adapters/tests/unit/fixtures/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class CredentialsStub(Credentials):
A stub for a database credentials that does not connect to a database
"""

@property
def type(self) -> str:
return "test"

Expand Down
4 changes: 3 additions & 1 deletion dbt-athena/src/dbt/adapters/athena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from dbt.include import athena

Plugin: AdapterPlugin = AdapterPlugin(
adapter=AthenaAdapter, credentials=AthenaCredentials, include_path=athena.PACKAGE_PATH
adapter=AthenaAdapter, # type:ignore
credentials=AthenaCredentials,
include_path=athena.PACKAGE_PATH,
)

__all__ = [
Expand Down
8 changes: 4 additions & 4 deletions dbt-athena/src/dbt/adapters/athena/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def is_timestamp(self) -> bool:
return self.dtype.lower() in {"timestamp"}

def is_array(self) -> bool:
return self.dtype.lower().startswith("array") # type: ignore
return self.dtype.lower().startswith("array")

@classmethod
def string_type(cls, size: int) -> str:
Expand Down Expand Up @@ -58,7 +58,7 @@ def array_inner_type(self) -> str:
if match:
return match.group(1)
# If for some reason there's no match, fall back to the original string
return self.dtype # type: ignore
return self.dtype

def string_size(self) -> int:
if not self.is_string():
Expand All @@ -72,7 +72,7 @@ def data_type(self) -> str:
return self.string_type(self.string_size())

if self.is_numeric():
return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) # type: ignore
return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale)

if self.is_binary():
return self.binary_type()
Expand All @@ -94,4 +94,4 @@ def data_type(self) -> str:
)
return self.array_type(inner_type_col.data_type)

return self.dtype # type: ignore
return self.dtype
12 changes: 8 additions & 4 deletions dbt-athena/src/dbt/adapters/athena/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,21 @@ def set_engine_config(self) -> Dict[str, Any]:

default_spark_properties: Dict[str, str] = dict(
**(
DEFAULT_SPARK_PROPERTIES.get(table_type)
DEFAULT_SPARK_PROPERTIES.get(table_type, {})
if table_type.lower() in ["iceberg", "hudi", "delta_lake"]
else {}
),
**DEFAULT_SPARK_PROPERTIES.get("spark_encryption") if spark_encryption else {},
**DEFAULT_SPARK_PROPERTIES.get("spark_encryption", {}) if spark_encryption else {},
**(
DEFAULT_SPARK_PROPERTIES.get("spark_cross_account_catalog")
DEFAULT_SPARK_PROPERTIES.get("spark_cross_account_catalog", {})
if spark_cross_account_catalog
else {}
),
**DEFAULT_SPARK_PROPERTIES.get("spark_requester_pays") if spark_requester_pays else {},
**(
DEFAULT_SPARK_PROPERTIES.get("spark_requester_pays", {})
if spark_requester_pays
else {}
),
)

default_engine_config = {
Expand Down
12 changes: 6 additions & 6 deletions dbt-athena/src/dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _connection_keys(self) -> Tuple[str, ...]:


class AthenaCursor(Cursor):
def __init__(self, **kwargs) -> None: # type: ignore
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._executor = ThreadPoolExecutor()

Expand Down Expand Up @@ -224,9 +224,9 @@ def execute_with_iceberg_retries() -> AthenaCursor:
return self
raise OperationalError(query_execution.state_change_reason)

return execute_with_iceberg_retries() # type: ignore
return execute_with_iceberg_retries()

return inner() # type: ignore
return inner()


class AthenaConnectionManager(SQLConnectionManager):
Expand All @@ -236,7 +236,7 @@ def set_query_header(self, query_header_context: Dict[str, Any]) -> None:
self.query_header = AthenaMacroQueryStringSetter(self.profile, query_header_context)

@classmethod
def data_type_code_to_name(cls, type_code: str) -> str:
def data_type_code_to_name(cls, type_code: str) -> str: # type:ignore
"""
Get the string representation of the data type from the Athena metadata. Dbt performs a
query to retrieve the types of the columns in the SQL query. Then these types are compared
Expand Down Expand Up @@ -287,15 +287,15 @@ def open(cls, connection: Connection) -> Connection:
config=get_boto3_config(num_retries=creds.effective_num_retries),
)

connection.state = ConnectionState.OPEN
connection.state = ConnectionState.OPEN # type:ignore
connection.handle = handle

except Exception as exc:
LOGGER.exception(
f"Got an error when attempting to open a Athena connection due to {exc}"
)
connection.handle = None
connection.state = ConnectionState.FAIL
connection.state = ConnectionState.FAIL # type:ignore
raise ConnectionError(str(exc))

return connection
Expand Down
Loading
Loading