1- import datetime
2- import decimal
3- import importlib
4- import logging
5- import os
6- import re
7- import sqlalchemy
8- import sqlite3
9- import sqlparse
10- import sys
11- import termcolor
12- import warnings
13-
14-
151class SQL (object ):
162 """Wrap SQLAlchemy to provide a simple SQL API."""
173
@@ -25,6 +11,13 @@ def __init__(self, url, **kwargs):
2511 http://docs.sqlalchemy.org/en/latest/dialects/index.html
2612 """
2713
14+ # Lazily import
15+ import logging
16+ import os
17+ import re
18+ import sqlalchemy
19+ import sqlite3
20+
2821 # Get logger
2922 self ._logger = logging .getLogger ("cs50" )
3023
@@ -74,6 +67,14 @@ def connect(dbapi_connection, connection_record):
7467 def execute (self , sql , * args , ** kwargs ):
7568 """Execute a SQL statement."""
7669
70+ # Lazily import
71+ import decimal
72+ import re
73+ import sqlalchemy
74+ import sqlparse
75+ import termcolor
76+ import warnings
77+
7778 # Allow only one statement at a time, since SQLite doesn't support multiple
7879 # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
7980 statements = sqlparse .parse (sql )
@@ -212,65 +213,82 @@ def execute(self, sql, *args, **kwargs):
212213 "value" if len (keys ) == 1 else "values" ,
213214 ", " .join (keys )))
214215
215- # Join tokens into statement
216- statement = "" .join ([str (token ) for token in tokens ])
216+ # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
217+ # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
218+ for index , token in enumerate (tokens ):
217219
218- # Raise exceptions for warnings
219- warnings .filterwarnings ("error" )
220+ # In string literal
221+ # https://www.sqlite.org/lang_keywords.html
222+ if token .ttype in [sqlparse .tokens .Literal .String , sqlparse .tokens .Literal .String .Single ]:
223+ token .value = re .sub ("(^'|\s+):" , r"\1\:" , token .value )
220224
221- # Prepare, execute statement
222- try :
225+ # In identifier
226+ # https://www.sqlite.org/lang_keywords.html
227+ elif token .ttype == sqlparse .tokens .Literal .String .Symbol :
228+ token .value = re .sub ("(^\" |\s+):" , r"\1\:" , token .value )
223229
224- # Execute statement
225- result = self . engine . execute ( sqlalchemy . text ( statement ) )
230+ # Join tokens into statement
231+ statement = "" . join ([ str ( token ) for token in tokens ] )
226232
227- # Return value
228- ret = True
229- if tokens [0 ].ttype == sqlparse .tokens .Keyword .DML :
230-
231- # Uppercase token's value
232- value = tokens [0 ].value .upper ()
233-
234- # If SELECT, return result set as list of dict objects
235- if value == "SELECT" :
236-
237- # Coerce any decimal.Decimal objects to float objects
238- # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
239- rows = [dict (row ) for row in result .fetchall ()]
240- for row in rows :
241- for column in row :
242- if type (row [column ]) is decimal .Decimal :
243- row [column ] = float (row [column ])
244- ret = rows
245-
246- # If INSERT, return primary key value for a newly inserted row
247- elif value == "INSERT" :
248- if self .engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
249- result = self .engine .execute ("SELECT LASTVAL()" )
250- ret = result .first ()[0 ]
251- else :
252- ret = result .lastrowid
253-
254- # If DELETE or UPDATE, return number of rows matched
255- elif value in ["DELETE" , "UPDATE" ]:
256- ret = result .rowcount
257-
258- # If constraint violated, return None
259- except sqlalchemy .exc .IntegrityError :
260- self ._logger .debug (termcolor .colored (statement , "yellow" ))
261- return None
262-
263- # If user errror
264- except sqlalchemy .exc .OperationalError as e :
265- self ._logger .debug (termcolor .colored (statement , "red" ))
266- e = RuntimeError (_parse_exception (e ))
267- e .__cause__ = None
268- raise e
233+ # Catch SQLAlchemy warnings
234+ with warnings .catch_warnings ():
235+
236+ # Raise exceptions for warnings
237+ warnings .simplefilter ("error" )
238+
239+ # Prepare, execute statement
240+ try :
241+
242+ # Execute statement
243+ result = self .engine .execute (sqlalchemy .text (statement ))
244+
245+ # Return value
246+ ret = True
247+ if tokens [0 ].ttype == sqlparse .tokens .Keyword .DML :
248+
249+ # Uppercase token's value
250+ value = tokens [0 ].value .upper ()
251+
252+ # If SELECT, return result set as list of dict objects
253+ if value == "SELECT" :
254+
255+ # Coerce any decimal.Decimal objects to float objects
256+ # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
257+ rows = [dict (row ) for row in result .fetchall ()]
258+ for row in rows :
259+ for column in row :
260+ if type (row [column ]) is decimal .Decimal :
261+ row [column ] = float (row [column ])
262+ ret = rows
263+
264+ # If INSERT, return primary key value for a newly inserted row
265+ elif value == "INSERT" :
266+ if self .engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
267+ result = self .engine .execute ("SELECT LASTVAL()" )
268+ ret = result .first ()[0 ]
269+ else :
270+ ret = result .lastrowid
271+
272+ # If DELETE or UPDATE, return number of rows matched
273+ elif value in ["DELETE" , "UPDATE" ]:
274+ ret = result .rowcount
275+
276+ # If constraint violated, return None
277+ except sqlalchemy .exc .IntegrityError :
278+ self ._logger .debug (termcolor .colored (statement , "yellow" ))
279+ return None
280+
281+ # If user errror
282+ except sqlalchemy .exc .OperationalError as e :
283+ self ._logger .debug (termcolor .colored (statement , "red" ))
284+ e = RuntimeError (_parse_exception (e ))
285+ e .__cause__ = None
286+ raise e
269287
270- # Return value
271- else :
272- self ._logger .debug (termcolor .colored (statement , "green" ))
273- return ret
288+ # Return value
289+ else :
290+ self ._logger .debug (termcolor .colored (statement , "green" ))
291+ return ret
274292
275293 def _escape (self , value ):
276294 """
@@ -279,8 +297,15 @@ def _escape(self, value):
279297 https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
280298 """
281299
300+ # Lazily import
301+ import sqlparse
302+
282303 def __escape (value ):
283304
305+ # Lazily import
306+ import datetime
307+ import sqlalchemy
308+
284309 # bool
285310 if type (value ) is bool :
286311 return sqlparse .sql .Token (
@@ -349,6 +374,9 @@ def __escape(value):
349374def _parse_exception (e ):
350375 """Parses an exception, returns its message."""
351376
377+ # Lazily import
378+ import re
379+
352380 # MySQL
353381 matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
354382 if matches :
@@ -371,6 +399,10 @@ def _parse_exception(e):
371399def _parse_placeholder (token ):
372400 """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
373401
402+ # Lazily load
403+ import re
404+ import sqlparse
405+
374406 # Validate token
375407 if not isinstance (token , sqlparse .sql .Token ) or token .ttype != sqlparse .tokens .Name .Placeholder :
376408 raise TypeError ()
0 commit comments