Skip to content

Commit af8dd92

Browse files
author
Kareem Zidane
authored
Merge pull request #93 from cs50/develop
Fixes colons in strings, catches SQLAlchemy warnings
2 parents e31b2e7 + 72a1706 commit af8dd92

File tree

6 files changed

+125
-73
lines changed

6 files changed

+125
-73
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="4.0.2"
19+
version="4.0.3"
2020
)

src/cs50/flask.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22

33
from distutils.version import StrictVersion
4-
from os import getenv
54
from pkg_resources import get_distribution
65

76
from .cs50 import _formatException
@@ -44,7 +43,8 @@ def _execute_after(*args, **kwargs):
4443

4544
# When behind CS50 IDE's proxy, ensure that flask.redirect doesn't redirect from HTTPS to HTTP
4645
# https://werkzeug.palletsprojects.com/en/0.15.x/middleware/proxy_fix/#module-werkzeug.middleware.proxy_fix
47-
if getenv("C9_HOSTNAME") and not getenv("IDE_OFFLINE"):
46+
from os import getenv
47+
if getenv("CS50_IDE_TYPE") == "online":
4848
try:
4949
import flask
5050
from werkzeug.middleware.proxy_fix import ProxyFix

src/cs50/sql.py

+100-68
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
import datetime
2-
import decimal
3-
import importlib
4-
import logging
5-
import os
6-
import re
7-
import sqlalchemy
8-
import sqlite3
9-
import sqlparse
10-
import sys
11-
import termcolor
12-
import warnings
13-
14-
151
class SQL(object):
162
"""Wrap SQLAlchemy to provide a simple SQL API."""
173

@@ -25,6 +11,13 @@ def __init__(self, url, **kwargs):
2511
http://docs.sqlalchemy.org/en/latest/dialects/index.html
2612
"""
2713

14+
# Lazily import
15+
import logging
16+
import os
17+
import re
18+
import sqlalchemy
19+
import sqlite3
20+
2821
# Get logger
2922
self._logger = logging.getLogger("cs50")
3023

@@ -74,6 +67,14 @@ def connect(dbapi_connection, connection_record):
7467
def execute(self, sql, *args, **kwargs):
7568
"""Execute a SQL statement."""
7669

70+
# Lazily import
71+
import decimal
72+
import re
73+
import sqlalchemy
74+
import sqlparse
75+
import termcolor
76+
import warnings
77+
7778
# Allow only one statement at a time, since SQLite doesn't support multiple
7879
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
7980
statements = sqlparse.parse(sql)
@@ -212,65 +213,82 @@ def execute(self, sql, *args, **kwargs):
212213
"value" if len(keys) == 1 else "values",
213214
", ".join(keys)))
214215

215-
# Join tokens into statement
216-
statement = "".join([str(token) for token in tokens])
216+
# For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
217+
# https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
218+
for index, token in enumerate(tokens):
217219

218-
# Raise exceptions for warnings
219-
warnings.filterwarnings("error")
220+
# In string literal
221+
# https://www.sqlite.org/lang_keywords.html
222+
if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]:
223+
token.value = re.sub("(^'|\s+):", r"\1\:", token.value)
220224

221-
# Prepare, execute statement
222-
try:
225+
# In identifier
226+
# https://www.sqlite.org/lang_keywords.html
227+
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
228+
token.value = re.sub("(^\"|\s+):", r"\1\:", token.value)
223229

224-
# Execute statement
225-
result = self.engine.execute(sqlalchemy.text(statement))
230+
# Join tokens into statement
231+
statement = "".join([str(token) for token in tokens])
226232

227-
# Return value
228-
ret = True
229-
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:
230-
231-
# Uppercase token's value
232-
value = tokens[0].value.upper()
233-
234-
# If SELECT, return result set as list of dict objects
235-
if value == "SELECT":
236-
237-
# Coerce any decimal.Decimal objects to float objects
238-
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
239-
rows = [dict(row) for row in result.fetchall()]
240-
for row in rows:
241-
for column in row:
242-
if type(row[column]) is decimal.Decimal:
243-
row[column] = float(row[column])
244-
ret = rows
245-
246-
# If INSERT, return primary key value for a newly inserted row
247-
elif value == "INSERT":
248-
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
249-
result = self.engine.execute("SELECT LASTVAL()")
250-
ret = result.first()[0]
251-
else:
252-
ret = result.lastrowid
253-
254-
# If DELETE or UPDATE, return number of rows matched
255-
elif value in ["DELETE", "UPDATE"]:
256-
ret = result.rowcount
257-
258-
# If constraint violated, return None
259-
except sqlalchemy.exc.IntegrityError:
260-
self._logger.debug(termcolor.colored(statement, "yellow"))
261-
return None
262-
263-
# If user errror
264-
except sqlalchemy.exc.OperationalError as e:
265-
self._logger.debug(termcolor.colored(statement, "red"))
266-
e = RuntimeError(_parse_exception(e))
267-
e.__cause__ = None
268-
raise e
233+
# Catch SQLAlchemy warnings
234+
with warnings.catch_warnings():
235+
236+
# Raise exceptions for warnings
237+
warnings.simplefilter("error")
238+
239+
# Prepare, execute statement
240+
try:
241+
242+
# Execute statement
243+
result = self.engine.execute(sqlalchemy.text(statement))
244+
245+
# Return value
246+
ret = True
247+
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:
248+
249+
# Uppercase token's value
250+
value = tokens[0].value.upper()
251+
252+
# If SELECT, return result set as list of dict objects
253+
if value == "SELECT":
254+
255+
# Coerce any decimal.Decimal objects to float objects
256+
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
257+
rows = [dict(row) for row in result.fetchall()]
258+
for row in rows:
259+
for column in row:
260+
if type(row[column]) is decimal.Decimal:
261+
row[column] = float(row[column])
262+
ret = rows
263+
264+
# If INSERT, return primary key value for a newly inserted row
265+
elif value == "INSERT":
266+
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
267+
result = self.engine.execute("SELECT LASTVAL()")
268+
ret = result.first()[0]
269+
else:
270+
ret = result.lastrowid
271+
272+
# If DELETE or UPDATE, return number of rows matched
273+
elif value in ["DELETE", "UPDATE"]:
274+
ret = result.rowcount
275+
276+
# If constraint violated, return None
277+
except sqlalchemy.exc.IntegrityError:
278+
self._logger.debug(termcolor.colored(statement, "yellow"))
279+
return None
280+
281+
# If user errror
282+
except sqlalchemy.exc.OperationalError as e:
283+
self._logger.debug(termcolor.colored(statement, "red"))
284+
e = RuntimeError(_parse_exception(e))
285+
e.__cause__ = None
286+
raise e
269287

270-
# Return value
271-
else:
272-
self._logger.debug(termcolor.colored(statement, "green"))
273-
return ret
288+
# Return value
289+
else:
290+
self._logger.debug(termcolor.colored(statement, "green"))
291+
return ret
274292

275293
def _escape(self, value):
276294
"""
@@ -279,8 +297,15 @@ def _escape(self, value):
279297
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
280298
"""
281299

300+
# Lazily import
301+
import sqlparse
302+
282303
def __escape(value):
283304

305+
# Lazily import
306+
import datetime
307+
import sqlalchemy
308+
284309
# bool
285310
if type(value) is bool:
286311
return sqlparse.sql.Token(
@@ -349,6 +374,9 @@ def __escape(value):
349374
def _parse_exception(e):
350375
"""Parses an exception, returns its message."""
351376

377+
# Lazily import
378+
import re
379+
352380
# MySQL
353381
matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e))
354382
if matches:
@@ -371,6 +399,10 @@ def _parse_exception(e):
371399
def _parse_placeholder(token):
372400
"""Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
373401

