Skip to content

Commit

Permalink
Merge pull request #410 from mindsdb/fix-cast-decimal
Browse files Browse the repository at this point in the history
Support CAST(a AS decimal(x, y))
  • Loading branch information
ea-rus authored Oct 31, 2024
2 parents fd1ae98 + ec6751b commit c50e06c
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 11 deletions.
11 changes: 6 additions & 5 deletions mindsdb_sql/parser/ast/select/type_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@


class TypeCast(ASTNode):
def __init__(self, type_name, arg, length=None, *args, **kwargs):
def __init__(self, type_name, arg, precision=None, *args, **kwargs):
super().__init__(*args, **kwargs)

self.type_name = type_name
self.arg = arg
self.length = length
self.precision = precision

def to_tree(self, *args, level=0, **kwargs):
out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, length={self.length}, arg=\n{indent(level+1)}{self.arg.to_tree()})'
out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, precision={self.precision}, arg=\n{indent(level+1)}{self.arg.to_tree()})'
return out_str

def get_string(self, *args, **kwargs):
type_name = self.type_name
if self.length is not None:
type_name += f'({self.length})'
if self.precision is not None:
precision = map(str, self.precision)
type_name += f'({",".join(precision)})'
return f'CAST({str(self.arg)} AS {type_name})'
7 changes: 6 additions & 1 deletion mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,8 +1466,13 @@ def expr_list_or_nothing(self, p):
pass

@_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN')
@_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN')
def expr(self, p):
return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer)
if hasattr(p, 'integer'):
precision=[p.integer]
else:
precision=[p.integer0, p.integer1]
return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision)

@_('CAST LPAREN expr AS id RPAREN')
def expr(self, p):
Expand Down
7 changes: 6 additions & 1 deletion mindsdb_sql/parser/dialects/mysql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,13 @@ def expr_list_or_nothing(self, p):
pass

@_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN')
@_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN')
def expr(self, p):
return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer)
if hasattr(p, 'integer'):
precision=[p.integer]
else:
precision=[p.integer0, p.integer1]
return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision)

@_('CAST LPAREN expr AS id RPAREN')
def expr(self, p):
Expand Down
7 changes: 6 additions & 1 deletion mindsdb_sql/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,13 @@ def expr_list_or_nothing(self, p):
pass

@_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN')
@_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN')
def expr(self, p):
return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer)
if hasattr(p, 'integer'):
precision=[p.integer]
else:
precision=[p.integer0, p.integer1]
return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision)

@_('CAST LPAREN expr AS id RPAREN')
def expr(self, p):
Expand Down
4 changes: 2 additions & 2 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ def to_expression(self, t):
elif isinstance(t, ast.TypeCast):
arg = self.to_expression(t.arg)
type = self.get_type(t.type_name)
if t.length is not None:
type = type(t.length)
if t.precision is not None:
type = type(*t.precision)
col = sa.cast(arg, type)

if t.alias:
Expand Down
10 changes: 9 additions & 1 deletion tests/test_parser/test_base_sql/test_select_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,15 @@ def test_type_cast(self, dialect):
sql = f"""SELECT CAST(a AS CHAR(10))"""
ast = parse_sql(sql, dialect=dialect)
expected_ast = Select(targets=[
TypeCast(type_name='CHAR', arg=Identifier('a'), length=10)
TypeCast(type_name='CHAR', arg=Identifier('a'), precision=[10])
])
assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)

sql = f"""SELECT CAST(a AS DECIMAL(10, 1))"""
ast = parse_sql(sql, dialect=dialect)
expected_ast = Select(targets=[
TypeCast(type_name='DECIMAL', arg=Identifier('a'), precision=[10, 1])
])
assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)
Expand Down

0 comments on commit c50e06c

Please sign in to comment.