1
- import datetime
2
- import decimal
3
- import importlib
4
- import logging
5
- import os
6
- import re
7
- import sqlalchemy
8
- import sqlite3
9
- import sqlparse
10
- import sys
11
- import termcolor
12
- import warnings
13
-
14
-
15
1
class SQL (object ):
16
2
"""Wrap SQLAlchemy to provide a simple SQL API."""
17
3
@@ -25,6 +11,13 @@ def __init__(self, url, **kwargs):
25
11
http://docs.sqlalchemy.org/en/latest/dialects/index.html
26
12
"""
27
13
14
+ # Lazily import
15
+ import logging
16
+ import os
17
+ import re
18
+ import sqlalchemy
19
+ import sqlite3
20
+
28
21
# Get logger
29
22
self ._logger = logging .getLogger ("cs50" )
30
23
@@ -74,6 +67,14 @@ def connect(dbapi_connection, connection_record):
74
67
def execute (self , sql , * args , ** kwargs ):
75
68
"""Execute a SQL statement."""
76
69
70
+ # Lazily import
71
+ import decimal
72
+ import re
73
+ import sqlalchemy
74
+ import sqlparse
75
+ import termcolor
76
+ import warnings
77
+
77
78
# Allow only one statement at a time, since SQLite doesn't support multiple
78
79
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
79
80
statements = sqlparse .parse (sql )
@@ -212,65 +213,82 @@ def execute(self, sql, *args, **kwargs):
212
213
"value" if len (keys ) == 1 else "values" ,
213
214
", " .join (keys )))
214
215
215
- # Join tokens into statement
216
- statement = "" .join ([str (token ) for token in tokens ])
216
+ # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
217
+ # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
218
+ for index , token in enumerate (tokens ):
217
219
218
- # Raise exceptions for warnings
219
- warnings .filterwarnings ("error" )
220
+ # In string literal
221
+ # https://www.sqlite.org/lang_keywords.html
222
+ if token .ttype in [sqlparse .tokens .Literal .String , sqlparse .tokens .Literal .String .Single ]:
223
+ token .value = re .sub ("(^'|\s+):" , r"\1\:" , token .value )
220
224
221
- # Prepare, execute statement
222
- try :
225
+ # In identifier
226
+ # https://www.sqlite.org/lang_keywords.html
227
+ elif token .ttype == sqlparse .tokens .Literal .String .Symbol :
228
+ token .value = re .sub ("(^\" |\s+):" , r"\1\:" , token .value )
223
229
224
- # Execute statement
225
- result = self . engine . execute ( sqlalchemy . text ( statement ) )
230
+ # Join tokens into statement
231
+ statement = "" . join ([ str ( token ) for token in tokens ] )
226
232
227
- # Return value
228
- ret = True
229
- if tokens [0 ].ttype == sqlparse .tokens .Keyword .DML :
230
-
231
- # Uppercase token's value
232
- value = tokens [0 ].value .upper ()
233
-
234
- # If SELECT, return result set as list of dict objects
235
- if value == "SELECT" :
236
-
237
- # Coerce any decimal.Decimal objects to float objects
238
- # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
239
- rows = [dict (row ) for row in result .fetchall ()]
240
- for row in rows :
241
- for column in row :
242
- if type (row [column ]) is decimal .Decimal :
243
- row [column ] = float (row [column ])
244
- ret = rows
245
-
246
- # If INSERT, return primary key value for a newly inserted row
247
- elif value == "INSERT" :
248
- if self .engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
249
- result = self .engine .execute ("SELECT LASTVAL()" )
250
- ret = result .first ()[0 ]
251
- else :
252
- ret = result .lastrowid
253
-
254
- # If DELETE or UPDATE, return number of rows matched
255
- elif value in ["DELETE" , "UPDATE" ]:
256
- ret = result .rowcount
257
-
258
- # If constraint violated, return None
259
- except sqlalchemy .exc .IntegrityError :
260
- self ._logger .debug (termcolor .colored (statement , "yellow" ))
261
- return None
262
-
263
- # If user errror
264
- except sqlalchemy .exc .OperationalError as e :
265
- self ._logger .debug (termcolor .colored (statement , "red" ))
266
- e = RuntimeError (_parse_exception (e ))
267
- e .__cause__ = None
268
- raise e
233
+ # Catch SQLAlchemy warnings
234
+ with warnings .catch_warnings ():
235
+
236
+ # Raise exceptions for warnings
237
+ warnings .simplefilter ("error" )
238
+
239
+ # Prepare, execute statement
240
+ try :
241
+
242
+ # Execute statement
243
+ result = self .engine .execute (sqlalchemy .text (statement ))
244
+
245
+ # Return value
246
+ ret = True
247
+ if tokens [0 ].ttype == sqlparse .tokens .Keyword .DML :
248
+
249
+ # Uppercase token's value
250
+ value = tokens [0 ].value .upper ()
251
+
252
+ # If SELECT, return result set as list of dict objects
253
+ if value == "SELECT" :
254
+
255
+ # Coerce any decimal.Decimal objects to float objects
256
+ # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
257
+ rows = [dict (row ) for row in result .fetchall ()]
258
+ for row in rows :
259
+ for column in row :
260
+ if type (row [column ]) is decimal .Decimal :
261
+ row [column ] = float (row [column ])
262
+ ret = rows
263
+
264
+ # If INSERT, return primary key value for a newly inserted row
265
+ elif value == "INSERT" :
266
+ if self .engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
267
+ result = self .engine .execute ("SELECT LASTVAL()" )
268
+ ret = result .first ()[0 ]
269
+ else :
270
+ ret = result .lastrowid
271
+
272
+ # If DELETE or UPDATE, return number of rows matched
273
+ elif value in ["DELETE" , "UPDATE" ]:
274
+ ret = result .rowcount
275
+
276
+ # If constraint violated, return None
277
+ except sqlalchemy .exc .IntegrityError :
278
+ self ._logger .debug (termcolor .colored (statement , "yellow" ))
279
+ return None
280
+
281
+ # If user errror
282
+ except sqlalchemy .exc .OperationalError as e :
283
+ self ._logger .debug (termcolor .colored (statement , "red" ))
284
+ e = RuntimeError (_parse_exception (e ))
285
+ e .__cause__ = None
286
+ raise e
269
287
270
- # Return value
271
- else :
272
- self ._logger .debug (termcolor .colored (statement , "green" ))
273
- return ret
288
+ # Return value
289
+ else :
290
+ self ._logger .debug (termcolor .colored (statement , "green" ))
291
+ return ret
274
292
275
293
def _escape (self , value ):
276
294
"""
@@ -279,8 +297,15 @@ def _escape(self, value):
279
297
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
280
298
"""
281
299
300
+ # Lazily import
301
+ import sqlparse
302
+
282
303
def __escape (value ):
283
304
305
+ # Lazily import
306
+ import datetime
307
+ import sqlalchemy
308
+
284
309
# bool
285
310
if type (value ) is bool :
286
311
return sqlparse .sql .Token (
@@ -349,6 +374,9 @@ def __escape(value):
349
374
def _parse_exception (e ):
350
375
"""Parses an exception, returns its message."""
351
376
377
+ # Lazily import
378
+ import re
379
+
352
380
# MySQL
353
381
matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
354
382
if matches :
@@ -371,6 +399,10 @@ def _parse_exception(e):
371
399
def _parse_placeholder (token ):
372
400
"""Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
373
401
402
+ # Lazily load
403
+ import re
404
+ import sqlparse
405
+
374
406
# Validate token
375
407
if not isinstance (token , sqlparse .sql .Token ) or token .ttype != sqlparse .tokens .Name .Placeholder :
376
408
raise TypeError ()
0 commit comments