diff --git a/mindsdb_sql/__about__.py b/mindsdb_sql/__about__.py index 72d76d8f..ce7faebc 100644 --- a/mindsdb_sql/__about__.py +++ b/mindsdb_sql/__about__.py @@ -1,6 +1,6 @@ __title__ = 'mindsdb_sql' __package_name__ = 'mindsdb_sql' -__version__ = '0.10.4' +__version__ = '0.10.5' __description__ = "Pure python SQL parser" __email__ = "jorge@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/mindsdb_sql/parser/ast/select/constant.py b/mindsdb_sql/parser/ast/select/constant.py index 0b31af1e..7869e956 100644 --- a/mindsdb_sql/parser/ast/select/constant.py +++ b/mindsdb_sql/parser/ast/select/constant.py @@ -15,7 +15,8 @@ def to_tree(self, *args, level=0, **kwargs): def get_string(self, *args, **kwargs): if isinstance(self.value, str) and self.with_quotes: - out_str = f"\'{self.value}\'" + val = self.value.replace("'", "\\'") + out_str = f"\'{val}\'" elif isinstance(self.value, bool): out_str = 'TRUE' if self.value else 'FALSE' elif isinstance(self.value, (dt.date, dt.datetime, dt.timedelta)): diff --git a/mindsdb_sql/parser/ast/select/operation.py b/mindsdb_sql/parser/ast/select/operation.py index a67c69dd..77ac949b 100644 --- a/mindsdb_sql/parser/ast/select/operation.py +++ b/mindsdb_sql/parser/ast/select/operation.py @@ -44,9 +44,9 @@ def get_string(self, *args, **kwargs): arg_strs = [] for arg in self.args: arg_str = arg.to_string() - if isinstance(arg, BinaryOperation) or isinstance(arg, BetweenOperation): - # to parens - arg_str = f'({arg_str})' + # if isinstance(arg, BinaryOperation) or isinstance(arg, BetweenOperation): + # # to parens + # arg_str = f'({arg_str})' arg_strs.append(arg_str) return f'{arg_strs[0]} {self.op.upper()} {arg_strs[1]}' diff --git a/mindsdb_sql/parser/ast/set.py b/mindsdb_sql/parser/ast/set.py index f317d892..b0c1165f 100644 --- a/mindsdb_sql/parser/ast/set.py +++ b/mindsdb_sql/parser/ast/set.py @@ -49,7 +49,13 @@ def render(self): return ', '.join(render_list) if self.params: - param_str = ' ' + ' '.join([f'{k} {v}' for k, v in self.params.items()]) + params = [] + for k, v in self.params.items(): + if k.lower() == 'access_mode': + params.append(v) + else: + params.append(f'{k} {v}') + param_str = ' ' + ', '.join(params) else: param_str = '' diff --git a/mindsdb_sql/parser/ast/show.py b/mindsdb_sql/parser/ast/show.py index 64dbcfda..cb94db0f 100644 --- a/mindsdb_sql/parser/ast/show.py +++ b/mindsdb_sql/parser/ast/show.py @@ -69,11 +69,11 @@ def get_string(self, *args, **kwargs): in_str = ' ' + ' '.join(ar) modes_str = f' {" ".join(self.modes)}' if self.modes else '' - like_str = f' LIKE {self.like}' if self.like else '' + like_str = f" LIKE '{self.like}'" if self.like else "" where_str = f' WHERE {str(self.where)}' if self.where else '' # custom commands - if self.category in ('FUNCTION CODE', 'PROCEDURE CODE', 'ENGINE'): + if self.category in ('FUNCTION CODE', 'PROCEDURE CODE', 'ENGINE') or self.category.startswith('ENGINE '): return f'SHOW {self.category} {self.name}' elif self.category == 'REPLICA STATUS': channel = '' diff --git a/mindsdb_sql/parser/dialects/mindsdb/chatbot.py b/mindsdb_sql/parser/dialects/mindsdb/chatbot.py index f3366b80..5a3a5223 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/chatbot.py +++ b/mindsdb_sql/parser/dialects/mindsdb/chatbot.py @@ -39,7 +39,8 @@ def get_string(self, *args, **kwargs): params = self.params.copy() params['model'] = self.model.to_string() if self.model else 'NULL' params['database'] = self.database.to_string() - params['agent'] = self.agent.to_string() if self.agent else 'NULL' + if self.agent: + params['agent'] = self.agent.to_string() using_ar = [f'{k}={repr(v)}' for k, v in params.items()] diff --git a/mindsdb_sql/parser/dialects/mindsdb/create_database.py b/mindsdb_sql/parser/dialects/mindsdb/create_database.py index 759dbd57..c60fe64b 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/create_database.py +++ b/mindsdb_sql/parser/dialects/mindsdb/create_database.py @@ -43,8 +43,12 @@ def get_string(self, *args, **kwargs): if self.is_replace: replace_str = f' OR REPLACE' + engine_str = '' + if self.engine: + engine_str = f'ENGINE = {repr(self.engine)} ' + parameters_str = '' if self.parameters: parameters_str = f', PARAMETERS = {json.dumps(self.parameters)}' - out_str = f'CREATE{replace_str} DATABASE {"IF NOT EXISTS " if self.if_not_exists else ""}{self.name.to_string()} WITH ENGINE = {repr(self.engine)}{parameters_str}' + out_str = f'CREATE{replace_str} DATABASE {"IF NOT EXISTS " if self.if_not_exists else ""}{self.name.to_string()} {engine_str}{parameters_str}' return out_str diff --git a/mindsdb_sql/parser/dialects/mindsdb/create_job.py b/mindsdb_sql/parser/dialects/mindsdb/create_job.py index eb06a8df..d046e8ca 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/create_job.py +++ b/mindsdb_sql/parser/dialects/mindsdb/create_job.py @@ -79,7 +79,7 @@ def get_string(self, *args, **kwargs): if_query_str = '' if self.if_query_str is not None: - if_query_str = f" IF '{self.if_query_str}'" + if_query_str = f" IF ({self.if_query_str})" out_str = f'CREATE JOB {"IF NOT EXISTS" if self.if_not_exists else ""} {self.name.to_string()} ({self.query_str}){start_str}{end_str}{repeat_str}{if_query_str}' return out_str diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index d9c1c0b7..398a1d2b 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -72,6 +72,7 @@ class MindsDBLexer(Lexer): EQUALS, NEQUALS, GREATER, GEQ, LESS, LEQ, AND, OR, NOT, IS, IS_NOT, IN, LIKE, NOT_LIKE, CONCAT, BETWEEN, WINDOW, OVER, PARTITION_BY, + JSON_GET, JSON_GET_STR, # Data types CAST, ID, INTEGER, FLOAT, QUOTE_STRING, DQUOTE_STRING, NULL, TRUE, FALSE, @@ -263,6 +264,8 @@ class MindsDBLexer(Lexer): SEMICOLON = r'\;' # Operators + JSON_GET = r'->' + JSON_GET_STR = r'->>' PLUS = r'\+' MINUS = r'-' DIVIDE = r'/' @@ -288,7 +291,6 @@ class MindsDBLexer(Lexer): OVER = r'\bOVER\b' PARTITION_BY = r'\bPARTITION BY\b' - # Data types NULL = r'\bNULL\b' TRUE = r'\bTRUE\b' diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index ada46cf3..dfe973fa 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -44,10 +44,12 @@ class MindsDBParser(Parser): ('left', AND), ('right', UNOT), ('left', EQUALS, NEQUALS), + ('nonassoc', LESS, LEQ, GREATER, GEQ, IN, BETWEEN, IS, IS_NOT, NOT_LIKE, LIKE), + ('left', JSON_GET), ('left', PLUS, MINUS), ('left', STAR, DIVIDE), ('right', UMINUS), # Unary minus operator, unary not - ('nonassoc', LESS, LEQ, GREATER, GEQ, IN, BETWEEN, IS, IS_NOT, NOT_LIKE, LIKE), + ) # Top-level statements @@ -402,6 +404,8 @@ def set_item(self, p): 'CHARSET', ) def charset(self, p): + if hasattr(p, 'SET'): + return f'{p[0]} {p[1]}' return p[0] # set transaction @@ -419,7 +423,7 @@ def set_item(self, p): params = {} if isolation_level is not None: - params['isolation_level'] = isolation_level + params['isolation level'] = isolation_level if access_mode is not None: params['access_mode'] = access_mode @@ -523,75 +527,28 @@ def show(self, p): modes=modes ) - @_('SCHEMAS', - 'DATABASES', - 'TABLES', - 'OPEN TABLES', - 'TRIGGERS', - 'COLUMNS', - 'FIELDS', - 'PLUGINS', - 'VARIABLES', - 'INDEXES', - 'KEYS', - 'SESSION VARIABLES', - 'GLOBAL VARIABLES', - 'GLOBAL STATUS', - 'SESSION STATUS', - 'PROCEDURE STATUS', - 'FUNCTION STATUS', - 'TABLE STATUS', - 'MASTER STATUS', - 'STATUS', - 'STORAGE ENGINES', - 'PROCESSLIST', - 'INDEX', - 'CREATE TABLE', - 'WARNINGS', - 'ENGINES', - 'CHARSET', - 'CHARACTER SET', - 'COLLATION', - 'BINARY LOGS', - 'MASTER LOGS', - 'PRIVILEGES', - 'PROFILES', - 'REPLICAS', - 'SLAVE HOSTS', - # Mindsdb specific - 'VIEWS', - 'STREAMS', - 'PREDICTORS', - 'INTEGRATIONS', - 'DATASOURCES', - 'PUBLICATIONS', - 'DATASETS', - 'MODELS', - 'ML_ENGINES', - 'HANDLERS', - 'SEARCH_PATH', - 'KNOWLEDGE_BASES', - 'ALL') + @_( + 'id', + 'id id', + ) def show_category(self, p): - return ' '.join([x for x in p]) + if hasattr(p, 'id'): + return p.id + return f"{p.id0} {p.id1}" # custom show commands - @_('SHOW ENGINE identifier STATUS', - 'SHOW ENGINE identifier MUTEX') - def show(self, p): - return Show( - category=p[1], - name=p.identifier.to_string(), - modes=[p[3]] - ) - @_('SHOW FUNCTION CODE identifier', - 'SHOW PROCEDURE CODE identifier') + @_('SHOW id id identifier') def show(self, p): category = p[1] + ' ' + p[2] + + if p[1].lower() == 'engine': + name = p.identifier.parts[0] + else: + name = p.identifier.to_string() return Show( category=category, - name=p.identifier.to_string() + name=name ) @_('SHOW REPLICA STATUS FOR CHANNEL id', @@ -819,6 +776,7 @@ def create_predictor(self, p): 'CREATE ANOMALY DETECTION MODEL identifier FROM identifier LPAREN raw_query RPAREN', 'CREATE ANOMALY DETECTION MODEL identifier PREDICT result_columns', 'CREATE ANOMALY DETECTION MODEL identifier PREDICT result_columns FROM identifier LPAREN raw_query RPAREN', + 'CREATE ANOMALY DETECTION MODEL identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns', # TODO add IF_NOT_EXISTS elegantly (should be low level BNF expansion) ) def create_anomaly_detection_model(self, p): @@ -1135,14 +1093,6 @@ def select(self, p): select.where = where_expr return select - # Special cases for keyword-like identifiers - @_('select FROM TABLES') - def select(self, p): - select = p.select - ensure_select_keyword_order(select, 'FROM') - select.from_table = Identifier(p.TABLES) - return select - @_('select FROM from_table_aliased', 'select FROM join_tables_implicit', 'select FROM join_tables') @@ -1437,17 +1387,15 @@ def expr(self, p): 'expr LIKE expr', 'expr NOT_LIKE expr', 'expr CONCAT expr', + 'expr JSON_GET constant', + 'expr JSON_GET_STR constant', 'expr IN expr') def expr(self, p): if hasattr(p, 'LAST'): arg1 = Last() else: - arg1 = p.expr1 - if len(p) > 3: - op = ' '.join([p[i] for i in range(1, len(p)-1)]) - else: - op = p[1] - return BinaryOperation(op=op, args=(p[0], arg1)) + arg1 = p[2] + return BinaryOperation(op=p[1], args=(p[0], arg1)) @_('MINUS expr %prec UMINUS', 'NOT expr %prec UNOT', ) @@ -1645,6 +1593,7 @@ def parameter(self, p): 'HORIZON', 'HOSTS', 'INDEXES', + 'INDEX', 'INTEGRATION', 'INTEGRATIONS', 'ISOLATION', @@ -1686,6 +1635,7 @@ def parameter(self, p): 'STREAM', 'STREAMS', 'TABLES', + 'TABLE', 'TRAIN', 'TRANSACTION', 'TRIGGERS', @@ -1696,7 +1646,17 @@ def parameter(self, p): 'WARNINGS', 'MODEL', 'MODELS', - 'AGENT' + 'AGENT', + 'SCHEMAS', + 'FUNCTION', + 'charset', + 'PROCEDURE', + 'ML_ENGINES', + 'HANDLERS', + 'BINARY', + 'KNOWLEDGE_BASES', + 'ALL', + 'CREATE', ) def id(self, p): return p[0] diff --git a/mindsdb_sql/parser/dialects/mysql/parser.py b/mindsdb_sql/parser/dialects/mysql/parser.py index ee6541c3..fde778b5 100644 --- a/mindsdb_sql/parser/dialects/mysql/parser.py +++ b/mindsdb_sql/parser/dialects/mysql/parser.py @@ -176,7 +176,7 @@ def set_item(self, p): params = {} if isolation_level is not None: - params['isolation_level'] = isolation_level + params['isolation level'] = isolation_level if access_mode is not None: params['access_mode'] = access_mode @@ -594,14 +594,6 @@ def select(self, p): select.where = where_expr return select - # Special cases for keyword-like identifiers - @_('select FROM TABLES') - def select(self, p): - select = p.select - ensure_select_keyword_order(select, 'FROM') - select.from_table = Identifier(p.TABLES) - return select - @_('select FROM from_table_aliased', 'select FROM join_tables_implicit', 'select FROM join_tables') diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index acd1003c..da20aa54 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -104,6 +104,7 @@ def plan(self, query): ): query2 = copy.deepcopy(query) query2.from_table = None + query2.using = None sup_select = QueryStep(query2, from_table=join_step.result) self.planner.plan.add_step(sup_select) return sup_select @@ -429,11 +430,25 @@ def process_predictor(self, item, query_in): # exclude condition el._orig_node.args = [Constant(0), Constant(0)] + # params for model + model_params = None + + if query_in.using is not None: + model_params = {} + for param, value in query_in.using.items(): + if '.' in param: + alias = param.split('.')[0] + if (alias,) in item.aliases: + new_param = '.'.join(param.split('.')[1:]) + model_params[new_param] = value + else: + model_params[param] = value + predictor_step = self.planner.plan.add_step(ApplyPredictorStep( namespace=item.integration, dataframe=data_step.result, predictor=item.table, - params=query_in.using, + params=model_params, row_dict=row_dict, )) self.step_stack.append(predictor_step) diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index c138b113..7e4f4288 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -592,28 +592,33 @@ def prepare_update(self, ast_query): return stmt + def get_query(self, ast_query): + if isinstance(ast_query, ast.Select): + stmt = self.prepare_select(ast_query) + elif isinstance(ast_query, ast.Insert): + stmt = self.prepare_insert(ast_query) + elif isinstance(ast_query, ast.Update): + stmt = self.prepare_update(ast_query) + elif isinstance(ast_query, ast.CreateTable): + stmt = self.prepare_create_table(ast_query) + elif isinstance(ast_query, ast.DropTables): + stmt = self.prepare_drop_table(ast_query) + else: + raise NotImplementedError(f'Unknown statement: {ast_query.__class__.__name__}') + return stmt + def get_string(self, ast_query, with_failback=True): + if isinstance(ast_query, (ast.CreateTable, ast.DropTables)): + render_func = render_ddl_query + else: + render_func = render_dml_query + try: - if isinstance(ast_query, ast.Select): - stmt = self.prepare_select(ast_query) - sql = render_dml_query(stmt, self.dialect) - elif isinstance(ast_query, ast.Insert): - stmt = self.prepare_insert(ast_query) - sql = render_dml_query(stmt, self.dialect) - elif isinstance(ast_query, ast.Update): - stmt = self.prepare_update(ast_query) - sql = render_dml_query(stmt, self.dialect) - elif isinstance(ast_query, ast.CreateTable): - stmt = self.prepare_create_table(ast_query) - sql = render_ddl_query(stmt, self.dialect) - elif isinstance(ast_query, ast.DropTables): - stmt = self.prepare_drop_table(ast_query) - sql = render_ddl_query(stmt, self.dialect) - else: - raise NotImplementedError(f'Unknown statement: {ast_query.__class__.__name__}') + stmt = self.get_query(ast_query) - return sql + sql = render_func(stmt, self.dialect) + return sql except (SQLAlchemyError, NotImplementedError) as e: if not with_failback: diff --git a/tests/test_parser/test_base_sql/test_misc_sql_queries.py b/tests/test_parser/test_base_sql/test_misc_sql_queries.py index 68acb3f9..1c6e3145 100644 --- a/tests/test_parser/test_base_sql/test_misc_sql_queries.py +++ b/tests/test_parser/test_base_sql/test_misc_sql_queries.py @@ -148,10 +148,10 @@ def test_set_transaction(self, dialect): ast = parse_sql(sql, dialect=dialect) expected_ast = Set( category='TRANSACTION', - params=dict( - isolation_level='REPEATABLE READ', - access_mode='READ WRITE', - ), + params={ + 'isolation level': 'REPEATABLE READ', + 'access_mode': 'READ WRITE', + }, scope='GLOBAL' ) @@ -164,10 +164,10 @@ def test_set_transaction(self, dialect): expected_ast = Set( category='TRANSACTION', - params=dict( - isolation_level='SERIALIZABLE', - access_mode='READ ONLY', - ), + params={ + 'isolation level': 'SERIALIZABLE', + 'access_mode': 'READ ONLY', + }, scope='SESSION' ) @@ -180,9 +180,9 @@ def test_set_transaction(self, dialect): expected_ast = Set( category='TRANSACTION', - params=dict( - isolation_level='READ UNCOMMITTED', - ) + params={ + 'isolation level': 'READ UNCOMMITTED' + }, ) assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_parser/test_base_sql/test_show.py b/tests/test_parser/test_base_sql/test_show.py index 44f24c5e..d32fd2b2 100644 --- a/tests/test_parser/test_base_sql/test_show.py +++ b/tests/test_parser/test_base_sql/test_show.py @@ -36,12 +36,6 @@ def test_show_category(self, dialect): assert str(ast) == str(expected_ast) assert ast.to_tree() == expected_ast.to_tree() - def test_show_unknown_category_error(self, dialect): - sql = "SHOW abracadabra" - - with pytest.raises(ParsingException): - parse_sql(sql, dialect=dialect) - def test_show_unknown_condition_error(self, dialect): sql = "SHOW databases WITH" with pytest.raises(ParsingException): @@ -249,18 +243,6 @@ def test_common_like_double_where_from_in_modes(self, dialect): def test_custom(self, dialect): - for arg in ['STATUS', 'MUTEX']: - sql = f"SHOW ENGINE engine_name {arg}" - ast = parse_sql(sql, dialect=dialect) - expected_ast = Show( - category='ENGINE', - name='engine_name', - modes=[arg], - ) - - assert str(ast) == str(expected_ast) - assert ast.to_tree() == expected_ast.to_tree() - for arg in ['FUNCTION', 'PROCEDURE']: sql = f"SHOW {arg} CODE obj_name" ast = parse_sql(sql, dialect=dialect) @@ -310,6 +292,18 @@ def test_show_database_adapted(self): class TestMindsdb: + def test_show_engine(self): + for arg in ['STATUS', 'MUTEX']: + sql = f"SHOW ENGINE engine_name {arg}" + ast = parse_sql(sql) + expected_ast = Show( + category='ENGINE engine_name', + name=arg, + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + def test_show(self): sql = ''' show full databases diff --git a/tests/test_parser/test_mindsdb/test_create_view.py b/tests/test_parser/test_mindsdb/test_create_view.py index f73e4b99..b6773ad4 100644 --- a/tests/test_parser/test_mindsdb/test_create_view.py +++ b/tests/test_parser/test_mindsdb/test_create_view.py @@ -15,7 +15,7 @@ def test_create_view_lexer(self): assert tokens[1].value == 'VIEW' assert tokens[1].type == 'VIEW' - def test_create_view_raises_wrong_dialect(self): + def test_create_view_raises_wrong_dialect_error(self): sql = "CREATE VIEW my_view FROM integr AS ( SELECT * FROM pred )" for dialect in ['sqlite', 'mysql']: with pytest.raises(ParsingException): diff --git a/tests/test_parser/test_mindsdb/test_selects.py b/tests/test_parser/test_mindsdb/test_selects.py index 5b68515a..f811afd0 100644 --- a/tests/test_parser/test_mindsdb/test_selects.py +++ b/tests/test_parser/test_mindsdb/test_selects.py @@ -157,5 +157,29 @@ def test_last(self): assert str(ast) == str(expected_ast) + def test_json(self): + sql = """SELECT col->1->'c' from TAB1""" + + ast = parse_sql(sql, dialect='mindsdb') + expected_ast = Select( + targets=[BinaryOperation( + op='->', + args=[ + BinaryOperation( + op='->', + args=[ + Identifier('col'), + Constant(1) + ] + ), + Constant('c') + ] + )], + from_table=Identifier(parts=['TAB1']), + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + diff --git a/tests/test_parser/test_standard_render.py b/tests/test_parser/test_standard_render.py new file mode 100644 index 00000000..ef12db80 --- /dev/null +++ b/tests/test_parser/test_standard_render.py @@ -0,0 +1,90 @@ +import inspect +import pkgutil +import sys +import os +import importlib + +from mindsdb_sql import parse_sql, Parameter +from mindsdb_sql.planner.utils import query_traversal + + +def load_all_modules_from_dir(dir_names): + for importer, package_name, _ in pkgutil.iter_modules(dir_names): + full_package_name = package_name + if full_package_name not in sys.modules: + spec = importer.find_spec(package_name) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + yield module + + +def check_module(module): + if module.__name__ in ('test_mysql_lexer', 'test_base_lexer'): + # skip + return + + for class_name, klass in inspect.getmembers(module, predicate=inspect.isclass): + if not class_name.startswith('Test'): + continue + + tests = klass() + for test_name, test_method in inspect.getmembers(tests, predicate=inspect.ismethod): + if not test_name.startswith('test_') or test_name.endswith('_error'): + # skip tests that expected error + continue + sig = inspect.signature(test_method) + args = [] + # add dialect + if 'dialect' in sig.parameters: + args.append('mindsdb') + if 'cat' in sig.parameters: + # skip it + continue + + test_method(*args) + + +def parse_sql2(sql, dialect='mindsdb'): + + params = [] + def check_param_f(node, **kwargs): + if isinstance(node, Parameter): + params.append(node) + + query = parse_sql(sql, dialect) + + # skip queries with params + query_traversal(query, check_param_f) + if len(params) > 0: + return query + + # render + sql2 = query.to_string() + + # Parse again + query2 = parse_sql(sql2, dialect) + + # compare result from first and second parsing + assert str(query) == str(query2) + + # return to test: it compares it with expected_ast + return query2 + + +def test_standard_render(): + + base_dir = os.path.dirname(__file__) + dir_names = [ + os.path.join(base_dir, folder) + for folder in os.listdir(base_dir) + if folder.startswith('test_') + ] + + for module in load_all_modules_from_dir(dir_names): + + # inject function + module.parse_sql = parse_sql2 + + check_module(module) + + diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index d2d80a4a..517f34ae 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -736,6 +736,9 @@ def test_model_join_model(self): join mindsdb.pred m join mindsdb.pred m2 where m.a = 2 + using m.param1 = 'a', + m2.param2 = 'b', + param3 = 'c' ''' subquery = parse_sql(""" @@ -750,13 +753,15 @@ def test_model_join_model(self): FetchDataframeStep(integration='int', query=parse_sql('select * from tab1 as t')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), - predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': 2}), + predictor=Identifier('pred', alias=Identifier('m')), + row_dict={ 'a': 2 }, params={ 'param1': 'a', 'param3': 'c' }), JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(2), - predictor=Identifier('pred', alias=Identifier('m2'))), + predictor=Identifier('pred', alias=Identifier('m2')), + params={ 'param2': 'b', 'param3': 'c' }), JoinStep(left=Result(2), right=Result(3), query=Join(left=Identifier('tab1'), right=Identifier('tab2'),