Skip to content

Commit 9e70d1d

Browse files
hovaescodamian3031
authored andcommitted
mypy checks
1 parent 61b8133 commit 9e70d1d

File tree

3 files changed

+75
-118
lines changed

3 files changed

+75
-118
lines changed

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ ignore_missing_imports = true
1919
no_implicit_optional = true
2020
warn_unused_ignores = true
2121

22-
[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*]
22+
[mypy-tests.*,trino.client,trino.sqlalchemy.*,trino.dbapi]
2323
ignore_errors = true

trino/client.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ class ClientSession(object):
125125

126126
def __init__(
127127
self,
128-
user: str,
129-
catalog: str = None,
130-
schema: str = None,
131-
source: str = None,
128+
user: Optional[str],
129+
catalog: Optional[str] = None,
130+
schema: Optional[str] = None,
131+
source: Optional[str] = None,
132132
properties: Dict[str, str] = None,
133133
headers: Dict[str, str] = None,
134134
transaction_id: str = None,

trino/dbapi.py

+70-113
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import math
2323
import uuid
2424
from decimal import Decimal
25-
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
25+
from types import TracebackType
26+
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Type, Union
2627

2728
import trino.client
2829
import trino.exceptions
@@ -72,7 +73,7 @@
7273
logger = trino.logging.get_logger(__name__)
7374

7475

75-
def connect(*args, **kwargs):
76+
def connect(*args: Any, **kwargs: Any) -> trino.dbapi.Connection:
7677
"""Constructor for creating a connection to the database.
7778
7879
See class :py:class:`Connection` for arguments.
@@ -92,28 +93,28 @@ class Connection(object):
9293

9394
def __init__(
9495
self,
95-
host,
96-
port=constants.DEFAULT_PORT,
97-
user=None,
98-
source=constants.DEFAULT_SOURCE,
99-
catalog=constants.DEFAULT_CATALOG,
100-
schema=constants.DEFAULT_SCHEMA,
101-
session_properties=None,
102-
http_headers=None,
103-
http_scheme=constants.HTTP,
104-
auth=constants.DEFAULT_AUTH,
105-
extra_credential=None,
106-
redirect_handler=None,
107-
max_attempts=constants.DEFAULT_MAX_ATTEMPTS,
108-
request_timeout=constants.DEFAULT_REQUEST_TIMEOUT,
109-
isolation_level=IsolationLevel.AUTOCOMMIT,
110-
verify=True,
111-
http_session=None,
112-
client_tags=None,
113-
legacy_primitive_types=False,
114-
roles=None,
96+
host: str,
97+
port: int = constants.DEFAULT_PORT,
98+
user: Optional[str] = None,
99+
source: str = constants.DEFAULT_SOURCE,
100+
catalog: Optional[str] = constants.DEFAULT_CATALOG,
101+
schema: Optional[str] = constants.DEFAULT_SCHEMA,
102+
session_properties: Optional[Dict[str, str]] = None,
103+
http_headers: Optional[Dict[str, str]] = None,
104+
http_scheme: str = constants.HTTP,
105+
auth: Optional[trino.auth.Authentication] = constants.DEFAULT_AUTH,
106+
extra_credential: Optional[List[Tuple[str, str]]] = None,
107+
redirect_handler: Optional[str] = None,
108+
max_attempts: int = constants.DEFAULT_MAX_ATTEMPTS,
109+
request_timeout: float = constants.DEFAULT_REQUEST_TIMEOUT,
110+
isolation_level: IsolationLevel = IsolationLevel.AUTOCOMMIT,
111+
verify: Union[bool | str] = True,
112+
http_session: Optional[trino.client.TrinoRequest.http.Session] = None,
113+
client_tags: Optional[List[str]] = None,
114+
legacy_primitive_types: Optional[bool] = False,
115+
roles: Optional[Dict[str, str]] = None,
115116
timezone=None,
116-
):
117+
) -> None:
117118
self.host = host
118119
self.port = port
119120
self.user = user
@@ -151,50 +152,53 @@ def __init__(
151152

152153
self._isolation_level = isolation_level
153154
self._request = None
154-
self._transaction = None
155+
self._transaction: Optional[Transaction] = None
155156
self.legacy_primitive_types = legacy_primitive_types
156157

157158
@property
158-
def isolation_level(self):
159+
def isolation_level(self) -> IsolationLevel:
159160
return self._isolation_level
160161

161162
@property
162-
def transaction(self):
163+
def transaction(self) -> Optional[Transaction]:
163164
return self._transaction
164165

165-
def __enter__(self):
166+
def __enter__(self) -> object:
166167
return self
167168

168-
def __exit__(self, exc_type, exc_value, traceback):
169+
def __exit__(self,
170+
exc_type: Optional[Type[BaseException]],
171+
exc_value: Optional[BaseException],
172+
traceback: Optional[TracebackType]) -> None:
169173
try:
170174
self.commit()
171175
except Exception:
172176
self.rollback()
173177
else:
174178
self.close()
175179

176-
def close(self):
180+
def close(self) -> None:
177181
# TODO cancel outstanding queries?
178182
self._http_session.close()
179183

180-
def start_transaction(self):
184+
def start_transaction(self) -> Transaction:
181185
self._transaction = Transaction(self._create_request())
182186
self._transaction.begin()
183187
return self._transaction
184188

185-
def commit(self):
186-
if self.transaction is None:
189+
def commit(self) -> None:
190+
if self._transaction is None:
187191
return
188192
self._transaction.commit()
189193
self._transaction = None
190194

191-
def rollback(self):
192-
if self.transaction is None:
195+
def rollback(self) -> None:
196+
if self._transaction is None:
193197
raise RuntimeError("no transaction was started")
194198
self._transaction.rollback()
195199
self._transaction = None
196200

197-
def _create_request(self):
201+
def _create_request(self) -> trino.client.TrinoRequest:
198202
return trino.client.TrinoRequest(
199203
self.host,
200204
self.port,
@@ -207,7 +211,7 @@ def _create_request(self):
207211
self.request_timeout,
208212
)
209213

210-
def cursor(self, legacy_primitive_types: bool = None):
214+
def cursor(self, legacy_primitive_types: bool = None) -> 'trino.dbapi.Cursor':
211215
"""Return a new :py:class:`Cursor` object using the connection."""
212216
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
213217
if self.transaction is None:
@@ -271,7 +275,10 @@ class Cursor(object):
271275
272276
"""
273277

274-
def __init__(self, connection, request, legacy_primitive_types: bool = False):
278+
def __init__(self,
279+
connection: Connection,
280+
request: trino.client.TrinoRequest,
281+
legacy_primitive_types: bool = False) -> None:
275282
if not isinstance(connection, Connection):
276283
raise ValueError(
277284
"connection must be a Connection object: {}".format(type(connection))
@@ -280,32 +287,32 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
280287
self._request = request
281288

282289
self.arraysize = 1
283-
self._iterator = None
284-
self._query = None
290+
self._iterator: Optional[Iterator[Any]] = None
291+
self._query: Optional[trino.client.TrinoQuery] = None
285292
self._legacy_primitive_types = legacy_primitive_types
286293

287-
def __iter__(self):
294+
def __iter__(self) -> Optional[Iterator[Any]]:
288295
return self._iterator
289296

290297
@property
291-
def connection(self):
298+
def connection(self) -> Connection:
292299
return self._connection
293300

294301
@property
295-
def info_uri(self):
302+
def info_uri(self) -> Optional[str]:
296303
if self._query is not None:
297304
return self._query.info_uri
298305
return None
299306

300307
@property
301-
def update_type(self):
308+
def update_type(self) -> Optional[str]:
302309
if self._query is not None:
303310
return self._query.update_type
304311
return None
305312

306313
@property
307-
def description(self) -> List[ColumnDescription]:
308-
if self._query.columns is None:
314+
def description(self) -> Optional[List[Tuple[Any, ...]]]:
315+
if self._query is None or self._query.columns is None:
309316
return None
310317

311318
# [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
@@ -314,7 +321,7 @@ def description(self) -> List[ColumnDescription]:
314321
]
315322

316323
@property
317-
def rowcount(self):
324+
def rowcount(self) -> int:
318325
"""Not supported.
319326
320327
Trino cannot reliablity determine the number of rows returned by an
@@ -325,27 +332,21 @@ def rowcount(self):
325332
return -1
326333

327334
@property
328-
def stats(self):
335+
def stats(self) -> Optional[Dict[Any, Any]]:
329336
if self._query is not None:
330337
return self._query.stats
331338
return None
332339

333340
@property
334-
def query_id(self) -> Optional[str]:
335-
if self._query is not None:
336-
return self._query.query_id
337-
return None
338-
339-
@property
340-
def warnings(self):
341+
def warnings(self) -> Optional[List[Dict[Any, Any]]]:
341342
if self._query is not None:
342343
return self._query.warnings
343344
return None
344345

345-
def setinputsizes(self, sizes):
346+
def setinputsizes(self, sizes: Sequence[Any]) -> None:
346347
raise trino.exceptions.NotSupportedError
347348

348-
def setoutputsize(self, size, column):
349+
def setoutputsize(self, size: int, column: Optional[int]) -> None:
349350
raise trino.exceptions.NotSupportedError
350351

351352
def _prepare_statement(self, statement: str, name: str) -> None:
@@ -363,13 +364,13 @@ def _prepare_statement(self, statement: str, name: str) -> None:
363364

364365
def _execute_prepared_statement(
365366
self,
366-
statement_name,
367-
params
368-
):
367+
statement_name: str,
368+
params: Any
369+
) -> trino.client.TrinoQuery:
369370
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
370371
return trino.client.TrinoQuery(self._request, sql=sql, legacy_primitive_types=self._legacy_primitive_types)
371372

372-
def _format_prepared_param(self, param):
373+
def _format_prepared_param(self, param: Any) -> str:
373374
"""
374375
Formats parameters to be passed in an
375376
EXECUTE statement.
@@ -451,10 +452,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
451452
legacy_primitive_types=self._legacy_primitive_types)
452453
query.execute()
453454