402+
# Lazily load
403+
import re
404+
import sqlparse
405+
374406
# Validate token
375407
if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder:
376408
raise TypeError()

tests/python.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44

55
import cs50
66

7-
i = cs50.get_int()
8-
print(i)
7+
i = cs50.get_int("Input: ")
8+
print(f"Output: {i}")

tests/sql.py

+18
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,24 @@ def test_update_returns_affected_rows(self):
7575
self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id > 1"), 2)
7676
self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id = -50"), 0)
7777

78+
def test_string_literal_with_colon(self):
79+
rows = [
80+
{"id": 1, "val": ":foo"},
81+
{"id": 2, "val": "foo:bar"},
82+
{"id": 3, "val": " :baz"},
83+
{"id": 3, "val": ":bar :baz"},
84+
{"id": 3, "val": " :bar :baz"}
85+
]
86+
for row in rows:
87+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
88+
89+
self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ':foo'"), [{"val": ":foo"}])
90+
self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ':bar'"), [])
91+
self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = 'foo:bar'"), [{"val": "foo:bar"}])
92+
self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ' :baz'"), [{"val": " :baz"}])
93+
self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ':bar :baz'"), [{"val": ":bar :baz"}])
94+
self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ' :bar :baz'"), [{"val": " :bar :baz"}])
95+
7896
def tearDown(self):
7997
self.db.execute("DROP TABLE cs50")
8098

tests/sqlite.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ["Andrew", "Adams"])
3232
db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ("Andrew", "Adams"))
3333

34+
db.execute("SELECT * FROM Employee WHERE FirstName = ':Andrew :Adams'")
35+
3436
db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", first="Andrew", last="Adams")
3537
db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", {"first": "Andrew", "last": "Adams"})
3638

0 commit comments

Comments
 (0)