diff --git a/mindsdb_sql/parser/ast/set.py b/mindsdb_sql/parser/ast/set.py index bc48f238..f317d892 100644 --- a/mindsdb_sql/parser/ast/set.py +++ b/mindsdb_sql/parser/ast/set.py @@ -6,94 +6,120 @@ class Set(ASTNode): def __init__(self, category=None, - arg=None, + name=None, + value=None, + scope=None, params=None, + set_list=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.category = category - self.arg = arg - self.params = params or {} - def to_tree(self, *args, level=0, **kwargs): - ind = indent(level) - category_str = f'category={self.category}, ' - arg_str = f'arg={self.arg.to_tree()},' if self.arg else '' - if self.params: - param_str = 'param=' + ', '.join([f'{k}:{v}' for k,v in self.params.items()]) - else: - param_str = '' - out_str = f'{ind}Set(' \ - f'{category_str}' \ - f'{arg_str} ' \ - f'{param_str}' \ - f')' - return out_str + # names / charset / transactions + self.category = category - def get_string(self, *args, **kwargs): - if self.params: - param_str = ' ' + ' '.join([f'{k} {v}' for k, v in self.params.items()]) - else: - param_str = '' - - if isinstance(self.arg, Tuple): - arg_str = ', '.join([str(i) for i in self.arg.items]) - else: - arg_str = f' {str(self.arg)}' if self.arg else '' - return f'SET {self.category if self.category else ""}{arg_str}{param_str}' + # name for variable assigment. category is None it this case + self.name = name + self.value = value + self.params = params or {} -class SetTransaction(ASTNode): - def __init__(self, - isolation_level=None, - access_mode=None, - scope=None, - *args, **kwargs): - super().__init__(*args, **kwargs) + # global / session / ... + self.scope = scope - if isolation_level is not None: - isolation_level = isolation_level.upper() - if access_mode is not None: - access_mode = access_mode.upper() - if scope is not None: - scope = scope.upper() + # contents all set subcommands + self.set_list = set_list - self.scope = scope - self.access_mode = access_mode - self.isolation_level = isolation_level def to_tree(self, *args, level=0, **kwargs): - ind = indent(level) - if self.scope is None: - scope_str = '' + if self.set_list is not None: + items = [set.render() for set in self.set_list] else: - scope_str = f'scope={self.scope}, ' + items = self.render() - properties = [] - if self.isolation_level is not None: - properties.append('ISOLATION LEVEL ' + self.isolation_level) - if self.access_mode is not None: - properties.append(self.access_mode) - prop_str = ', '.join(properties) + ind = indent(level) - out_str = f'{ind}SetTransaction(' \ - f'{scope_str}' \ - f'properties=[{prop_str}]' \ - f'\n{ind})' - return out_str + return f'{ind}Set(items={items})' def get_string(self, *args, **kwargs): - properties = [] - if self.isolation_level is not None: - properties.append('ISOLATION LEVEL ' + self.isolation_level) - if self.access_mode is not None: - properties.append(self.access_mode) + return 'SET ' + self.render() - prop_str = ', '.join(properties) + def render(self): + if self.set_list is not None: + render_list = [set.render() for set in self.set_list] + return ', '.join(render_list) - if self.scope is None: - scope_str = '' + if self.params: + param_str = ' ' + ' '.join([f'{k} {v}' for k, v in self.params.items()]) else: - scope_str = self.scope + ' ' + param_str = '' - return f'SET {scope_str}TRANSACTION {prop_str}' + if self.name is not None: + # category should be empty + content = f'{self.name.to_string()}={self.value.to_string()}' + elif self.value is not None: + content = f'{self.category} {self.value.to_string()}' + else: + content = f'{self.category}' + + scope = '' + if self.scope is not None: + scope = f'{self.scope} ' + + return f'{scope}{content}{param_str}' + + +# class SetTransaction(ASTNode): +# def __init__(self, +# isolation_level=None, +# access_mode=None, +# scope=None, +# *args, **kwargs): +# super().__init__(*args, **kwargs) +# +# if isolation_level is not None: +# isolation_level = isolation_level.upper() +# if access_mode is not None: +# access_mode = access_mode.upper() +# if scope is not None: +# scope = scope.upper() +# +# self.scope = scope +# self.access_mode = access_mode +# self.isolation_level = isolation_level +# +# def to_tree(self, *args, level=0, **kwargs): +# ind = indent(level) +# if self.scope is None: +# scope_str = '' +# else: +# scope_str = f'scope={self.scope}, ' +# +# properties = [] +# if self.isolation_level is not None: +# properties.append('ISOLATION LEVEL ' + self.isolation_level) +# if self.access_mode is not None: +# properties.append(self.access_mode) +# prop_str = ', '.join(properties) +# +# out_str = f'{ind}SetTransaction(' \ +# f'{scope_str}' \ +# f'properties=[{prop_str}]' \ +# f'\n{ind})' +# return out_str +# +# def get_string(self, *args, **kwargs): +# properties = [] +# if self.isolation_level is not None: +# properties.append('ISOLATION LEVEL ' + self.isolation_level) +# if self.access_mode is not None: +# properties.append(self.access_mode) +# +# prop_str = ', '.join(properties) +# +# if self.scope is None: +# scope_str = '' +# else: +# scope_str = self.scope + ' ' +# +# return f'SET {scope_str}TRANSACTION {prop_str}' diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index 6a312edf..566559bc 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -42,7 +42,7 @@ class MindsDBLexer(Lexer): VARIABLES, SESSION, STATUS, GLOBAL, PROCEDURE, FUNCTION, INDEX, WARNINGS, ENGINES, CHARSET, COLLATION, PLUGINS, CHARACTER, - PERSIST, PERSIST_ONLY, DEFAULT, + PERSIST, PERSIST_ONLY, IF_EXISTS, IF_NOT_EXISTS, COLUMNS, FIELDS, COLLATE, SEARCH_PATH, VARIABLE, SYSTEM_VARIABLE, @@ -172,7 +172,6 @@ class MindsDBLexer(Lexer): PLUGINS = r'\bPLUGINS\b' PERSIST = r'\bPERSIST\b' PERSIST_ONLY = r'\bPERSIST_ONLY\b' - DEFAULT = r'\bDEFAULT\b' IF_EXISTS = r'\bIF[\s]+EXISTS\b' IF_NOT_EXISTS = r'\bIF[\s]+NOT[\s]+EXISTS\b' COLUMNS = r'\bCOLUMNS\b' diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 72777809..2b720c20 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -334,42 +334,57 @@ def commit_transaction(self, p): def rollback_transaction(self, p): return RollbackTransaction() - # Set - - @_('SET id identifier', - 'SET id identifier COLLATE constant', - 'SET id identifier COLLATE DEFAULT', - 'SET id constant', - 'SET id constant COLLATE constant', - 'SET id constant COLLATE DEFAULT') + # --- Set --- + @_('SET set_item_list') def set(self, p): - if not p.id.lower() == 'names': - raise ParsingException(f'Expected "SET names", got "SET {p.id}"') - if isinstance(p[2], Constant): - arg = Identifier(p[2].value) + set_list = p[1] + if len(set_list) == 1: + return set_list[0] + return Set(set_list=set_list) + + @_('set_item', + 'set_item_list COMMA set_item') + def set_item_list(self, p): + arr = getattr(p, 'set_item_list', []) + arr.append(p.set_item) + return arr + + # set names + @_('id id', + 'id constant', + 'id id COLLATE constant', + 'id id COLLATE id', + 'id constant COLLATE constant', + 'id constant COLLATE id') + def set_item(self, p): + category = p[0] + if category.lower() != 'names': + raise ParsingException(f'Expected "SET names", got "SET {category}"') + if isinstance(p[1], Constant): + value = p[1] else: - # is identifier - arg = p[2] + # is id + value = Constant(p[1], with_quotes=False) params = {} if hasattr(p, 'COLLATE'): - if isinstance(p[4], Constant): - val = p[4] + if isinstance(p[3], Constant): + val = p[3] else: - val = Constant(p[4], with_quotes=False) + val = Constant(p[3], with_quotes=False) params['COLLATE'] = val - return Set(category=p.id.lower(), arg=arg, params=params) + return Set(category=category, value=value, params=params) # set charset - @_('SET charset constant', - 'SET charset DEFAULT') - def set(self, p): + @_('charset constant', + 'charset id') + def set_item(self, p): if hasattr(p, 'id'): arg = Constant(p.id, with_quotes=False) else: arg = p.constant - return Set(category='CHARSET', arg=arg) + return Set(category='CHARSET', value=arg) @_('CHARACTER SET', 'CHARSET', @@ -378,29 +393,30 @@ def charset(self, p): return p[0] # set transaction - @_('SET transact_scope TRANSACTION transact_property_list', - 'SET TRANSACTION transact_property_list') - def set(self, p): + @_('set_scope TRANSACTION transact_property_list', + 'TRANSACTION transact_property_list') + def set_item(self, p): isolation_level = None access_mode = None - transact_scope = getattr(p, 'transact_scope', None) + transact_scope = getattr(p, 'set_scope', None) for prop in p.transact_property_list: if prop['type'] == 'iso_level': isolation_level = prop['value'] else: access_mode = prop['value'] - return SetTransaction( - isolation_level=isolation_level, - access_mode=access_mode, + params = {} + if isolation_level is not None: + params['isolation_level'] = isolation_level + if access_mode is not None: + params['access_mode'] = access_mode + + return Set( + category='TRANSACTION', scope=transact_scope, + params=params ) - @_('GLOBAL', - 'SESSION') - def transact_scope(self, p): - return p[0] - @_('transact_property_list COMMA transact_property') def transact_property_list(self, p): return p.transact_property_list + [p.transact_property] @@ -429,30 +445,29 @@ def transact_level(self, p): def transact_access_mode(self, p): return ' '.join([x for x in p]) - @_('SET expr_list', - 'SET set_modifier expr_list') - def set(self, p): - if len(p.expr_list) == 1: - arg = p.expr_list[0] - else: - arg = Tuple(items=p.expr_list) + @_('identifier EQUALS expr', + 'set_scope identifier EQUALS expr', + 'variable EQUALS expr', + 'set_scope variable EQUALS expr') + def set_item(self, p): - if hasattr(p, 'set_modifier'): - category = p.set_modifier - else: - category = None + scope = None + name = p[0] + if hasattr(p, 'set_scope'): + scope = p.set_scope + name=p[1] - return Set(category=category, arg=arg) + return Set(name=name, value=p.expr, scope=scope) @_('GLOBAL', 'PERSIST', 'PERSIST_ONLY', 'SESSION', ) - def set_modifier(self, p): + def set_scope(self, p): return p[0] - # Show + # --- Show --- @_('show WHERE expr') def show(self, p): command = p.show diff --git a/mindsdb_sql/parser/dialects/mysql/parser.py b/mindsdb_sql/parser/dialects/mysql/parser.py index f5f9966b..ee6541c3 100644 --- a/mindsdb_sql/parser/dialects/mysql/parser.py +++ b/mindsdb_sql/parser/dialects/mysql/parser.py @@ -102,42 +102,58 @@ def commit_transaction(self, p): def rollback_transaction(self, p): return RollbackTransaction() - # Set - - @_('SET id identifier') - @_('SET id identifier COLLATE constant') - @_('SET id identifier COLLATE DEFAULT') - @_('SET id constant') - @_('SET id constant COLLATE constant') - @_('SET id constant COLLATE DEFAULT') + + # --- Set --- + @_('SET set_item_list') def set(self, p): - if not p.id.lower() == 'names': - raise ParsingException(f'Expected "SET names", got "SET {p.id}"') - if isinstance(p[2], Constant): - arg = Identifier(p[2].value) + set_list = p[1] + if len(set_list) == 1: + return set_list[0] + return Set(set_list=set_list) + + @_('set_item', + 'set_item_list COMMA set_item') + def set_item_list(self, p): + arr = getattr(p, 'set_item_list', []) + arr.append(p.set_item) + return arr + + # set names + @_('id id', + 'id constant', + 'id id COLLATE constant', + 'id id COLLATE id', + 'id constant COLLATE constant', + 'id constant COLLATE id') + def set_item(self, p): + category = p[0] + if category.lower() != 'names': + raise ParsingException(f'Expected "SET names", got "SET {category}"') + if isinstance(p[1], Constant): + value = p[1] else: - # is identifier - arg = p[2] + # is id + value = Constant(p[1], with_quotes=False) params = {} if hasattr(p, 'COLLATE'): - if isinstance(p[4], Constant): - val = p[4] + if isinstance(p[3], Constant): + val = p[3] else: - val = Constant('DEFAULT', with_quotes=False) + val = Constant(p[3], with_quotes=False) params['COLLATE'] = val - return Set(category=p.id.lower(), arg=arg, params=params) + return Set(category=category, value=value, params=params) # set charset - @_('SET charset constant') - @_('SET charset DEFAULT') - def set(self, p): - if hasattr(p, 'DEFAULT'): - arg = Constant('DEFAULT', with_quotes=False) + @_('charset constant', + 'charset id') + def set_item(self, p): + if hasattr(p, 'id'): + arg = Constant(p.id, with_quotes=False) else: arg = p.constant - return Set(category='CHARSET', arg=arg) + return Set(category='CHARSET', value=arg) @_('CHARACTER SET', 'CHARSET', @@ -146,29 +162,30 @@ def charset(self, p): return p[0] # set transaction - @_('SET transact_scope TRANSACTION transact_property_list') - @_('SET TRANSACTION transact_property_list') - def set(self, p): + @_('set_scope TRANSACTION transact_property_list', + 'TRANSACTION transact_property_list') + def set_item(self, p): isolation_level = None access_mode = None - transact_scope = getattr(p, 'transact_scope', None) + transact_scope = getattr(p, 'set_scope', None) for prop in p.transact_property_list: if prop['type'] == 'iso_level': isolation_level = prop['value'] else: access_mode = prop['value'] - return SetTransaction( - isolation_level=isolation_level, - access_mode=access_mode, + params = {} + if isolation_level is not None: + params['isolation_level'] = isolation_level + if access_mode is not None: + params['access_mode'] = access_mode + + return Set( + category='TRANSACTION', scope=transact_scope, + params=params ) - @_('GLOBAL', - 'SESSION') - def transact_scope(self, p): - return p[0] - @_('transact_property_list COMMA transact_property') def transact_property_list(self, p): return p.transact_property_list + [p.transact_property] @@ -181,9 +198,9 @@ def transact_property_list(self, p): 'transact_access_mode') def transact_property(self, p): if hasattr(p, 'transact_level'): - return {'type': 'iso_level', 'value': p.transact_level} + return {'type':'iso_level', 'value':p.transact_level} else: - return {'type': 'access_mode', 'value': p.transact_access_mode} + return {'type':'access_mode', 'value':p.transact_access_mode} @_('REPEATABLE READ', 'READ COMMITTED', @@ -197,30 +214,29 @@ def transact_level(self, p): def transact_access_mode(self, p): return ' '.join([x for x in p]) - @_('SET expr_list') - @_('SET set_modifier expr_list') - def set(self, p): - if len(p.expr_list) == 1: - arg = p.expr_list[0] - else: - arg = Tuple(items=p.expr_list) + @_('identifier EQUALS expr', + 'set_scope identifier EQUALS expr', + 'variable EQUALS expr', + 'set_scope variable EQUALS expr') + def set_item(self, p): - if hasattr(p, 'set_modifier'): - category = p.set_modifier - else: - category = None + scope = None + name = p[0] + if hasattr(p, 'set_scope'): + scope = p.set_scope + name=p[1] - return Set(category=category, arg=arg) + return Set(name=name, value=p.expr, scope=scope) @_('GLOBAL', 'PERSIST', 'PERSIST_ONLY', 'SESSION', ) - def set_modifier(self, p): + def set_scope(self, p): return p[0] - # Show + # --- Show --- @_('show WHERE expr') def show(self, p): command = p.show diff --git a/mindsdb_sql/parser/lexer.py b/mindsdb_sql/parser/lexer.py index 1df0ee87..6033689e 100644 --- a/mindsdb_sql/parser/lexer.py +++ b/mindsdb_sql/parser/lexer.py @@ -25,7 +25,7 @@ class SQLLexer(Lexer): VIEW, VARIABLES, SESSION, STATUS, GLOBAL, PROCEDURE, FUNCTION, INDEX, WARNINGS, ENGINES, CHARSET, COLLATION, PLUGINS, CHARACTER, - PERSIST, PERSIST_ONLY, DEFAULT, + PERSIST, PERSIST_ONLY, IF_EXISTS, COLUMNS, FIELDS, COLLATE, # SELECT Keywords @@ -109,7 +109,6 @@ class SQLLexer(Lexer): PLUGINS = r'\bPLUGINS\b' PERSIST = r'\bPERSIST\b' PERSIST_ONLY = r'\bPERSIST_ONLY\b' - DEFAULT = r'\bDEFAULT\b' IF_EXISTS = r'\bIF[\s]+EXISTS\b' COLUMNS = r'\bCOLUMNS\b' FIELDS = r'\bFIELDS\b' diff --git a/tests/test_parser/test_base_sql/test_misc_sql_queries.py b/tests/test_parser/test_base_sql/test_misc_sql_queries.py index 53f09879..68acb3f9 100644 --- a/tests/test_parser/test_base_sql/test_misc_sql_queries.py +++ b/tests/test_parser/test_base_sql/test_misc_sql_queries.py @@ -3,22 +3,22 @@ from mindsdb_sql.parser.ast import * -@pytest.mark.parametrize('dialect', ['sqlite', 'mysql', 'mindsdb']) +@pytest.mark.parametrize('dialect', ['mysql', 'mindsdb']) class TestMiscQueries: def test_set(self, dialect): lexer, parser = get_lexer_parser(dialect) - sql = "SET NAMES some_name" + sql = "SET names some_name" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category="names", arg=Identifier('some_name')) + expected_ast = Set(category="names", value=Identifier('some_name')) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) sql = "set character_set_results = NULL" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(arg=BinaryOperation('=', args=[Identifier('character_set_results'), NullConstant()])) + expected_ast = Set(name=Identifier('character_set_results'), value=NullConstant()) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -74,14 +74,8 @@ def test_autocommit(self, dialect): ast = parse_sql(sql, dialect=dialect) expected_ast = Set( - category=None, - arg=BinaryOperation( - op='=', - args=( - Identifier('autocommit'), - Constant(1) - ) - ) + name=Identifier('autocommit'), + value=Constant(1) ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -94,11 +88,12 @@ def test_set(self, dialect): sql = "set var1 = NULL, var2 = 10" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(arg=Tuple(items=[ - BinaryOperation('=', args=[Identifier('var1'), NullConstant()]), - BinaryOperation('=', args=[Identifier('var2'), Constant(10)]), - ]) - ) + expected_ast = Set( + set_list=[ + Set(name=Identifier('var1'), value=NullConstant()), + Set(name=Identifier('var2'), value=Constant(10)), + ] + ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -106,17 +101,17 @@ def test_set(self, dialect): sql = "SET NAMES some_name collate DEFAULT" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category="names", - arg=Identifier('some_name'), + expected_ast = Set(category="NAMES", + value=Constant('some_name', with_quotes=False), params={'COLLATE': 'DEFAULT'}) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) - sql = "SET NAMES some_name collate 'utf8mb4_general_ci'" + sql = "SET names some_name collate 'utf8mb4_general_ci'" ast = parse_sql(sql, dialect=dialect) expected_ast = Set(category="names", - arg=Identifier('some_name'), + value=Constant('some_name', with_quotes=False), params={'COLLATE': Constant('utf8mb4_general_ci')}) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -126,14 +121,14 @@ def test_set_charset(self, dialect): sql = "SET CHARACTER SET DEFAULT" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category='CHARSET', arg=Constant('DEFAULT', with_quotes=False)) + expected_ast = Set(category='CHARSET', value=Constant('DEFAULT', with_quotes=False)) assert ast.to_tree() == expected_ast.to_tree() sql = "SET CHARSET DEFAULT" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category='CHARSET', arg=Constant('DEFAULT', with_quotes=False)) + expected_ast = Set(category='CHARSET', value=Constant('DEFAULT', with_quotes=False)) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -141,7 +136,7 @@ def test_set_charset(self, dialect): sql = "SET CHARSET 'utf8'" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category='CHARSET', arg=Constant('utf8')) + expected_ast = Set(category='CHARSET', value=Constant('utf8')) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -151,10 +146,14 @@ def test_set_transaction(self, dialect): sql = "SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE" ast = parse_sql(sql, dialect=dialect) - expected_ast = SetTransaction( - isolation_level='REPEATABLE READ', - access_mode='READ WRITE', - scope='GLOBAL') + expected_ast = Set( + category='TRANSACTION', + params=dict( + isolation_level='REPEATABLE READ', + access_mode='READ WRITE', + ), + scope='GLOBAL' + ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -162,10 +161,15 @@ def test_set_transaction(self, dialect): sql = "SET SESSION TRANSACTION READ ONLY, ISOLATION LEVEL SERIALIZABLE" ast = parse_sql(sql, dialect=dialect) - expected_ast = SetTransaction( - isolation_level='SERIALIZABLE', - access_mode='READ ONLY', - scope='SESSION') + + expected_ast = Set( + category='TRANSACTION', + params=dict( + isolation_level='SERIALIZABLE', + access_mode='READ ONLY', + ), + scope='SESSION' + ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -173,8 +177,12 @@ def test_set_transaction(self, dialect): sql = "SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED" ast = parse_sql(sql, dialect=dialect) - expected_ast = SetTransaction( - isolation_level='READ UNCOMMITTED' + + expected_ast = Set( + category='TRANSACTION', + params=dict( + isolation_level='READ UNCOMMITTED', + ) ) assert ast.to_tree() == expected_ast.to_tree() @@ -183,8 +191,12 @@ def test_set_transaction(self, dialect): sql = "SET TRANSACTION READ ONLY" ast = parse_sql(sql, dialect=dialect) - expected_ast = SetTransaction( - access_mode='READ ONLY' + + expected_ast = Set( + category='TRANSACTION', + params=dict( + access_mode='READ ONLY', + ) ) assert ast.to_tree() == expected_ast.to_tree() @@ -198,3 +210,14 @@ def test_begin(self, dialect): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) +class TestMindsdb: + def test_charset(self): + sql = "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci" + + ast = parse_sql(sql) + expected_ast = Set(category="NAMES", + value=Constant('utf8mb4', with_quotes=False), + params={'COLLATE': Constant('utf8mb4_unicode_ci', with_quotes=False)}) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + diff --git a/tests/test_parser/test_mindsdb/test_variables.py b/tests/test_parser/test_mindsdb/test_variables.py index e117eadb..27333033 100644 --- a/tests/test_parser/test_mindsdb/test_variables.py +++ b/tests/test_parser/test_mindsdb/test_variables.py @@ -16,22 +16,24 @@ def test_select_variable(self): assert str(ast).lower() == sql.lower() assert str(ast) == str(expected_ast) - sql = "set autocommit = 1, sql_mode = concat(@@sql_mode, ',STRICT_TRANS_TABLES')" + sql = "set autocommit=1, global sql_mode=concat(@@sql_mode, ',STRICT_TRANS_TABLES'), NAMES utf8mb4 COLLATE utf8mb4_unicode_ci" ast = parse_sql(sql) expected_ast = Set( - arg=Tuple([ - BinaryOperation(op='=', args=[ - Identifier('autocommit'), Constant(1) - ]), - BinaryOperation(op='=', args=[ - Identifier('sql_mode'), - Function(op='concat', args=[ + set_list=[ + Set(name=Identifier('autocommit'), value=Constant(1)), + Set(name=Identifier('sql_mode'), + scope='global', + value=Function(op='concat', args=[ Variable('sql_mode', is_system_var=True), Constant(',STRICT_TRANS_TABLES') ]) - ]) - ]) + ), + Set(category="NAMES", + value=Constant('utf8mb4', with_quotes=False), + params={'COLLATE': Constant('utf8mb4_unicode_ci', with_quotes=False)}) + ] ) + assert str(ast).lower() == sql.lower() assert str(ast) == str(expected_ast)