From e487c931bfd02179add123e48abedeb6de6c393d Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 30 Oct 2024 12:14:29 +0300 Subject: [PATCH 1/2] support CAST(a AS decimal(10, 1)) --- mindsdb_sql/parser/ast/select/type_cast.py | 11 ++++++----- mindsdb_sql/parser/dialects/mindsdb/parser.py | 7 ++++++- mindsdb_sql/parser/dialects/mysql/parser.py | 7 ++++++- mindsdb_sql/parser/parser.py | 7 ++++++- mindsdb_sql/render/sqlalchemy_render.py | 4 ++-- .../test_base_sql/test_select_structure.py | 10 +++++++++- 6 files changed, 35 insertions(+), 11 deletions(-) diff --git a/mindsdb_sql/parser/ast/select/type_cast.py b/mindsdb_sql/parser/ast/select/type_cast.py index f0ae304c..7bad1ce9 100644 --- a/mindsdb_sql/parser/ast/select/type_cast.py +++ b/mindsdb_sql/parser/ast/select/type_cast.py @@ -3,19 +3,20 @@ class TypeCast(ASTNode): - def __init__(self, type_name, arg, length=None, *args, **kwargs): + def __init__(self, type_name, arg, precision=None, *args, **kwargs): super().__init__(*args, **kwargs) self.type_name = type_name self.arg = arg - self.length = length + self.precision = precision def to_tree(self, *args, level=0, **kwargs): - out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, length={self.length}, arg=\n{indent(level+1)}{self.arg.to_tree()})' + out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, precision={self.precision}, arg=\n{indent(level+1)}{self.arg.to_tree()})' return out_str def get_string(self, *args, **kwargs): type_name = self.type_name - if self.length is not None: - type_name += f'({self.length})' + if self.precision is not None: + precision = map(str, self.precision) + type_name += f'({",".join(precision)})' return f'CAST({str(self.arg)} AS {type_name})' diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index c0d1edde..c740cf4d 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -1466,8 +1466,13 @@ def expr_list_or_nothing(self, p): pass @_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN') + @_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN') def expr(self, p): - return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer) + if hasattr(p, 'integer'): + precision=[p.integer] + else: + precision=[p.integer0, p.integer1] + return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision) @_('CAST LPAREN expr AS id RPAREN') def expr(self, p): diff --git a/mindsdb_sql/parser/dialects/mysql/parser.py b/mindsdb_sql/parser/dialects/mysql/parser.py index bb128562..0f28ba7a 100644 --- a/mindsdb_sql/parser/dialects/mysql/parser.py +++ b/mindsdb_sql/parser/dialects/mysql/parser.py @@ -821,8 +821,13 @@ def expr_list_or_nothing(self, p): pass @_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN') + @_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN') def expr(self, p): - return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer) + if hasattr(p, 'integer'): + precision=[p.integer] + else: + precision=[p.integer0, p.integer1] + return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision) @_('CAST LPAREN expr AS id RPAREN') def expr(self, p): diff --git a/mindsdb_sql/parser/parser.py b/mindsdb_sql/parser/parser.py index 18576c86..6794116b 100644 --- a/mindsdb_sql/parser/parser.py +++ b/mindsdb_sql/parser/parser.py @@ -581,8 +581,13 @@ def expr_list_or_nothing(self, p): pass @_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN') + @_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN') def expr(self, p): - return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer) + if hasattr(p, 'integer'): + precision=[p.integer] + else: + precision=[p.integer0, p.integer1] + return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision) @_('CAST LPAREN expr AS id RPAREN') def expr(self, p): diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index 4d7e513d..bcd54f4e 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -254,8 +254,8 @@ def to_expression(self, t): elif isinstance(t, ast.TypeCast): arg = self.to_expression(t.arg) type = self.get_type(t.type_name) - if t.length is not None: - type = type(t.length) + if t.precision is not None: + type = type(*t.precision) col = sa.cast(arg, type) if t.alias: diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 02a99a43..65d686e7 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -633,7 +633,15 @@ def test_type_cast(self, dialect): sql = f"""SELECT CAST(a AS CHAR(10))""" ast = parse_sql(sql, dialect=dialect) expected_ast = Select(targets=[ - TypeCast(type_name='CHAR', arg=Identifier('a'), length=10) + TypeCast(type_name='CHAR', arg=Identifier('a'), precision=[10]) + ]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = f"""SELECT CAST(a AS decimal(10, 1))""" + ast = parse_sql(sql, dialect=dialect) + expected_ast = Select(targets=[ + TypeCast(type_name='decimal', arg=Identifier('a'), precision=[10, 1]) ]) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) From ec6751b00e649fb3e2eed565224e70519d26a046 Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 30 Oct 2024 14:17:29 +0300 Subject: [PATCH 2/2] support CAST(a AS decimal(10, 1)) --- tests/test_parser/test_base_sql/test_select_structure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 65d686e7..07f8abee 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -638,10 +638,10 @@ def test_type_cast(self, dialect): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) - sql = f"""SELECT CAST(a AS decimal(10, 1))""" + sql = f"""SELECT CAST(a AS DECIMAL(10, 1))""" ast = parse_sql(sql, dialect=dialect) expected_ast = Select(targets=[ - TypeCast(type_name='decimal', arg=Identifier('a'), precision=[10, 1]) + TypeCast(type_name='DECIMAL', arg=Identifier('a'), precision=[10, 1]) ]) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast)