Skip to content

Commit

Permalink
fix(sql_parse): Provide more lenient logic when extracting latest[_su…
Browse files Browse the repository at this point in the history
…b]_partition (apache#28152)
  • Loading branch information
john-bodley authored Apr 26, 2024
1 parent 1e47e65 commit c5e7d87
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
21 changes: 12 additions & 9 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,16 +1554,19 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
"latest_partition",
"latest_sub_partition",
):
# Extract the table referenced in the macro.
tables.add(
Table(
*[
remove_quotes(part.strip())
for part in node.args[0].as_const().split(".")[::-1]
if len(node.args) == 1
]
# Try to extract the table referenced in the macro.
try:
tables.add(
Table(
*[
remove_quotes(part.strip())
for part in node.args[0].as_const().split(".")[::-1]
if len(node.args) == 1
]
)
)
)
except nodes.Impossible:
pass

# Replace the potentially problematic Jinja macro with some benign SQL.
node.__class__ = nodes.TemplateData
Expand Down
40 changes: 22 additions & 18 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,36 +1857,40 @@ def test_sqlstatement() -> None:
],
)
@pytest.mark.parametrize(
"macro",
[
"latest_partition('foo.bar')",
"latest_partition(' foo.bar ')", # Non-atypical user error which works
"latest_partition('foo.%s'|format('bar'))",
"latest_sub_partition('foo.bar', baz='qux')",
],
)
@pytest.mark.parametrize(
"sql,expected",
"macro,expected",
[
(
"SELECT '{{{{ {engine}.{macro} }}}}'",
"latest_partition('foo.bar')",
{Table(table="bar", schema="foo")},
),
(
"latest_partition(' foo.bar ')", # Non-atypical user error which works
{Table(table="bar", schema="foo")},
),
(
"SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
"latest_partition('foo.%s'|format('bar'))",
{Table(table="bar", schema="foo")},
),
(
"latest_sub_partition('foo.bar', baz='qux')",
{Table(table="bar", schema="foo")},
),
(
"latest_partition('foo.%s'|format(str('bar')))",
set(),
),
(
"latest_partition('foo.{}'.format('bar'))",
set(),
),
],
)
def test_extract_tables_from_jinja_sql(
engine: str,
macro: str,
sql: str,
expected: set[Table],
engine: str, macro: str, expected: set[Table]
) -> None:
assert (
extract_tables_from_jinja_sql(
sql=sql.format(engine=engine, macro=macro),
sql=f"'{{{{ {engine}.{macro} }}}}'",
database=Mock(),
)
== expected
Expand Down

0 comments on commit c5e7d87

Please sign in to comment.