Skip to content

Commit

Permalink
Merge pull request #406 from mindsdb/staging
Browse files Browse the repository at this point in the history
Release 0.20.0
  • Loading branch information
ea-rus authored Oct 4, 2024
2 parents 565ee04 + bc8518c commit f747b5b
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 28 deletions.
2 changes: 1 addition & 1 deletion mindsdb_sql/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = 'mindsdb_sql'
__package_name__ = 'mindsdb_sql'
__version__ = '0.19.0'
__version__ = '0.20.0'
__description__ = "Pure python SQL parser"
__email__ = "[email protected]"
__author__ = 'MindsDB Inc'
Expand Down
12 changes: 9 additions & 3 deletions mindsdb_sql/parser/ast/select/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Case(ASTNode):
def __init__(self, rules, default, *args, **kwargs):
def __init__(self, rules, default=None, *args, **kwargs):
super().__init__(*args, **kwargs)

# structure:
Expand Down Expand Up @@ -32,10 +32,13 @@ def to_tree(self, *args, level=0, **kwargs):
f'{ind1}{condition.to_string()} => {result.to_string()}'
)
rules_str = '\n'.join(rules_ar)
default_str = ''
if self.default is not None:
default_str = f'{ind1}default => {self.default.to_string()}\n'

return f'{ind}Case(\n' \
f'{rules_str}\n' \
f'{ind1}default => {self.default.to_string()}\n' \
f'{default_str}' \
f'{ind})'

def get_string(self, *args, alias=True, **kwargs):
Expand All @@ -47,4 +50,7 @@ def get_string(self, *args, alias=True, **kwargs):
)
rules_str = ' '.join(rules_ar)

return f"CASE {rules_str} ELSE {self.default.to_string()} END"
default_str = ''
if self.default is not None:
default_str = f' ELSE {self.default.to_string()}'
return f"CASE {rules_str}{default_str} END"
25 changes: 17 additions & 8 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class MindsDBParser(Parser):
('nonassoc', LESS, LEQ, GREATER, GEQ, IN, NOT_IN, BETWEEN, IS, IS_NOT, NOT_LIKE, LIKE),
('left', JSON_GET),
('left', PLUS, MINUS),
('left', STAR, DIVIDE),
('left', STAR, DIVIDE, TYPECAST),
('right', UMINUS), # Unary minus operator, unary not

)
Expand Down Expand Up @@ -1329,9 +1329,10 @@ def column_list(self, p):
return column_list

# case
@_('CASE case_conditions ELSE expr END')
@_('CASE case_conditions ELSE expr END',
'CASE case_conditions END')
def case(self, p):
return Case(rules=p.case_conditions, default=p.expr)
return Case(rules=p.case_conditions, default=getattr(p, 'expr', None))

@_('case_condition',
'case_conditions case_condition')
Expand Down Expand Up @@ -1415,6 +1416,14 @@ def function(self, p):
args = p.expr_list_or_nothing
if not args:
args = []
for i, arg in enumerate(args):
if (
isinstance(arg, Identifier)
and len(arg.parts) == 1
and arg.parts[0].lower() == 'last'
):
args[i] = Last()

namespace = None
if hasattr(p, 'identifier'):
if len(p.identifier.parts) > 1:
Expand Down Expand Up @@ -1690,16 +1699,16 @@ def identifier(self, p):
node.parts += p[2].parts
return node

@_('id')
def identifier(self, p):
value = p[0]
return Identifier.from_path_str(value)

@_('quote_string',
'dquote_string')
def string(self, p):
return p[0]

@_('id', 'dquote_string')
def identifier(self, p):
value = p[0]
return Identifier.from_path_str(value)

@_('PARAMETER')
def parameter(self, p):
return Parameter(value=p.PARAMETER)
Expand Down
13 changes: 13 additions & 0 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,25 @@ def to_expression(self, t):
elif isinstance(t, ast.NotExists):
sub_stmt = self.prepare_select(t.query)
col = ~sub_stmt.exists()
elif isinstance(t, ast.Case):
col = self.prepare_case(t)
else:
# some other complex object?
raise NotImplementedError(f'Column {t}')

