22
22
import math
23
23
import uuid
24
24
from decimal import Decimal
25
- from typing import Any , Dict , List , NamedTuple , Optional # NOQA for mypy types
25
+ from types import TracebackType
26
+ from typing import Any , Dict , Iterator , List , NamedTuple , Optional , Sequence , Tuple , Type , Union
26
27
27
28
import trino .client
28
29
import trino .exceptions
72
73
logger = trino .logging .get_logger (__name__ )
73
74
74
75
75
- def connect (* args , ** kwargs ) :
76
+ def connect (* args : Any , ** kwargs : Any ) -> trino . dbapi . Connection :
76
77
"""Constructor for creating a connection to the database.
77
78
78
79
See class :py:class:`Connection` for arguments.
@@ -92,28 +93,28 @@ class Connection(object):
92
93
93
94
def __init__ (
94
95
self ,
95
- host ,
96
- port = constants .DEFAULT_PORT ,
97
- user = None ,
98
- source = constants .DEFAULT_SOURCE ,
99
- catalog = constants .DEFAULT_CATALOG ,
100
- schema = constants .DEFAULT_SCHEMA ,
101
- session_properties = None ,
102
- http_headers = None ,
103
- http_scheme = constants .HTTP ,
104
- auth = constants .DEFAULT_AUTH ,
105
- extra_credential = None ,
106
- redirect_handler = None ,
107
- max_attempts = constants .DEFAULT_MAX_ATTEMPTS ,
108
- request_timeout = constants .DEFAULT_REQUEST_TIMEOUT ,
109
- isolation_level = IsolationLevel .AUTOCOMMIT ,
110
- verify = True ,
111
- http_session = None ,
112
- client_tags = None ,
113
- legacy_primitive_types = False ,
114
- roles = None ,
96
+ host : str ,
97
+ port : int = constants .DEFAULT_PORT ,
98
+ user : Optional [ str ] = None ,
99
+ source : str = constants .DEFAULT_SOURCE ,
100
+ catalog : Optional [ str ] = constants .DEFAULT_CATALOG ,
101
+ schema : Optional [ str ] = constants .DEFAULT_SCHEMA ,
102
+ session_properties : Optional [ Dict [ str , str ]] = None ,
103
+ http_headers : Optional [ Dict [ str , str ]] = None ,
104
+ http_scheme : str = constants .HTTP ,
105
+ auth : Optional [ trino . auth . Authentication ] = constants .DEFAULT_AUTH ,
106
+ extra_credential : Optional [ List [ Tuple [ str , str ]]] = None ,
107
+ redirect_handler : Optional [ str ] = None ,
108
+ max_attempts : int = constants .DEFAULT_MAX_ATTEMPTS ,
109
+ request_timeout : float = constants .DEFAULT_REQUEST_TIMEOUT ,
110
+ isolation_level : IsolationLevel = IsolationLevel .AUTOCOMMIT ,
111
+ verify : Union [ bool | str ] = True ,
112
+ http_session : Optional [ trino . client . TrinoRequest . http . Session ] = None ,
113
+ client_tags : Optional [ List [ str ]] = None ,
114
+ legacy_primitive_types : Optional [ bool ] = False ,
115
+ roles : Optional [ Dict [ str , str ]] = None ,
115
116
timezone = None ,
116
- ):
117
+ ) -> None :
117
118
self .host = host
118
119
self .port = port
119
120
self .user = user
@@ -151,50 +152,53 @@ def __init__(
151
152
152
153
self ._isolation_level = isolation_level
153
154
self ._request = None
154
- self ._transaction = None
155
+ self ._transaction : Optional [ Transaction ] = None
155
156
self .legacy_primitive_types = legacy_primitive_types
156
157
157
158
@property
158
- def isolation_level (self ):
159
+ def isolation_level (self ) -> IsolationLevel :
159
160
return self ._isolation_level
160
161
161
162
@property
162
- def transaction (self ):
163
+ def transaction (self ) -> Optional [ Transaction ] :
163
164
return self ._transaction
164
165
165
- def __enter__ (self ):
166
+ def __enter__ (self ) -> object :
166
167
return self
167
168
168
- def __exit__ (self , exc_type , exc_value , traceback ):
169
+ def __exit__ (self ,
170
+ exc_type : Optional [Type [BaseException ]],
171
+ exc_value : Optional [BaseException ],
172
+ traceback : Optional [TracebackType ]) -> None :
169
173
try :
170
174
self .commit ()
171
175
except Exception :
172
176
self .rollback ()
173
177
else :
174
178
self .close ()
175
179
176
- def close (self ):
180
+ def close (self ) -> None :
177
181
# TODO cancel outstanding queries?
178
182
self ._http_session .close ()
179
183
180
- def start_transaction (self ):
184
+ def start_transaction (self ) -> Transaction :
181
185
self ._transaction = Transaction (self ._create_request ())
182
186
self ._transaction .begin ()
183
187
return self ._transaction
184
188
185
- def commit (self ):
186
- if self .transaction is None :
189
+ def commit (self ) -> None :
190
+ if self ._transaction is None :
187
191
return
188
192
self ._transaction .commit ()
189
193
self ._transaction = None
190
194
191
- def rollback (self ):
192
- if self .transaction is None :
195
+ def rollback (self ) -> None :
196
+ if self ._transaction is None :
193
197
raise RuntimeError ("no transaction was started" )
194
198
self ._transaction .rollback ()
195
199
self ._transaction = None
196
200
197
- def _create_request (self ):
201
+ def _create_request (self ) -> trino . client . TrinoRequest :
198
202
return trino .client .TrinoRequest (
199
203
self .host ,
200
204
self .port ,
@@ -207,7 +211,7 @@ def _create_request(self):
207
211
self .request_timeout ,
208
212
)
209
213
210
- def cursor (self , legacy_primitive_types : bool = None ):
214
+ def cursor (self , legacy_primitive_types : bool = None ) -> 'trino.dbapi.Cursor' :
211
215
"""Return a new :py:class:`Cursor` object using the connection."""
212
216
if self .isolation_level != IsolationLevel .AUTOCOMMIT :
213
217
if self .transaction is None :
@@ -271,7 +275,10 @@ class Cursor(object):
271
275
272
276
"""
273
277
274
- def __init__ (self , connection , request , legacy_primitive_types : bool = False ):
278
+ def __init__ (self ,
279
+ connection : Connection ,
280
+ request : trino .client .TrinoRequest ,
281
+ legacy_primitive_types : bool = False ) -> None :
275
282
if not isinstance (connection , Connection ):
276
283
raise ValueError (
277
284
"connection must be a Connection object: {}" .format (type (connection ))
@@ -280,32 +287,32 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
280
287
self ._request = request
281
288
282
289
self .arraysize = 1
283
- self ._iterator = None
284
- self ._query = None
290
+ self ._iterator : Optional [ Iterator [ Any ]] = None
291
+ self ._query : Optional [ trino . client . TrinoQuery ] = None
285
292
self ._legacy_primitive_types = legacy_primitive_types
286
293
287
- def __iter__ (self ):
294
+ def __iter__ (self ) -> Optional [ Iterator [ Any ]] :
288
295
return self ._iterator
289
296
290
297
@property
291
- def connection (self ):
298
+ def connection (self ) -> Connection :
292
299
return self ._connection
293
300
294
301
@property
295
- def info_uri (self ):
302
+ def info_uri (self ) -> Optional [ str ] :
296
303
if self ._query is not None :
297
304
return self ._query .info_uri
298
305
return None
299
306
300
307
@property
301
- def update_type (self ):
308
+ def update_type (self ) -> Optional [ str ] :
302
309
if self ._query is not None :
303
310
return self ._query .update_type
304
311
return None
305
312
306
313
@property
307
- def description (self ) -> List [ColumnDescription ]:
308
- if self ._query .columns is None :
314
+ def description (self ) -> Optional [ List [Tuple [ Any , ...]] ]:
315
+ if self ._query is None or self . _query .columns is None :
309
316
return None
310
317
311
318
# [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
@@ -314,7 +321,7 @@ def description(self) -> List[ColumnDescription]:
314
321
]
315
322
316
323
@property
317
- def rowcount (self ):
324
+ def rowcount (self ) -> int :
318
325
"""Not supported.
319
326
320
327
Trino cannot reliablity determine the number of rows returned by an
@@ -325,27 +332,21 @@ def rowcount(self):
325
332
return - 1
326
333
327
334
@property
328
- def stats (self ):
335
+ def stats (self ) -> Optional [ Dict [ Any , Any ]] :
329
336
if self ._query is not None :
330
337
return self ._query .stats
331
338
return None
332
339
333
340
@property
334
- def query_id (self ) -> Optional [str ]:
335
- if self ._query is not None :
336
- return self ._query .query_id
337
- return None
338
-
339
- @property
340
- def warnings (self ):
341
+ def warnings (self ) -> Optional [List [Dict [Any , Any ]]]:
341
342
if self ._query is not None :
342
343
return self ._query .warnings
343
344
return None
344
345
345
- def setinputsizes (self , sizes ) :
346
+ def setinputsizes (self , sizes : Sequence [ Any ]) -> None :
346
347
raise trino .exceptions .NotSupportedError
347
348
348
- def setoutputsize (self , size , column ) :
349
+ def setoutputsize (self , size : int , column : Optional [ int ]) -> None :
349
350
raise trino .exceptions .NotSupportedError
350
351
351
352
def _prepare_statement (self , statement : str , name : str ) -> None :
@@ -363,13 +364,13 @@ def _prepare_statement(self, statement: str, name: str) -> None:
363
364
364
365
def _execute_prepared_statement (
365
366
self ,
366
- statement_name ,
367
- params
368
- ):
367
+ statement_name : str ,
368
+ params : Any
369
+ ) -> trino . client . TrinoQuery :
369
370
sql = 'EXECUTE ' + statement_name + ' USING ' + ',' .join (map (self ._format_prepared_param , params ))
370
371
return trino .client .TrinoQuery (self ._request , sql = sql , legacy_primitive_types = self ._legacy_primitive_types )
371
372
372
- def _format_prepared_param (self , param ) :
373
+ def _format_prepared_param (self , param : Any ) -> str :
373
374
"""
374
375
Formats parameters to be passed in an
375
376
EXECUTE statement.
@@ -451,10 +452,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
451
452
legacy_primitive_types = self ._legacy_primitive_types )
452
453
query .execute ()
453
454
454
- def _generate_unique_statement_name (self ):
455
+ def _generate_unique_statement_name (self ) -> str :
455
456
return 'st_' + uuid .uuid4 ().hex .replace ('-' , '' )
456
457
457
- def execute (self , operation , params = None ):
458
+ def execute (self , operation : str , params : Optional [ Any ] = None ) -> trino . client . TrinoResult :
458
459
if params :
459
460
assert isinstance (params , (list , tuple )), (
460
461
'params must be a list or tuple containing the query '
@@ -484,7 +485,7 @@ def execute(self, operation, params=None):
484
485
self ._iterator = iter (self ._query .execute ())
485
486
return self
486
487
487
- def executemany (self , operation , seq_of_params ) :
488
+ def executemany (self , operation : str , seq_of_params : Any ) -> None :
488
489
"""
489
490
PEP-0249: Prepare a database operation (query or command) and then
490
491
execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
@@ -503,6 +504,7 @@ def executemany(self, operation, seq_of_params):
503
504
for parameters in seq_of_params [:- 1 ]:
504
505
self .execute (operation , parameters )
505
506
self .fetchall ()
507
+ assert self ._query is not None
506
508
if self ._query .update_type is None :
507
509
raise NotSupportedError ("Query must return update type" )
508
510
if seq_of_params :
@@ -529,7 +531,7 @@ def fetchone(self) -> Optional[List[Any]]:
529
531
except trino .exceptions .HttpError as err :
530
532
raise trino .exceptions .OperationalError (str (err ))
531
533
532
- def fetchmany (self , size = None ) -> List [List [Any ]]:
534
+ def fetchmany (self , size : Optional [ int ] = None ) -> List [List [Any ]]:
533
535
"""
534
536
PEP-0249: Fetch the next set of rows of a query result, returning a
535
537
sequence of sequences (e.g. a list of tuples). An empty sequence is
@@ -562,6 +564,7 @@ def fetchmany(self, size=None) -> List[List[Any]]:
562
564
563
565
return result
564
566
567
+ < << << << HEAD
565
568
def describe (self , sql : str ) -> List [DescribeOutput ]:
566
569
"""
567
570
List the output columns of a SQL statement, including the column name (or alias), catalog, schema, table, type,
@@ -584,66 +587,20 @@ def describe(self, sql: str) -> List[DescribeOutput]:
584
587
585
588
return list (map (lambda x : DescribeOutput .from_row (x ), result ))
586
589
587
- def genall (self ):
590
+ def genall (self ) -> trino . client . TrinoResult :
588
591
return self ._query .result
589
592
590
593
def fetchall (self ) -> List [List [Any ]]:
591
594
return list (self .genall ())
592
595
593
- def cancel (self ):
596
+ def cancel (self ) -> None :
594
597
if self ._query is None :
595
598
raise trino .exceptions .OperationalError (
596
599
"Cancel query failed; no running query"
597
600
)
598
601
self ._query .cancel ()
599
602
600
- def close (self ):
603
+ def close (self ) -> None :
601
604
self .cancel ()
602
605
# TODO: Cancel not only the last query executed on this cursor
603
606
# but also any other outstanding queries executed through this cursor.
604
-
605
-
606
- Date = datetime .date
607
- Time = datetime .time
608
- Timestamp = datetime .datetime
609
- DateFromTicks = datetime .date .fromtimestamp
610
- TimestampFromTicks = datetime .datetime .fromtimestamp
611
-
612
-
613
- def TimeFromTicks (ticks ):
614
- return datetime .time (* datetime .localtime (ticks )[3 :6 ])
615
-
616
-
617
- def Binary (string ):
618
- return string .encode ("utf-8" )
619
-
620
-
621
- class DBAPITypeObject :
622
- def __init__ (self , * values ):
623
- self .values = [v .lower () for v in values ]
624
-
625
- def __eq__ (self , other ):
626
- return other .lower () in self .values
627
-
628
-
629
- STRING = DBAPITypeObject ("VARCHAR" , "CHAR" , "VARBINARY" , "JSON" , "IPADDRESS" )
630
-
631
- BINARY = DBAPITypeObject (
632
- "ARRAY" , "MAP" , "ROW" , "HyperLogLog" , "P4HyperLogLog" , "QDigest"
633
- )
634
-
635
- NUMBER = DBAPITypeObject (
636
- "BOOLEAN" , "TINYINT" , "SMALLINT" , "INTEGER" , "BIGINT" , "REAL" , "DOUBLE" , "DECIMAL"
637
- )
638
-
639
- DATETIME = DBAPITypeObject (
640
- "DATE" ,
641
- "TIME" ,
642
- "TIME WITH TIME ZONE" ,
643
- "TIMESTAMP" ,
644
- "TIMESTAMP WITH TIME ZONE" ,
645
- "INTERVAL YEAR TO MONTH" ,
646
- "INTERVAL DAY TO SECOND" ,
647
- )
648
-
649
- ROWID = DBAPITypeObject () # nothing indicates row id in Trino
0 commit comments