diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 54e479a9..07638a20 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -182,6 +182,18 @@ def test_ignore_nulls(dialect, function, element): f'SELECT {function}("table".id) OVER (PARTITION BY "table".name) AS window ' \ f'\nFROM "table"' + # testing with compile kwargs + statement = select( + element( + func.round(table_without_catalog.c.id, 2), + ignore_nulls=True, + ).over(partition_by=table_without_catalog.c.name).label('window') + ) + query = statement.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) + assert str(query) == \ + f'SELECT {function}(round("table".id, 2)) IGNORE NULLS OVER (PARTITION BY "table".name) AS window '\ + f'\nFROM "table"' + @pytest.mark.skipif( sqlalchemy_version() < "2.0", diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 174750b7..19555920 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -172,7 +172,7 @@ class Lag(GenericIgnoreNulls): @compiles(Lead) @compiles(Lag) def compile_ignore_nulls(element, compiler, **kwargs): - compiled = f'{element.name}({compiler.process(element.clauses)})' + compiled = f'{element.name}({compiler.process(element.clauses, **kwargs)})' if element.ignore_nulls: compiled += ' IGNORE NULLS' return compiled