diff --git a/mindsdb_sql/__about__.py b/mindsdb_sql/__about__.py index ecd4e773..7987bfc5 100644 --- a/mindsdb_sql/__about__.py +++ b/mindsdb_sql/__about__.py @@ -1,6 +1,6 @@ __title__ = 'mindsdb_sql' __package_name__ = 'mindsdb_sql' -__version__ = '0.20.0' +__version__ = '0.21.0' __description__ = "Pure python SQL parser" __email__ = "jorge@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/mindsdb_sql/parser/ast/select/__init__.py b/mindsdb_sql/parser/ast/select/__init__.py index d064242c..7dd64b61 100644 --- a/mindsdb_sql/parser/ast/select/__init__.py +++ b/mindsdb_sql/parser/ast/select/__init__.py @@ -1,6 +1,6 @@ from .select import Select from .common_table_expression import CommonTableExpression -from .union import Union +from .union import Union, Except, Intersect from .constant import Constant, NullConstant, Last from .star import Star from .identifier import Identifier diff --git a/mindsdb_sql/parser/ast/select/case.py b/mindsdb_sql/parser/ast/select/case.py index 02fa8275..1ae0df06 100644 --- a/mindsdb_sql/parser/ast/select/case.py +++ b/mindsdb_sql/parser/ast/select/case.py @@ -4,13 +4,14 @@ class Case(ASTNode): - def __init__(self, rules, default=None, *args, **kwargs): + def __init__(self, rules, default=None, arg=None, *args, **kwargs): super().__init__(*args, **kwargs) # structure: # [ # [ condition, result ] # ] + self.arg = arg self.rules = rules self.default = default @@ -36,7 +37,12 @@ def to_tree(self, *args, level=0, **kwargs): if self.default is not None: default_str = f'{ind1}default => {self.default.to_string()}\n' + arg_str = '' + if self.arg is not None: + arg_str = f'{ind1}arg => {self.arg.to_string()}\n' + return f'{ind}Case(\n' \ + f'{arg_str}'\ f'{rules_str}\n' \ f'{default_str}' \ f'{ind})' @@ -53,4 +59,8 @@ def get_string(self, *args, alias=True, **kwargs): default_str = '' if self.default is not None: default_str = f' ELSE {self.default.to_string()}' - return f"CASE {rules_str}{default_str} END" + + arg_str = '' + if self.arg is not None: + arg_str = f'{self.arg.to_string()} ' + return f"CASE {arg_str}{rules_str}{default_str} END" diff --git a/mindsdb_sql/parser/ast/select/operation.py b/mindsdb_sql/parser/ast/select/operation.py index d7cad6a6..e208d435 100644 --- a/mindsdb_sql/parser/ast/select/operation.py +++ b/mindsdb_sql/parser/ast/select/operation.py @@ -98,12 +98,13 @@ def get_string(self, *args, **kwargs): class WindowFunction(ASTNode): - def __init__(self, function, partition=None, order_by=None, alias=None): + def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None): super().__init__() self.function = function self.partition = partition self.order_by = order_by self.alias = alias + self.modifier = modifier def to_tree(self, *args, level=0, **kwargs): fnc_str = self.function.to_tree(level=level+2) @@ -143,7 +144,8 @@ def to_string(self, *args, **kwargs): alias_str = self.alias.to_string() else: alias_str = '' - return f'{fnc_str} over({partition_str} {order_str}) {alias_str}' + modifier_str = ' ' + self.modifier if self.modifier else '' + return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}' class Object(ASTNode): @@ -177,7 +179,12 @@ def __init__(self, info): super().__init__(op='interval', args=[info, ]) def get_string(self, *args, **kwargs): - return f'INTERVAL {self.args[0]}' + + arg = self.args[0] + items = arg.split(' ', maxsplit=1) + # quote first element + items[0] = f"'{items[0]}'" + return "INTERVAL " + " ".join(items) def to_tree(self, *args, level=0, **kwargs): return self.get_string( *args, **kwargs) diff --git a/mindsdb_sql/parser/ast/select/type_cast.py b/mindsdb_sql/parser/ast/select/type_cast.py index f0ae304c..7bad1ce9 100644 --- a/mindsdb_sql/parser/ast/select/type_cast.py +++ b/mindsdb_sql/parser/ast/select/type_cast.py @@ -3,19 +3,20 @@ class TypeCast(ASTNode): - def __init__(self, type_name, arg, length=None, *args, **kwargs): + def __init__(self, type_name, arg, precision=None, *args, **kwargs): super().__init__(*args, **kwargs) self.type_name = type_name self.arg = arg - self.length = length + self.precision = precision def to_tree(self, *args, level=0, **kwargs): - out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, length={self.length}, arg=\n{indent(level+1)}{self.arg.to_tree()})' + out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, precision={self.precision}, arg=\n{indent(level+1)}{self.arg.to_tree()})' return out_str def get_string(self, *args, **kwargs): type_name = self.type_name - if self.length is not None: - type_name += f'({self.length})' + if self.precision is not None: + precision = map(str, self.precision) + type_name += f'({",".join(precision)})' return f'CAST({str(self.arg)} AS {type_name})' diff --git a/mindsdb_sql/parser/ast/select/union.py b/mindsdb_sql/parser/ast/select/union.py index e78609ea..dce1da2b 100644 --- a/mindsdb_sql/parser/ast/select/union.py +++ b/mindsdb_sql/parser/ast/select/union.py @@ -2,7 +2,8 @@ from mindsdb_sql.parser.utils import indent -class Union(ASTNode): +class CombiningQuery(ASTNode): + operation = None def __init__(self, left, @@ -24,7 +25,8 @@ def to_tree(self, *args, level=0, **kwargs): left_str = f'\n{ind1}left=\n{self.left.to_tree(level=level + 2)},' right_str = f'\n{ind1}right=\n{self.right.to_tree(level=level + 2)},' - out_str = f'{ind}Union(unique={repr(self.unique)},' \ + cls_name = self.__class__.__name__ + out_str = f'{ind}{cls_name}(unique={repr(self.unique)},' \ f'{left_str}' \ f'{right_str}' \ f'\n{ind})' @@ -33,7 +35,21 @@ def to_tree(self, *args, level=0, **kwargs): def get_string(self, *args, **kwargs): left_str = str(self.left) right_str = str(self.right) - keyword = 'UNION' if self.unique else 'UNION ALL' + keyword = self.operation + if not self.unique: + keyword += ' ALL' out_str = f"""{left_str}\n{keyword}\n{right_str}""" return out_str + + +class Union(CombiningQuery): + operation = 'UNION' + + +class Intersect(CombiningQuery): + operation = 'INTERSECT' + + +class Except(CombiningQuery): + operation = 'EXCEPT' diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index 01024c97..9232efbc 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -55,7 +55,7 @@ class MindsDBLexer(Lexer): JOIN, INNER, OUTER, CROSS, LEFT, RIGHT, ON, - UNION, ALL, + UNION, ALL, INTERSECT, EXCEPT, # CASE CASE, ELSE, END, THEN, WHEN, @@ -238,6 +238,8 @@ class MindsDBLexer(Lexer): # UNION UNION = r'\bUNION\b' + INTERSECT = r'\bINTERSECT\b' + EXCEPT = r'\bEXCEPT\b' ALL = r'\bALL\b' # CASE diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 05233429..ae139597 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -46,7 +46,7 @@ class MindsDBParser(Parser): ('nonassoc', LESS, LEQ, GREATER, GEQ, IN, NOT_IN, BETWEEN, IS, IS_NOT, NOT_LIKE, LIKE), ('left', JSON_GET), ('left', PLUS, MINUS), - ('left', STAR, DIVIDE, TYPECAST), + ('left', STAR, DIVIDE, TYPECAST, MODULO), ('right', UMINUS), # Unary minus operator, unary not ) @@ -68,9 +68,9 @@ class MindsDBParser(Parser): 'drop_predictor', 'drop_datasource', 'drop_dataset', - 'union', 'select', 'insert', + 'union', 'update', 'delete', 'evaluate', @@ -615,10 +615,13 @@ def update(self, p): # INSERT @_('INSERT INTO identifier LPAREN column_list RPAREN select', - 'INSERT INTO identifier select') + 'INSERT INTO identifier LPAREN column_list RPAREN union', + 'INSERT INTO identifier select', + 'INSERT INTO identifier union') def insert(self, p): columns = getattr(p, 'column_list', None) - return Insert(table=p.identifier, columns=columns, from_select=p.select) + query = p.select if hasattr(p, 'select') else p.union + return Insert(table=p.identifier, columns=columns, from_select=query) @_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set', 'INSERT INTO identifier VALUES expr_list_set') @@ -999,21 +1002,35 @@ def database_engine(self, p): engine = p.string return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty} - # UNION / UNION ALL + # Combining @_('select UNION select', - 'union UNION select') + 'union UNION select', + 'select UNION ALL select', + 'union UNION ALL select') def union(self, p): - return Union(left=p[0], right=p[2], unique=True) + unique = not hasattr(p, 'ALL') + return Union(left=p[0], right=p[2] if unique else p[3], unique=unique) - @_('select UNION ALL select', - 'union UNION ALL select',) + @_('select INTERSECT select', + 'union INTERSECT select', + 'select INTERSECT ALL select', + 'union INTERSECT ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Intersect(left=p[0], right=p[2] if unique else p[3], unique=unique) + @_('select EXCEPT select', + 'union EXCEPT select', + 'select EXCEPT ALL select', + 'union EXCEPT ALL select') def union(self, p): - return Union(left=p[0], right=p[3], unique=False) + unique = not hasattr(p, 'ALL') + return Except(left=p[0], right=p[2] if unique else p[3], unique=unique) # tableau @_('LPAREN select RPAREN') + @_('LPAREN union RPAREN') def select(self, p): - return p.select + return p[1] # WITH @_('ctes select') @@ -1033,13 +1050,14 @@ def ctes(self, p): ] return ctes - @_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN') + @_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN', + 'WITH identifier cte_columns_or_nothing AS LPAREN union RPAREN') def ctes(self, p): return [ CommonTableExpression( name=p.identifier, columns=p.cte_columns_or_nothing, - query=p.select) + query=p[5]) ] @_('empty') @@ -1334,6 +1352,15 @@ def column_list(self, p): def case(self, p): return Case(rules=p.case_conditions, default=getattr(p, 'expr', None)) + @_('CASE expr case_conditions ELSE expr END', + 'CASE expr case_conditions END') + def case(self, p): + if hasattr(p, 'expr'): + arg, default = p.expr, None + else: + arg, default = p.expr0, p.expr1 + return Case(rules=p.case_conditions, default=default, arg=arg) + @_('case_condition', 'case_conditions case_condition') def case_conditions(self, p): @@ -1346,13 +1373,18 @@ def case_condition(self, p): return [p.expr0, p.expr1] # Window function - @_('function OVER LPAREN window RPAREN') + @_('expr OVER LPAREN window RPAREN', + 'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN') def window_function(self, p): + modifier = None + if hasattr(p, 'BETWEEN'): + modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}' return WindowFunction( - function=p.function, + function=p.expr, order_by=p.window.get('order_by'), partition=p.window.get('partition'), + modifier=modifier, ) @_('window PARTITION_BY expr_list') @@ -1469,8 +1501,13 @@ def expr_list_or_nothing(self, p): pass @_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN') + @_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN') def expr(self, p): - return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer) + if hasattr(p, 'integer'): + precision=[p.integer] + else: + precision=[p.integer0, p.integer1] + return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision) @_('CAST LPAREN expr AS id RPAREN') def expr(self, p): diff --git a/mindsdb_sql/parser/dialects/mysql/parser.py b/mindsdb_sql/parser/dialects/mysql/parser.py index bb128562..0f28ba7a 100644 --- a/mindsdb_sql/parser/dialects/mysql/parser.py +++ b/mindsdb_sql/parser/dialects/mysql/parser.py @@ -821,8 +821,13 @@ def expr_list_or_nothing(self, p): pass @_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN') + @_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN') def expr(self, p): - return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer) + if hasattr(p, 'integer'): + precision=[p.integer] + else: + precision=[p.integer0, p.integer1] + return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision) @_('CAST LPAREN expr AS id RPAREN') def expr(self, p): diff --git a/mindsdb_sql/parser/parser.py b/mindsdb_sql/parser/parser.py index 18576c86..6794116b 100644 --- a/mindsdb_sql/parser/parser.py +++ b/mindsdb_sql/parser/parser.py @@ -581,8 +581,13 @@ def expr_list_or_nothing(self, p): pass @_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN') + @_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN') def expr(self, p): - return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer) + if hasattr(p, 'integer'): + precision=[p.integer] + else: + precision=[p.integer0, p.integer1] + return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision) @_('CAST LPAREN expr AS id RPAREN') def expr(self, p): diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index 3089dc6f..1bb3afa5 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -21,7 +21,7 @@ class TableInfo: sub_select: ast.ASTNode = None predictor_info: dict = None join_condition = None - + index: int = None class PlanJoin: @@ -85,12 +85,15 @@ def __init__(self, planner): # index to lookup tables self.tables_idx = None + self.tables = [] + self.tables_fetch_step = {} self.step_stack = None self.query_context = {} self.partition = None + def plan(self, query): self.tables_idx = {} join_step = self.plan_join_tables(query) @@ -109,6 +112,7 @@ def plan(self, query): query2 = copy.deepcopy(query) query2.from_table = None query2.using = None + query2.cte = None sup_select = QueryStep(query2, from_table=join_step.result) self.planner.plan.add_step(sup_select) return sup_select @@ -145,7 +149,8 @@ def resolve_table(self, table): return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select) def get_table_for_column(self, column: Identifier): - + if not isinstance(column, Identifier): + return # to lowercase parts = tuple(map(str.lower, column.parts[:-1])) if parts in self.tables_idx: @@ -160,6 +165,9 @@ def get_join_sequence(self, node, condition=None): for alias in table_info.aliases: self.tables_idx[alias] = table_info + table_info.index = len(self.tables) + self.tables.append(table_info) + table_info.predictor_info = self.planner.get_predictor(node) if condition is not None: @@ -368,20 +376,25 @@ def process_subselect(self, item): self.step_stack.append(step2) def process_table(self, item, query_in): - query2 = Select(from_table=item.table, targets=[Star()]) + table = copy.deepcopy(item.table) + table.parts.insert(0, item.integration) + query2 = Select(from_table=table, targets=[Star()]) # parts = tuple(map(str.lower, table_name.parts)) conditions = item.conditions if 'or' in self.query_context['binary_ops']: # not use conditions conditions = [] + conditions += self.get_filters_from_join_conditions(item) + if self.query_context['use_limit']: order_by = None if query_in.order_by is not None: order_by = [] # all order column be from this table for col in query_in.order_by: - if self.get_table_for_column(col.field).table != item.table: + table_info = self.get_table_for_column(col.field) + if table_info is None or table_info.table != item.table: order_by = False break col = copy.deepcopy(col) @@ -404,8 +417,9 @@ def process_table(self, item, query_in): else: query2.where = cond - # step = self.planner.get_integration_select_step(query2) - step = FetchDataframeStep(integration=item.integration, query=query2) + step = self.planner.get_integration_select_step(query2) + self.tables_fetch_step[item.index] = step + self.add_plan_step(step) self.step_stack.append(step) @@ -440,6 +454,70 @@ def _check_conditions(node, **kwargs): query_traversal(model_table.join_condition, _check_conditions) return columns_map + def get_filters_from_join_conditions(self, fetch_table): + + binary_ops = set() + conditions = [] + data_conditions = [] + + def _check_conditions(node, **kwargs): + if not isinstance(node, BinaryOperation): + return + + if node.op != '=': + binary_ops.add(node.op.lower()) + return + + arg1, arg2 = node.args + table1 = self.get_table_for_column(arg1) if isinstance(arg1, Identifier) else None + table2 = self.get_table_for_column(arg2) if isinstance(arg2, Identifier) else None + + if table1 is not fetch_table: + if table2 is not fetch_table: + return + # set our table first + table1, table2 = table2, table1 + arg1, arg2 = arg2, arg1 + + if isinstance(arg2, Constant): + conditions.append(node) + elif table2 is not None: + data_conditions.append([arg1, arg2]) + + query_traversal(fetch_table.join_condition, _check_conditions) + + binary_ops.discard('and') + if len(binary_ops) > 0: + # other operations exists, skip + return [] + + for arg1, arg2 in data_conditions: + # is fetched? + table2 = self.get_table_for_column(arg2) + fetch_step = self.tables_fetch_step.get(table2.index) + + if fetch_step is None: + continue + + # extract distinct values + # remove aliases + arg1 = Identifier(parts=[arg1.parts[-1]]) + arg2 = Identifier(parts=[arg2.parts[-1]]) + + query2 = Select(targets=[arg2], distinct=True) + subselect_step = SubSelectStep(query2, fetch_step.result) + subselect_step = self.add_plan_step(subselect_step) + + conditions.append(BinaryOperation( + op='in', + args=[ + arg1, + Parameter(subselect_step.result) + ] + )) + + return conditions + def process_predictor(self, item, query_in): if len(self.step_stack) == 0: raise NotImplementedError("Predictor can't be first element of join syntax") diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index e6d58be2..35bea81c 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -3,7 +3,7 @@ from mindsdb_sql.exceptions import PlanningException from mindsdb_sql.parser import ast from mindsdb_sql.parser.ast import (Select, Identifier, Join, Star, BinaryOperation, Constant, Union, CreateTable, - Function, Insert, + Function, Insert, Except, Intersect, Update, NativeQuery, Parameter, Delete) from mindsdb_sql.planner import utils from mindsdb_sql.planner.query_plan import QueryPlan @@ -86,6 +86,8 @@ def __init__(self, self.statement = None + self.cte_results = {} + def is_predictor(self, identifier): if not isinstance(identifier, Identifier): return False @@ -158,6 +160,12 @@ def get_integration_select_step(self, select): else: integration_name, table = self.resolve_database_table(select.from_table) + # is it CTE? + table_name = table.parts[-1] + if integration_name == self.default_namespace and table_name in self.cte_results: + select.from_table = None + return SubSelectStep(select, self.cte_results[table_name], table_name=table_name) + fetch_df_select = copy.deepcopy(select) self.prepare_integration_select(integration_name, fetch_df_select) @@ -221,6 +229,19 @@ def find_objects(node, is_table, **kwargs): mdb_entities.append(node) query_traversal(query, find_objects) + + # cte names are not mdb objects + if isinstance(query, Select) and query.cte: + cte_names = [ + cte.name.parts[-1] + for cte in query.cte + ] + mdb_entities = [ + item + for item in mdb_entities + if '.'.join(item.parts) not in cte_names + ] + return { 'mdb_entities': mdb_entities, 'integrations': integrations, @@ -250,21 +271,21 @@ def find_selects(node, **kwargs): return find_selects def plan_select_identifier(self, query): - query_info = self.get_query_info(query) - - if len(query_info['integrations']) == 0 and len(query_info['predictors']) >= 1: - # select from predictor - return self.plan_select_from_predictor(query) - elif ( - len(query_info['integrations']) == 1 - and len(query_info['mdb_entities']) == 0 - and len(query_info['user_functions']) == 0 - ): - - int_name = list(query_info['integrations'])[0] - if self.integrations.get(int_name, {}).get('class_type') != 'api': - # one integration without predictors, send all query to integration - return self.plan_integration_select(query) + # query_info = self.get_query_info(query) + # + # if len(query_info['integrations']) == 0 and len(query_info['predictors']) >= 1: + # # select from predictor + # return self.plan_select_from_predictor(query) + # elif ( + # len(query_info['integrations']) == 1 + # and len(query_info['mdb_entities']) == 0 + # and len(query_info['user_functions']) == 0 + # ): + # + # int_name = list(query_info['integrations'])[0] + # if self.integrations.get(int_name, {}).get('class_type') != 'api': + # # one integration without predictors, send all query to integration + # return self.plan_integration_select(query) # find subselects main_integration, _ = self.resolve_database_table(query.from_table) @@ -359,21 +380,21 @@ def plan_api_db_select(self, query): def plan_nested_select(self, select): - query_info = self.get_query_info(select) - # get all predictors - - if ( - len(query_info['mdb_entities']) == 0 - and len(query_info['integrations']) == 1 - and len(query_info['user_functions']) == 0 - and 'files' not in query_info['integrations'] - and 'views' not in query_info['integrations'] - ): - int_name = list(query_info['integrations'])[0] - if self.integrations.get(int_name, {}).get('class_type') != 'api': - - # if no predictor inside = run as is - return self.plan_integration_nested_select(select, int_name) + # query_info = self.get_query_info(select) + # # get all predictors + # + # if ( + # len(query_info['mdb_entities']) == 0 + # and len(query_info['integrations']) == 1 + # and len(query_info['user_functions']) == 0 + # and 'files' not in query_info['integrations'] + # and 'views' not in query_info['integrations'] + # ): + # int_name = list(query_info['integrations'])[0] + # if self.integrations.get(int_name, {}).get('class_type') != 'api': + # + # # if no predictor inside = run as is + # return self.plan_integration_nested_select(select, int_name) return self.plan_mdb_nested_select(select) @@ -663,7 +684,45 @@ def plan_delete(self, query: Delete): where=query.where )) + def plan_cte(self, query): + + for cte in query.cte: + step = self.plan_select(cte.query) + name = cte.name.parts[-1] + self.cte_results[name] = step.result + + def check_single_integration(self, query): + query_info = self.get_query_info(query) + + # can we send all query to integration? + + # one integration and not mindsdb objects in query + if ( + len(query_info['mdb_entities']) == 0 + and len(query_info['integrations']) == 1 + and 'files' not in query_info['integrations'] + and 'views' not in query_info['integrations'] + and len(query_info['user_functions']) == 0 + ): + + int_name = list(query_info['integrations'])[0] + # if is sql database + if self.integrations.get(int_name, {}).get('class_type') != 'api': + + # send to this integration + self.prepare_integration_select(int_name, query) + + last_step = self.plan.add_step(FetchDataframeStep(integration=int_name, query=query)) + return last_step + def plan_select(self, query, integration=None): + + if isinstance(query, (Union, Except, Intersect)): + return self.plan_union(query, integration=integration) + + if query.cte is not None: + self.plan_cte(query) + from_table = query.from_table if isinstance(from_table, Identifier): @@ -713,15 +772,16 @@ def plan_sub_select(self, query, prev_step, add_absent_cols=False): return sup_select return prev_step - def plan_union(self, query): - if isinstance(query.left, Union): - step1 = self.plan_union(query.left) - else: - # it is select - step1 = self.plan_select(query.left) - step2 = self.plan_select(query.right) + def plan_union(self, query, integration=None): + step1 = self.plan_select(query.left, integration=integration) + step2 = self.plan_select(query.right, integration=integration) + operation = 'union' + if isinstance(query, Except): + operation = 'except' + elif isinstance(query, Intersect): + operation = 'intersect' - return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique)) + return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique, operation=operation)) # method for compatibility def from_query(self, query=None): @@ -730,10 +790,10 @@ def from_query(self, query=None): if query is None: query = self.query - if isinstance(query, Select): + if isinstance(query, (Select, Union, Except, Intersect)): + if self.check_single_integration(query): + return self.plan self.plan_select(query) - elif isinstance(query, Union): - self.plan_union(query) elif isinstance(query, CreateTable): self.plan_create_table(query) elif isinstance(query, Insert): diff --git a/mindsdb_sql/planner/query_prepare.py b/mindsdb_sql/planner/query_prepare.py index 26a10c8b..9614dfc9 100644 --- a/mindsdb_sql/planner/query_prepare.py +++ b/mindsdb_sql/planner/query_prepare.py @@ -348,6 +348,8 @@ def find_predictors(node, is_table, **kwargs): elif column.name is not None: # is Identifier + if isinstance(column.name, ast.Star): + continue col_name = column.name.upper() if column.table is not None: table = column.table diff --git a/mindsdb_sql/planner/steps.py b/mindsdb_sql/planner/steps.py index 40e86e53..395f3079 100644 --- a/mindsdb_sql/planner/steps.py +++ b/mindsdb_sql/planner/steps.py @@ -75,11 +75,12 @@ def __init__(self, left, right, query, *args, **kwargs): class UnionStep(PlanStep): """Union of two dataframes, producing a new dataframe""" - def __init__(self, left, right, unique, *args, **kwargs): + def __init__(self, left, right, unique, operation='union', *args, **kwargs): super().__init__(*args, **kwargs) self.left = left self.right = right self.unique = unique + self.operation = operation # TODO remove diff --git a/mindsdb_sql/planner/utils.py b/mindsdb_sql/planner/utils.py index e4d52787..9baa16ec 100644 --- a/mindsdb_sql/planner/utils.py +++ b/mindsdb_sql/planner/utils.py @@ -145,7 +145,7 @@ def query_traversal(node, callback, is_table=False, is_target=False, parent_quer array.append(node_out) node.order_by = array - elif isinstance(node, ast.Union): + elif isinstance(node, (ast.Union, ast.Intersect, ast.Except)): node_out = query_traversal(node.left, callback, parent_query=node) if node_out is not None: node.left = node_out diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index caa9e232..c4c31aaf 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -254,8 +254,8 @@ def to_expression(self, t): elif isinstance(t, ast.TypeCast): arg = self.to_expression(t.arg) type = self.get_type(t.type_name) - if t.length is not None: - type = type(t.length) + if t.precision is not None: + type = type(*t.precision) col = sa.cast(arg, type) if t.alias: @@ -293,10 +293,15 @@ def prepare_case(self, t: ast.Case): conditions.append( (self.to_expression(condition), self.to_expression(result)) ) + default = None if t.default is not None: - conditions.append(self.to_expression(t.default)) + default = self.to_expression(t.default) - return sa.case(*conditions) + value = None + if t.arg is not None: + value = self.to_expression(t.arg) + + return sa.case(*conditions, else_=default, value=value) def to_function(self, t): op = getattr(sa.func, t.op) @@ -382,7 +387,7 @@ def to_table(self, node): if node.alias: table = aliased(table, name=self.get_alias(node.alias)) - elif isinstance(node, ast.Select): + elif isinstance(node, (ast.Select, ast.Union, ast.Intersect, ast.Except)): sub_stmt = self.prepare_select(node) alias = None if node.alias: @@ -396,6 +401,8 @@ def to_table(self, node): return table def prepare_select(self, node): + if isinstance(node, (ast.Union, ast.Except, ast.Intersect)): + return self.prepare_union(node) cols = [] for t in node.targets: @@ -454,17 +461,10 @@ def prepare_select(self, node): full=is_full ) elif isinstance(from_table, ast.Union): - tables = self.extract_union_list(from_table) - alias = None if from_table.alias: alias = self.get_alias(from_table.alias) - - table1 = tables[1] - tables_x = tables[1:] - - table = table1.union(*tables_x).subquery(alias) - + table = self.prepare_union(from_table).subquery(alias) query = query.select_from(table) elif isinstance(from_table, ast.Select): @@ -529,19 +529,18 @@ def prepare_select(self, node): return query - def extract_union_list(self, node): - if not (isinstance(node.left, (ast.Select, ast.Union)) and isinstance(node.right, ast.Select)): - raise NotImplementedError( - f'Unknown UNION {node.left.__class__.__name__}, {node.right.__class__.__name__}') + def prepare_union(self, from_table): + step1 = self.prepare_select(from_table.left) + step2 = self.prepare_select(from_table.right) - tables = [] - if isinstance(node.left, ast.Union): - tables.extend(self.extract_union_list(node.left)) + if isinstance(from_table, ast.Except): + func = sa.except_ if from_table.unique else sa.except_all + elif isinstance(from_table, ast.Intersect): + func = sa.intersect if from_table.unique else sa.intersect_all else: - tables.append(self.prepare_select(node.left)) - tables.append(self.prepare_select(node.right)) - return tables + func = sa.union if from_table.unique else sa.union_all + return func(step1, step2) def prepare_create_table(self, ast_query): columns = [] @@ -690,7 +689,7 @@ def prepare_delete(self, ast_query: ast.Delete): def get_query(self, ast_query, with_params=False): params = None - if isinstance(ast_query, ast.Select): + if isinstance(ast_query, (ast.Select, ast.Union, ast.Except, ast.Intersect)): stmt = self.prepare_select(ast_query) elif isinstance(ast_query, ast.Insert): stmt, params = self.prepare_insert(ast_query, with_params=with_params) diff --git a/tests/test_parser/test_base_sql/test_insert.py b/tests/test_parser/test_base_sql/test_insert.py index 1318b071..ed69a4b3 100644 --- a/tests/test_parser/test_base_sql/test_insert.py +++ b/tests/test_parser/test_base_sql/test_insert.py @@ -74,3 +74,25 @@ def test_insert_from_select_no_columns(self, dialect): assert str(ast).lower() == sql.lower() assert ast.to_tree() == expected_ast.to_tree() + +class TestInsertMDB: + + def test_insert_from_union(self): + from textwrap import dedent + sql = dedent(""" + INSERT INTO tbl_name(a, c) SELECT * from table1 + UNION + SELECT * from table2""")[1:] + + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier('tbl_name'), + columns=[Identifier('a'), Identifier('c')], + from_select=Union( + left=Select(targets=[Star()], from_table=Identifier('table1')), + right=Select(targets=[Star()], from_table=Identifier('table2')) + ) + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() \ No newline at end of file diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 02a99a43..b7a6543c 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -633,7 +633,15 @@ def test_type_cast(self, dialect): sql = f"""SELECT CAST(a AS CHAR(10))""" ast = parse_sql(sql, dialect=dialect) expected_ast = Select(targets=[ - TypeCast(type_name='CHAR', arg=Identifier('a'), length=10) + TypeCast(type_name='CHAR', arg=Identifier('a'), precision=[10]) + ]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = f"""SELECT CAST(a AS DECIMAL(10, 1))""" + ast = parse_sql(sql, dialect=dialect) + expected_ast = Select(targets=[ + TypeCast(type_name='DECIMAL', arg=Identifier('a'), precision=[10, 1]) ]) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -1018,6 +1026,40 @@ def test_case(self): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) + def test_case_simple_form(self): + sql = f'''SELECT + CASE R.DELETE_RULE + WHEN 'CASCADE' THEN 0 + WHEN 'SET NULL' THEN 2 + ELSE 3 + END AS DELETE_RULE + FROM COLLATIONS''' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Case( + arg=Identifier('R.DELETE_RULE'), + rules=[ + [ + Constant('CASCADE'), + Constant(0) + ], + [ + Constant('SET NULL'), + Constant(2) + ] + ], + default=Constant(3), + alias=Identifier('DELETE_RULE') + ) + ], + from_table=Identifier('COLLATIONS') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + def test_select_left(self): sql = f'select left(a, 1) from tab1' ast = parse_sql(sql) @@ -1144,3 +1186,23 @@ def test_table_double_quote(self): ast = parse_sql(sql) assert str(ast) == str(expected_ast) + + def test_window_function_mindsdb(self): + + # modifier + query = "select SUM(col0) OVER (partition by col1 order by col2 rows between unbounded preceding and current row) from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + partition=[Identifier('col1')], + order_by=[OrderBy(field=Identifier('col2'))], + modifier='rows BETWEEN unbounded preceding AND current row' + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_parser/test_base_sql/test_union.py b/tests/test_parser/test_base_sql/test_union.py index 1545e4b0..d8d5dddf 100644 --- a/tests/test_parser/test_base_sql/test_union.py +++ b/tests/test_parser/test_base_sql/test_union.py @@ -11,40 +11,42 @@ def test_single_select_error(self): parse_sql(sql) def test_union_base(self): - sql = """SELECT col1 FROM tab1 - UNION - SELECT col1 FROM tab2""" + for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items(): + sql = f"""SELECT col1 FROM tab1 + {keyword} + SELECT col1 FROM tab2""" - ast = parse_sql(sql) - expected_ast = Union(unique=True, - left=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']), - ), - right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']), - ), - ) - assert ast.to_tree() == expected_ast.to_tree() - assert str(ast) == str(expected_ast) + ast = parse_sql(sql) + expected_ast = cls(unique=True, + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']), + ), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) def test_union_all(self): - sql = """SELECT col1 FROM tab1 - UNION ALL - SELECT col1 FROM tab2""" + for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items(): + sql = f"""SELECT col1 FROM tab1 + {keyword} ALL + SELECT col1 FROM tab2""" - ast = parse_sql(sql) - expected_ast = Union(unique=False, - left=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']), - ), - right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']), - ), - ) - assert ast.to_tree() == expected_ast.to_tree() - assert str(ast) == str(expected_ast) + ast = parse_sql(sql) + expected_ast = cls(unique=False, + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']), + ), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) - def xtest_union_alias(self): + def test_union_alias(self): sql = """SELECT * FROM ( SELECT col1 FROM tab1 UNION diff --git a/tests/test_planner/test_integration_select.py b/tests/test_planner/test_integration_select.py index 5180465d..dab5e65a 100644 --- a/tests/test_planner/test_integration_select.py +++ b/tests/test_planner/test_integration_select.py @@ -290,7 +290,7 @@ def test_integration_select_subquery_in_from(self): steps=[ FetchDataframeStep(integration='int', query=Select( - targets=[Identifier('column1')], + targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Select( targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Identifier('tab'), @@ -378,7 +378,7 @@ def test_integration_select_default_namespace_subquery_in_from(self): steps=[ FetchDataframeStep(integration='int', query=Select( - targets=[Identifier('column1')], + targets=[Identifier('column1', alias=Identifier('column1')),], from_table=Select( targets=[Identifier('column1', alias=Identifier('column1'))], from_table=Identifier('tab'), @@ -554,7 +554,7 @@ def test_select_from_table_subselect_api_integration(self): plan = plan_query( query, integrations=[{'name': 'int1', 'class_type': 'api', 'type': 'data'}], - predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}] + predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}], ) assert plan.steps == expected_plan.steps @@ -583,6 +583,53 @@ def test_select_from_table_subselect_sql_integration(self): assert plan.steps == expected_plan.steps + def test_select_from_single_integration(self): + sql_parsed = ''' + with tab2 as ( + select * from int1.tabl2 + ) + select a from ( + select x from tab2 + union + select y from int1.tab1 + where x1 in (select id from int1.tab1) + limit 1 + ) + ''' + + sql_integration = ''' + with tab2 as ( + select * from tabl2 + ) + select a as a from ( + select x as x from tab2 + union + select y as y from tab1 + where x1 in (select id as id from tab1) + limit 1 + ) + ''' + query = parse_sql(sql_parsed, dialect='mindsdb') + + expected_plan = QueryPlan( + predictor_namespace='mindsdb', + steps=[ + FetchDataframeStep( + integration='int1', + query=parse_sql(sql_integration), + ), + ], + ) + + plan = plan_query( + query, + integrations=[{'name': 'int1', 'class_type': 'sql', 'type': 'data'}], + predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}], + default_namespace='mindsdb', + ) + + assert plan.steps == expected_plan.steps + def test_delete_from_table_subselect_api_integration(self): query = parse_sql(''' delete from int1.tab1 diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index 40c792a5..94afda8f 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -22,8 +22,6 @@ def test_join_predictor_plan(self): """ query = parse_sql(sql) - query_step = parse_sql("select tab1.column1, pred.predicted") - query_step.from_table = Parameter(Result(2)) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', @@ -75,7 +73,7 @@ def test_join_predictor_plan_aliases(self): plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) assert plan.steps == expected_plan.steps - + def test_join_predictor_plan_limit(self): @@ -116,7 +114,7 @@ def test_join_predictor_plan_limit(self): plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) assert plan.steps == expected_plan.steps - + # def test_join_predictor_error_when_filtering_on_predictions(self): # """ @@ -144,14 +142,6 @@ def test_join_predictor_plan_limit(self): # plan_query(query, integrations=['postgres_90'], predictor_namespace='mindsdb', predictor_metadata={'hrp3': {}}) def test_join_predictor_plan_complex_query(self): - query = Select(targets=[Identifier('tab.asset'), Identifier('tab.time'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True), - group_by=[Identifier('tab.asset')], - having=BinaryOperation('=', args=[Identifier('tab.asset'), Constant('bitcoin')]) - ) sql = """ select t.asset, t.time, m.predicted @@ -673,15 +663,16 @@ def test_complex_subselect(self): sql = ''' select t2.x, m.id, (select a from int.tab0 where x=0) from int.tab1 t1 - join int.tab2 t2 on t1.x = t2.x + join int.tab2 t2 on t1.x = t2.a join mindsdb.pred m where m.a=(select a from int.tab3 where x=3) and t2.x=(select a from int.tab4 where x=4) and t1.b=1 and t2.b=2 and t1.a = t2.a ''' - q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 ') - q_table2.where.args[0].args[1] = Parameter(Result(2)) + q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 AND a IN 1') + q_table2.where.args[0].args[0].args[1] = Parameter(Result(2)) + q_table2.where.args[1].args[1] = Parameter(Result(4)) subquery = parse_sql(""" select t2.x, m.id, x @@ -708,22 +699,23 @@ def test_complex_subselect(self): # tables FetchDataframeStep(integration='int', query=parse_sql('select * from tab1 as t1 where b=1')), + SubSelectStep(dataframe=Result(3), query=Select(targets=[Identifier('x')], distinct=True)), FetchDataframeStep(integration='int', query=q_table2), - JoinStep(left=Result(3), right=Result(4), + JoinStep(left=Result(3), right=Result(5), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN, - condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.x')]) + condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.a')]) ) ), # model - ApplyPredictorStep(namespace='mindsdb', dataframe=Result(5), + ApplyPredictorStep(namespace='mindsdb', dataframe=Result(6), predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': Result(1)}), - JoinStep(left=Result(5), right=Result(6), + JoinStep(left=Result(6), right=Result(7), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - QueryStep(subquery, from_table=Result(7)), + QueryStep(subquery, from_table=Result(8)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) diff --git a/tests/test_planner/test_join_tables.py b/tests/test_planner/test_join_tables.py index f44483c5..b85bafbb 100644 --- a/tests/test_planner/test_join_tables.py +++ b/tests/test_planner/test_join_tables.py @@ -16,7 +16,7 @@ def test_join_tables_plan(self): query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')], from_table=Join(left=Identifier('int.tab1'), right=Identifier('int2.tab2'), - condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN ) ) @@ -35,7 +35,7 @@ def test_join_tables_plan(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -45,13 +45,13 @@ def test_join_tables_plan(self): ) assert plan.steps == expected_plan.steps - + def test_join_tables_where_plan(self): query = parse_sql(''' SELECT tab1.column1, tab2.column1, tab2.column2 FROM int.tab1 - INNER JOIN int2.tab2 ON tab1.column1 = tab2.column1 + INNER JOIN int2.tab2 ON tab1.column1 > tab2.column1 WHERE ((tab1.column1 = 1) AND (tab2.column1 = 0)) AND (tab1.column3 = tab2.column3) @@ -71,7 +71,7 @@ def test_join_tables_where_plan(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -90,7 +90,7 @@ def test_join_tables_plan_groupby(self): Function('sum', args=[Identifier('tab2.column2')], alias=Identifier('total'))], from_table=Join(left=Identifier('int.tab1'), right=Identifier('int2.tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN ), @@ -117,7 +117,7 @@ def test_join_tables_plan_groupby(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -126,13 +126,13 @@ def test_join_tables_plan_groupby(self): ], ) assert plan.steps == expected_plan.steps - + def test_join_tables_plan_limit_offset(self): query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')], from_table=Join(left=Identifier('int.tab1'), right=Identifier('int2.tab2'), - condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN ), limit=Constant(10), @@ -161,7 +161,7 @@ def test_join_tables_plan_limit_offset(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -171,13 +171,13 @@ def test_join_tables_plan_limit_offset(self): ) assert plan.steps == expected_plan.steps - + def test_join_tables_plan_order_by(self): query = Select(targets=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')], from_table=Join(left=Identifier('int.tab1'), right=Identifier('int2.tab2'), - condition=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN ), limit=Constant(10), @@ -203,7 +203,7 @@ def test_join_tables_plan_order_by(self): JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - condition=BinaryOperation(op='=', + condition=BinaryOperation(op='>', args=[Identifier('tab1.column1'), Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN @@ -213,7 +213,7 @@ def test_join_tables_plan_order_by(self): ) assert plan.steps == expected_plan.steps - + # This quiery should be sent to integration without raising exception # def test_join_tables_where_ambigous_column_error(self): @@ -278,7 +278,7 @@ def test_join_tables_disambiguate_identifiers_in_condition(self): for i in range(len(plan.steps)): assert plan.steps[i] == expected_plan.steps[i] - + def _disabled_test_join_tables_error_on_unspecified_table_in_condition(self): # disabled: identifier can be environment of system variable @@ -328,7 +328,7 @@ def test_join_tables_plan_default_namespace(self): def test_complex_join_tables(self): query = parse_sql(''' select * from int1.tbl1 t1 - right join int2.tbl2 t2 on t1.id=t2.id + right join int2.tbl2 t2 on t1.id>t2.id join pred m left join tbl3 on tbl3.id=t1.id where t1.a=1 and t2.b=2 and 1=1 @@ -337,6 +337,9 @@ def test_complex_join_tables(self): subquery = copy.deepcopy(query) subquery.from_table = None + q_table3 = parse_sql('select * from tbl3 where id in 0') + q_table3.where.args[1] = Parameter(Result(5)) + plan = plan_query(query, integrations=['int1', 'int2', 'proj'], default_namespace='proj', predictor_metadata=[{'name': 'pred', 'integration_name': 'proj'}]) @@ -349,7 +352,7 @@ def test_complex_join_tables(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), condition=BinaryOperation( - op='=', + op='>', args=[Identifier('t1.id'), Identifier('t2.id')]), join_type=JoinType.RIGHT_JOIN)), @@ -359,9 +362,10 @@ def test_complex_join_tables(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - FetchDataframeStep(integration='proj', query=parse_sql('select * from tbl3')), + SubSelectStep(dataframe=Result(0), query=Select(targets=[Identifier('id')], distinct=True)), + FetchDataframeStep(integration='proj', query=q_table3), JoinStep(left=Result(4), - right=Result(5), + right=Result(6), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), condition=BinaryOperation( @@ -369,7 +373,7 @@ def test_complex_join_tables(self): args=[Identifier('tbl3.id'), Identifier('t1.id')]), join_type=JoinType.LEFT_JOIN)), - QueryStep(subquery, from_table=Result(6)), + QueryStep(subquery, from_table=Result(7)), ] ) @@ -485,4 +489,35 @@ def test_join_one_integration(self): ) plan = plan_query(query, integrations=['int'], default_namespace='int') - assert plan.steps == expected_plan.steps \ No newline at end of file + assert plan.steps == expected_plan.steps + + def test_cte(self): + query = parse_sql(''' + with t1 as ( + select * from int1.tbl1 + ) + select t1.id, t2.* from t1 + join int2.tbl2 t2 on t1.id>t2.id + ''') + + subquery = copy.deepcopy(query) + subquery.from_table = None + + plan = plan_query(query, integrations=['int1', 'int2'], default_namespace='mindsdb') + + expected_plan = QueryPlan( + steps=[ + FetchDataframeStep(integration='int1', query=parse_sql('select * from tbl1')), + SubSelectStep(dataframe=Result(0), query=Select(targets=[Star()]), table_name='t1'), + FetchDataframeStep(integration='int2', query=parse_sql('select * from tbl2 as t2')), + JoinStep(left=Result(1), + right=Result(2), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), + condition=BinaryOperation(op='>', args=[Identifier('t1.id'), Identifier('t2.id')]), + join_type=JoinType.JOIN)), + QueryStep(parse_sql('SELECT t1.`id`, t2.*'), from_table=Result(3)), + ] + ) + + assert plan.steps == expected_plan.steps