diff --git a/sqlalchemy_bigquery/__init__.py b/sqlalchemy_bigquery/__init__.py index 1e506125..567015ee 100644 --- a/sqlalchemy_bigquery/__init__.py +++ b/sqlalchemy_bigquery/__init__.py @@ -37,6 +37,7 @@ FLOAT64, INT64, INTEGER, + JSON, NUMERIC, RECORD, STRING, @@ -74,6 +75,7 @@ "FLOAT64", "INT64", "INTEGER", + "JSON", "NUMERIC", "RECORD", "STRING", diff --git a/sqlalchemy_bigquery/_json.py b/sqlalchemy_bigquery/_json.py new file mode 100644 index 00000000..ff800dc9 --- /dev/null +++ b/sqlalchemy_bigquery/_json.py @@ -0,0 +1,106 @@ +from enum import auto, Enum +import sqlalchemy +from sqlalchemy.sql import sqltypes + + +class _FormatTypeMixin: + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSON(sqltypes.JSON): + def bind_expression(self, bindvalue): + # JSON query parameters are STRINGs + return sqlalchemy.func.PARSE_JSON(bindvalue, type_=self) + + def literal_processor(self, dialect): + super_proc = self.bind_processor(dialect) + + def process(value): + value = super_proc(value) + return repr(value) + + return process + + class Comparator(sqltypes.JSON.Comparator): + def _generate_converter(self, name, lax): + prefix = "LAX_" if lax else "" + func_ = getattr(sqlalchemy.func, f"{prefix}{name}") + return func_ + + def as_boolean(self, lax=False): + func_ = self._generate_converter("BOOL", lax) + return func_(self.expr, type_=sqltypes.Boolean) + + def as_string(self, lax=False): + func_ = self._generate_converter("STRING", lax) + return func_(self.expr, type_=sqltypes.String) + + def as_integer(self, lax=False): + func_ = self._generate_converter("INT64", lax) + return func_(self.expr, type_=sqltypes.Integer) + + def as_float(self, lax=False): + func_ = self._generate_converter("FLOAT64", lax) + return func_(self.expr, type_=sqltypes.Float) + + def as_numeric(self, precision, scale, asdecimal=True): + # No converter available in BigQuery + raise NotImplementedError() + + comparator_factory = Comparator + + class JSONPathMode(Enum): + LAX = auto() + LAX_RECURSIVE = auto() + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _mode_prefix(self, mode): + if mode == JSON.JSONPathMode.LAX: + mode_prefix = "lax" + elif mode == JSON.JSONPathMode.LAX_RECURSIVE: + mode_prefix = "lax recursive" + else: + raise NotImplementedError(f"Unhandled JSONPathMode: {mode}") + return mode_prefix + + def _format_value(self, value): + if isinstance(value[0], JSON.JSONPathMode): + mode = value[0] + mode_prefix = self._mode_prefix(mode) + value = value[1:] + else: + mode_prefix = "" + + return "%s$%s" % ( + mode_prefix + " " if mode_prefix else "", + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ), + ) diff --git a/sqlalchemy_bigquery/_types.py b/sqlalchemy_bigquery/_types.py index 8399e978..6a268ce9 100644 --- a/sqlalchemy_bigquery/_types.py +++ b/sqlalchemy_bigquery/_types.py @@ -27,6 +27,7 @@ except ImportError: # pragma: NO COVER pass +from ._json import JSON from ._struct import STRUCT _type_map = { @@ -41,6 +42,7 @@ "FLOAT": sqlalchemy.types.Float, "INT64": sqlalchemy.types.Integer, "INTEGER": sqlalchemy.types.Integer, + "JSON": JSON, "NUMERIC": sqlalchemy.types.Numeric, "RECORD": STRUCT, "STRING": sqlalchemy.types.String, @@ -61,6 +63,7 @@ FLOAT = _type_map["FLOAT"] INT64 = _type_map["INT64"] INTEGER = _type_map["INTEGER"] +JSON = _type_map["JSON"] NUMERIC = _type_map["NUMERIC"] RECORD = _type_map["RECORD"] STRING = _type_map["STRING"] diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index c36ca1b1..91168d53 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -59,7 +59,7 @@ import re from .parse_url import parse_url -from . import _helpers, _struct, _types +from . import _helpers, _json, _struct, _types import sqlalchemy_bigquery_vendored.sqlalchemy.postgresql.base as vendored_postgresql # Illegal characters is intended to be all characters that are not explicitly @@ -547,6 +547,13 @@ def visit_bindparam( bq_type = self.dialect.type_compiler.process(type_) bq_type = self.__remove_type_parameter(bq_type) + if bq_type == "JSON": + # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI + # For now, we hack around this by: + # - Rewriting the bindparam type to STRING + # - Applying a bind expression that converts the parameter back to JSON + bq_type = "STRING" + assert_(param != "%s", f"Unexpected param: {param}") if bindparam.expanding: # pragma: NO COVER @@ -571,6 +578,12 @@ def visit_getitem_binary(self, binary, operator_, **kw): right = self.process(binary.right, **kw) return f"{left}[OFFSET({right})]" + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return "JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + def _get_regexp_args(self, binary, kw): string = self.process(binary.left, **kw) pattern = self.process(binary.right, **kw) @@ -641,6 +654,12 @@ def visit_NUMERIC(self, type_, **kw): visit_DECIMAL = visit_NUMERIC + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_json_path(self, type_, **kw): + return "STRING" + class BigQueryDDLCompiler(DDLCompiler): option_datatype_mapping = { @@ -1076,6 +1095,8 @@ class BigQueryDialect(DefaultDialect): sqlalchemy.sql.sqltypes.TIMESTAMP: BQTimestamp, sqlalchemy.sql.sqltypes.ARRAY: BQArray, sqlalchemy.sql.sqltypes.Enum: sqlalchemy.sql.sqltypes.Enum, + sqlalchemy.sql.sqltypes.JSON: _json.JSON, + sqlalchemy.sql.sqltypes.JSON.JSONPathType: _json.JSONPathType, } def __init__( @@ -1086,6 +1107,8 @@ def __init__( credentials_info=None, credentials_base64=None, list_tables_page_size=1000, + json_serializer=None, + json_deserializer=None, *args, **kwargs, ): @@ -1098,6 +1121,8 @@ def __init__( self.identifier_preparer = self.preparer(self) self.dataset_id = None self.list_tables_page_size = list_tables_page_size + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer @classmethod def dbapi(cls): diff --git a/tests/unit/test__json.py b/tests/unit/test__json.py new file mode 100644 index 00000000..fcc3a9d7 --- /dev/null +++ b/tests/unit/test__json.py @@ -0,0 +1,189 @@ +import json +import pytest + +import sqlalchemy + + +@pytest.fixture +def json_table(metadata): + from sqlalchemy_bigquery import JSON + + return sqlalchemy.Table( + "json_table", + metadata, + sqlalchemy.Column("cart", JSON), + ) + + +@pytest.fixture +def json_column(json_table): + return json_table.c.cart + + +@pytest.fixture +def json_data(): + return { + "name": "Alice", + "items": [{"product": "book", "price": 10}, {"product": "food", "price": 5}], + } + + +def test_select_json(faux_conn, json_table, json_data): + faux_conn.ex("create table json_table (cart JSON)") + faux_conn.ex(f"insert into json_table values ('{json.dumps(json_data)}')") + + row = list(faux_conn.execute(sqlalchemy.select(json_table)))[0] + assert row.cart == json_data + + +def test_insert_json(faux_conn, metadata, json_table, json_data): + actual = str(json_table.insert().values(cart=json_data).compile(faux_conn.engine)) + + assert ( + actual + == "INSERT INTO `json_table` (`cart`) VALUES (PARSE_JSON(%(cart:STRING)s))" + ) + + +@pytest.mark.parametrize( + "path,literal_sql", + ( + ( + ["name"], + "JSON_QUERY(`json_table`.`cart`, '$.\"name\"')", + ), + ( + ["items", 0], + "JSON_QUERY(`json_table`.`cart`, '$.\"items\"[0]')", + ), + ( + ["items", 0, "price"], + 'JSON_QUERY(`json_table`.`cart`, \'$."items"[0]."price"\')', + ), + ), +) +def test_json_query(faux_conn, json_column, path, literal_sql): + expr = sqlalchemy.select(json_column[path]) + + expected_sql = ( + "SELECT JSON_QUERY(`json_table`.`cart`, %(cart_1:STRING)s) AS `anon_1` \n" + "FROM `json_table`" + ) + expected_literal_sql = f"SELECT {literal_sql} AS `anon_1` \nFROM `json_table`" + + actual_sql = expr.compile(faux_conn).string + actual_literal_sql = expr.compile( + faux_conn, compile_kwargs={"literal_binds": True} + ).string + + assert expected_sql == actual_sql + assert expected_literal_sql == actual_literal_sql + + +def test_json_value(faux_conn, json_column, json_data): + expr = sqlalchemy.select(json_column[["items", 0]].label("first_item")).where( + sqlalchemy.func.JSON_VALUE(json_column[["name"]]) == "Alice" + ) + + expected_sql = ( + "SELECT JSON_QUERY(`json_table`.`cart`, %(cart_1:STRING)s) AS `first_item` \n" + "FROM `json_table` \n" + "WHERE JSON_VALUE(JSON_QUERY(`json_table`.`cart`, %(cart_2:STRING)s)) = %(JSON_VALUE_1:STRING)s" + ) + expected_literal_sql = ( + "SELECT JSON_QUERY(`json_table`.`cart`, '$.\"items\"[0]') AS `first_item` \n" + "FROM `json_table` \n" + "WHERE JSON_VALUE(JSON_QUERY(`json_table`.`cart`, '$.\"name\"')) = 'Alice'" + ) + + actual_sql = expr.compile(faux_conn).string + actual_literal_sql = expr.compile( + faux_conn, compile_kwargs={"literal_binds": True} + ).string + + assert expected_sql == actual_sql + assert expected_literal_sql == actual_literal_sql + + +def test_json_literal(faux_conn): + from sqlalchemy_bigquery import JSON + + expr = sqlalchemy.select( + sqlalchemy.func.STRING( + sqlalchemy.sql.expression.literal("purple", type_=JSON) + ).label("color") + ) + + expected_sql = "SELECT STRING(PARSE_JSON(%(param_1:STRING)s)) AS `color`" + expected_literal_sql = "SELECT STRING(PARSE_JSON('\"purple\"')) AS `color`" + + actual_sql = expr.compile(faux_conn).string + actual_literal_sql = expr.compile( + faux_conn, compile_kwargs={"literal_binds": True} + ).string + + assert expected_sql == actual_sql + assert expected_literal_sql == actual_literal_sql + + +@pytest.mark.parametrize("lax,prefix", ((False, ""), (True, "LAX_"))) +def test_json_casts(faux_conn, json_column, json_data, lax, prefix): + from sqlalchemy_bigquery import JSON + + expr = sqlalchemy.select(1).where( + json_column[["name"]].as_string(lax=lax) == "Alice" + ) + assert expr.compile(faux_conn, compile_kwargs={"literal_binds": True}).string == ( + "SELECT 1 \n" + "FROM `json_table` \n" + f"WHERE {prefix}STRING(JSON_QUERY(`json_table`.`cart`, '$.\"name\"')) = 'Alice'" + ) + + expr = sqlalchemy.select(1).where( + json_column[["items", 1, "price"]].as_integer(lax=lax) == 10 + ) + assert expr.compile(faux_conn, compile_kwargs={"literal_binds": True}).string == ( + "SELECT 1 \n" + "FROM `json_table` \n" + f'WHERE {prefix}INT64(JSON_QUERY(`json_table`.`cart`, \'$."items"[1]."price"\')) = 10' + ) + + expr = sqlalchemy.select( + sqlalchemy.literal(10.0, type_=JSON).as_float(lax=lax) == 10.0 + ) + assert expr.compile(faux_conn, compile_kwargs={"literal_binds": True}).string == ( + f"SELECT {prefix}FLOAT64(PARSE_JSON('10.0')) = 10.0 AS `anon_1`" + ) + + expr = sqlalchemy.select( + sqlalchemy.literal(True, type_=JSON).as_boolean(lax=lax) == sqlalchemy.true() + ) + assert expr.compile(faux_conn, compile_kwargs={"literal_binds": True}).string == ( + f"SELECT {prefix}BOOL(PARSE_JSON('true')) = true AS `anon_1`" + ) + + +@pytest.mark.parametrize( + "mode,prefix", ((None, ""), ("LAX", "lax "), ("LAX_RECURSIVE", "lax recursive ")) +) +def test_json_path_mode(faux_conn, json_column, mode, prefix): + from sqlalchemy_bigquery import JSON + + if mode == "LAX": + path = [JSON.JSONPathMode.LAX, "items", "price"] + elif mode == "LAX_RECURSIVE": + path = [JSON.JSONPathMode.LAX_RECURSIVE, "items", "price"] + else: + path = ["items", "price"] + + expr = sqlalchemy.select(json_column[path]) + + expected_literal_sql = ( + f'SELECT JSON_QUERY(`json_table`.`cart`, \'{prefix}$."items"."price"\') AS `anon_1` \n' + "FROM `json_table`" + ) + actual_literal_sql = expr.compile( + faux_conn, compile_kwargs={"literal_binds": True} + ).string + + assert expected_literal_sql == actual_literal_sql diff --git a/tests/unit/test_engine.py b/tests/unit/test_engine.py index 59481baa..f9e6eb3d 100644 --- a/tests/unit/test_engine.py +++ b/tests/unit/test_engine.py @@ -16,6 +16,8 @@ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +import json +from unittest import mock import pytest import sqlalchemy @@ -64,3 +66,34 @@ def test_arraysize_querystring_takes_precedence_over_default(faux_conn, metadata metadata.create_all(engine) assert conn.connection.test_data["arraysize"] == arraysize + + +def test_set_json_serde(faux_conn, metadata): + from sqlalchemy_bigquery import JSON + + json_serializer = mock.Mock(side_effect=json.dumps) + json_deserializer = mock.Mock(side_effect=json.loads) + + engine = sqlalchemy.create_engine( + "bigquery://myproject/mydataset", + json_serializer=json_serializer, + json_deserializer=json_deserializer, + ) + + json_data = {"foo": "bar"} + json_table = sqlalchemy.Table( + "json_table", metadata, sqlalchemy.Column("json", JSON) + ) + + metadata.create_all(engine) + faux_conn.ex(f"insert into json_table values ('{json.dumps(json_data)}')") + + with engine.begin() as conn: + row = conn.execute(sqlalchemy.select(json_table.c.json)).first() + assert row == (json_data,) + assert json_deserializer.mock_calls == [mock.call(json.dumps(json_data))] + + expr = sqlalchemy.select(sqlalchemy.literal(json_data, type_=JSON)) + literal_sql = expr.compile(engine, compile_kwargs={"literal_binds": True}).string + assert literal_sql == f"SELECT PARSE_JSON('{json.dumps(json_data)}') AS `anon_1`" + assert json_serializer.mock_calls == [mock.call(json_data)]