diff --git a/lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py b/lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py index 6a6e60e3f0..96644d2464 100644 --- a/lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py +++ b/lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py @@ -551,6 +551,7 @@ def test_window_functions( "Date Sales": "SUM([Group Sales] WITHIN [order_date])", "City Sales": "SUM([Group Sales] AMONG [order_date])", "Total RSUM": 'RSUM([Group Sales], "asc" TOTAL)', + "First Round Count": "FIRST(ROUND(COUNT(1)))", }, ) @@ -567,6 +568,7 @@ def test_window_functions( ds.find_field(title="Date Sales"), ds.find_field(title="City Sales"), ds.find_field(title="Total RSUM"), + ds.find_field(title="First Round Count"), ], order_by=[ ds.find_field(title="order_date"), @@ -585,7 +587,7 @@ def test_window_functions( assert {row[3] for row in data_rows}.issubset({str(i) for i in range(1, cnt + 1)}) # There are as many [Unique Rank of Sales] values as there are rows - assert {row[4] for row in data_rows} == ({str(i) for i in range(1, cnt + 1)}) + assert {row[4] for row in data_rows} == {str(i) for i in range(1, cnt + 1)} # [Rank of City Sales for Date] values are not greater than the number of [City] values assert len({row[5] for row in data_rows}) <= len({row[1] for row in data_rows}) @@ -603,6 +605,8 @@ def test_window_functions( # RSUM = previous RSUM value + value of current arg assert pytest.approx(float(data_rows[i][9])) == float(data_rows[i - 1][9]) + float(data_rows[i][2]) + assert all(float(row[10]) == 1 for row in data_rows) + class DefaultBasicNativeFunctionTestSuite( RegulatedTestCase, DataApiTestBase, DatasetTestBase, DbServiceFixtureTextClass diff --git a/lib/dl_connector_trino/dl_connector_trino/__init__.py b/lib/dl_connector_trino/dl_connector_trino/__init__.py index e69de29bb2..7339b6693d 100644 --- a/lib/dl_connector_trino/dl_connector_trino/__init__.py +++ b/lib/dl_connector_trino/dl_connector_trino/__init__.py @@ -0,0 +1 @@ +from dl_connector_trino import vendor_patches # noqa: F401 diff --git a/lib/dl_connector_trino/dl_connector_trino/vendor_patches.py b/lib/dl_connector_trino/dl_connector_trino/vendor_patches.py new file mode 100644 index 0000000000..8735b11f2e --- /dev/null +++ b/lib/dl_connector_trino/dl_connector_trino/vendor_patches.py @@ -0,0 +1,23 @@ +from typing import Any + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.compiler import SQLCompiler +from trino.sqlalchemy.compiler import TrinoSQLCompiler + + +# This is a temporary patch to fix https://github.com/trinodb/trino-python-client/pull/586 +# BI-6846 +@compiles(TrinoSQLCompiler.FirstValue) +@compiles(TrinoSQLCompiler.LastValue) +@compiles(TrinoSQLCompiler.NthValue) +@compiles(TrinoSQLCompiler.Lead) +@compiles(TrinoSQLCompiler.Lag) +def compile_ignore_nulls( + element: TrinoSQLCompiler.GenericIgnoreNulls, + compiler: SQLCompiler, + **kwargs: Any, +) -> str: + compiled = f"{element.name}({compiler.process(element.clauses, **kwargs)})" + if element.ignore_nulls: + compiled += " IGNORE NULLS" + return compiled diff --git a/metapkg/poetry.lock b/metapkg/poetry.lock index 77343cd223..308cccc703 100644 --- a/metapkg/poetry.lock +++ b/metapkg/poetry.lock @@ -6979,6 +6979,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -9426,4 +9427,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10, <3.13" -content-hash = "86cedfaef2e0514d48710078ebed3cae755e3b63a336e24ef3a57791e24f81ba" +content-hash = "b64d2bde7486ad6217fe030ab9281355b5f3ebc97f00ba3d6bed738132c717bd" diff --git a/metapkg/pyproject.toml b/metapkg/pyproject.toml index 4c02276db4..8c73ff7659 100644 --- a/metapkg/pyproject.toml +++ b/metapkg/pyproject.toml @@ -83,6 +83,7 @@ sqlalchemy = "==1.4.46, <2.0" sqlalchemy-bigquery = "==1.9.0" tabulate = "==0.9.0" tornado = "==6.4.2" +trino = {extras = ["sqlalchemy"], version = "==0.331.0"} typeguard = "==4.1.5" typing-extensions = "==4.15.0" ujson = "==1.35"