return col

def prepare_case(self, t: ast.Case):
conditions = []
for condition, result in t.rules:
conditions.append(
(self.to_expression(condition), self.to_expression(result))
)
if t.default is not None:
conditions.append(self.to_expression(t.default))

return sa.case(*conditions)

def to_function(self, t):
op = getattr(sa.func, t.op)
if t.from_arg is not None:
Expand Down
32 changes: 25 additions & 7 deletions tests/test_parser/test_base_sql/test_select_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,6 @@ def test_case(self):
sum(
CASE
WHEN 1 = 1 THEN 1
ELSE 0
END
)
FROM INFORMATION_SCHEMA.COLLATIONS'''
Expand Down Expand Up @@ -1009,7 +1008,6 @@ def test_case(self):
Constant(1)
],
],
default=Constant(0)
)
]
)
Expand Down Expand Up @@ -1108,12 +1106,22 @@ def test_alternative_casting(self):
assert str(ast) == str(expected_ast)

# date
expected_ast = Select(targets=[
TypeCast(type_name='DATE', arg=Constant('1998-12-01')),
TypeCast(type_name='DATE', arg=Identifier('col1'), alias=Identifier('col2'))
])
expected_ast = Select(
targets=[
TypeCast(type_name='DATE', arg=Constant('1998-12-01')),
BinaryOperation(op='+', args=[
Identifier('col0'),
TypeCast(type_name='DATE', arg=Identifier('col1'), alias=Identifier('col2')),
])
],
from_table=Identifier('t1'),
where=BinaryOperation(op='>', args=[
Identifier('col0'),
TypeCast(type_name='DATE', arg=Identifier('col1')),
])
)

sql = f"SELECT '1998-12-01'::DATE, col1::DATE col2"
sql = f"SELECT '1998-12-01'::DATE, col0 + col1::DATE col2 from t1 where col0 > col1::DATE"
ast = parse_sql(sql)
assert str(ast) == str(expected_ast)

Expand All @@ -1126,3 +1134,13 @@ def test_alternative_casting(self):
ast = parse_sql(sql)
assert str(ast) == str(expected_ast)

def test_table_double_quote(self):
expected_ast = Select(
targets=[Identifier('account_id')],
from_table=Identifier(parts=['order'])
)

sql = 'select account_id from "order"'

ast = parse_sql(sql)
assert str(ast) == str(expected_ast)
29 changes: 20 additions & 9 deletions tests/test_parser/test_mindsdb/test_selects.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,32 @@ def test_select_limit_negative(self):
assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)


def test_last(self):
sql = """SELECT * FROM t1 t where t.id>last"""
sql = """SELECT * FROM t1 t where t.id>last and t.x > coalence(last, 0)"""

ast = parse_sql(sql, dialect='mindsdb')
expected_ast = Select(
targets=[Star()],
from_table=Identifier(parts=['t1'], alias=Identifier('t')),
where=BinaryOperation(
op='>',
args=[
Identifier(parts=['t', 'id']),
Last()
]
)
where=BinaryOperation(op='and', args=[
BinaryOperation(
op='>',
args=[
Identifier(parts=['t', 'id']),
Last()
]
),
BinaryOperation(
op='>',
args=[
Identifier(parts=['t', 'x']),
Function(op='coalence', args=[
Last(),
Constant(0)
])
]
),
])
)

assert ast.to_tree() == expected_ast.to_tree()
Expand Down
1 change: 1 addition & 0 deletions tests/test_render/test_sqlalchemyrender.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def parse_sql2(sql, dialect='mindsdb'):
or 'current_user()' in sql # replaced to CURRENT_USER
or 'user()' in sql # replaced to USER
or 'not exists' in sql # replaced to not(exits(
or "WHEN R.DELETE_RULE = 'CASCADE'" in sql # wrapped in parens by sqlalchemy
):

# sqlalchemy could add own aliases for constant
Expand Down

0 comments on commit f747b5b

Please sign in to comment.