454-
def _generate_unique_statement_name(self):
455+
def _generate_unique_statement_name(self) -> str:
455456
return 'st_' + uuid.uuid4().hex.replace('-', '')
456457

457-
def execute(self, operation, params=None):
458+
def execute(self, operation: str, params: Optional[Any] = None) -> trino.client.TrinoResult:
458459
if params:
459460
assert isinstance(params, (list, tuple)), (
460461
'params must be a list or tuple containing the query '
@@ -484,7 +485,7 @@ def execute(self, operation, params=None):
484485
self._iterator = iter(self._query.execute())
485486
return self
486487

487-
def executemany(self, operation, seq_of_params):
488+
def executemany(self, operation: str, seq_of_params: Any) -> None:
488489
"""
489490
PEP-0249: Prepare a database operation (query or command) and then
490491
execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
@@ -503,6 +504,7 @@ def executemany(self, operation, seq_of_params):
503504
for parameters in seq_of_params[:-1]:
504505
self.execute(operation, parameters)
505506
self.fetchall()
507+
assert self._query is not None
506508
if self._query.update_type is None:
507509
raise NotSupportedError("Query must return update type")
508510
if seq_of_params:
@@ -529,7 +531,7 @@ def fetchone(self) -> Optional[List[Any]]:
529531
except trino.exceptions.HttpError as err:
530532
raise trino.exceptions.OperationalError(str(err))
531533

532-
def fetchmany(self, size=None) -> List[List[Any]]:
534+
def fetchmany(self, size: Optional[int] = None) -> List[List[Any]]:
533535
"""
534536
PEP-0249: Fetch the next set of rows of a query result, returning a
535537
sequence of sequences (e.g. a list of tuples). An empty sequence is
@@ -562,6 +564,7 @@ def fetchmany(self, size=None) -> List[List[Any]]:
562564

563565
return result
564566

567+
<<<<<<< HEAD
565568
def describe(self, sql: str) -> List[DescribeOutput]:
566569
"""
567570
List the output columns of a SQL statement, including the column name (or alias), catalog, schema, table, type,
@@ -584,66 +587,20 @@ def describe(self, sql: str) -> List[DescribeOutput]:
584587

585588
return list(map(lambda x: DescribeOutput.from_row(x), result))
586589

587-
def genall(self):
590+
def genall(self) -> trino.client.TrinoResult:
588591
return self._query.result
589592

590593
def fetchall(self) -> List[List[Any]]:
591594
return list(self.genall())
592595

593-
def cancel(self):
596+
def cancel(self) -> None:
594597
if self._query is None:
595598
raise trino.exceptions.OperationalError(
596599
"Cancel query failed; no running query"
597600
)
598601
self._query.cancel()
599602

600-
def close(self):
603+
def close(self) -> None:
601604
self.cancel()
602605
# TODO: Cancel not only the last query executed on this cursor
603606
# but also any other outstanding queries executed through this cursor.
604-
605-
606-
Date = datetime.date
607-
Time = datetime.time
608-
Timestamp = datetime.datetime
609-
DateFromTicks = datetime.date.fromtimestamp
610-
TimestampFromTicks = datetime.datetime.fromtimestamp
611-
612-
613-
def TimeFromTicks(ticks):
614-
return datetime.time(*datetime.localtime(ticks)[3:6])
615-
616-
617-
def Binary(string):
618-
return string.encode("utf-8")
619-
620-
621-
class DBAPITypeObject:
622-
def __init__(self, *values):
623-
self.values = [v.lower() for v in values]
624-
625-
def __eq__(self, other):
626-
return other.lower() in self.values
627-
628-
629-
STRING = DBAPITypeObject("VARCHAR", "CHAR", "VARBINARY", "JSON", "IPADDRESS")
630-
631-
BINARY = DBAPITypeObject(
632-
"ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest"
633-
)
634-
635-
NUMBER = DBAPITypeObject(
636-
"BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL"
637-
)
638-
639-
DATETIME = DBAPITypeObject(
640-
"DATE",
641-
"TIME",
642-
"TIME WITH TIME ZONE",
643-
"TIMESTAMP",
644-
"TIMESTAMP WITH TIME ZONE",
645-
"INTERVAL YEAR TO MONTH",
646-
"INTERVAL DAY TO SECOND",
647-
)
648-
649-
ROWID = DBAPITypeObject() # nothing indicates row id in Trino

0 commit comments

Comments
 (0)