diff --git a/mindsdb_sql/parser/ast/select/constant.py b/mindsdb_sql/parser/ast/select/constant.py index 0b31af1e..7869e956 100644 --- a/mindsdb_sql/parser/ast/select/constant.py +++ b/mindsdb_sql/parser/ast/select/constant.py @@ -15,7 +15,8 @@ def to_tree(self, *args, level=0, **kwargs): def get_string(self, *args, **kwargs): if isinstance(self.value, str) and self.with_quotes: - out_str = f"\'{self.value}\'" + val = self.value.replace("'", "\\'") + out_str = f"\'{val}\'" elif isinstance(self.value, bool): out_str = 'TRUE' if self.value else 'FALSE' elif isinstance(self.value, (dt.date, dt.datetime, dt.timedelta)): diff --git a/mindsdb_sql/parser/dialects/mindsdb/chatbot.py b/mindsdb_sql/parser/dialects/mindsdb/chatbot.py index f3366b80..5a3a5223 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/chatbot.py +++ b/mindsdb_sql/parser/dialects/mindsdb/chatbot.py @@ -39,7 +39,8 @@ def get_string(self, *args, **kwargs): params = self.params.copy() params['model'] = self.model.to_string() if self.model else 'NULL' params['database'] = self.database.to_string() - params['agent'] = self.agent.to_string() if self.agent else 'NULL' + if self.agent: + params['agent'] = self.agent.to_string() using_ar = [f'{k}={repr(v)}' for k, v in params.items()] diff --git a/mindsdb_sql/parser/dialects/mindsdb/create_database.py b/mindsdb_sql/parser/dialects/mindsdb/create_database.py index 759dbd57..c60fe64b 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/create_database.py +++ b/mindsdb_sql/parser/dialects/mindsdb/create_database.py @@ -43,8 +43,12 @@ def get_string(self, *args, **kwargs): if self.is_replace: replace_str = f' OR REPLACE' + engine_str = '' + if self.engine: + engine_str = f'ENGINE = {repr(self.engine)} ' + parameters_str = '' if self.parameters: parameters_str = f', PARAMETERS = {json.dumps(self.parameters)}' - out_str = f'CREATE{replace_str} DATABASE {"IF NOT EXISTS " if self.if_not_exists else ""}{self.name.to_string()} WITH ENGINE = {repr(self.engine)}{parameters_str}' + out_str = f'CREATE{replace_str} DATABASE {"IF NOT EXISTS " if self.if_not_exists else ""}{self.name.to_string()} {engine_str}{parameters_str}' return out_str diff --git a/mindsdb_sql/parser/dialects/mindsdb/create_job.py b/mindsdb_sql/parser/dialects/mindsdb/create_job.py index eb06a8df..d046e8ca 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/create_job.py +++ b/mindsdb_sql/parser/dialects/mindsdb/create_job.py @@ -79,7 +79,7 @@ def get_string(self, *args, **kwargs): if_query_str = '' if self.if_query_str is not None: - if_query_str = f" IF '{self.if_query_str}'" + if_query_str = f" IF ({self.if_query_str})" out_str = f'CREATE JOB {"IF NOT EXISTS" if self.if_not_exists else ""} {self.name.to_string()} ({self.query_str}){start_str}{end_str}{repeat_str}{if_query_str}' return out_str diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index ada46cf3..cc057926 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -819,6 +819,7 @@ def create_predictor(self, p): 'CREATE ANOMALY DETECTION MODEL identifier FROM identifier LPAREN raw_query RPAREN', 'CREATE ANOMALY DETECTION MODEL identifier PREDICT result_columns', 'CREATE ANOMALY DETECTION MODEL identifier PREDICT result_columns FROM identifier LPAREN raw_query RPAREN', + 'CREATE ANOMALY DETECTION MODEL identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns', # TODO add IF_NOT_EXISTS elegantly (should be low level BNF expansion) ) def create_anomaly_detection_model(self, p): diff --git a/tests/test_parser/test_mindsdb/test_create_view.py b/tests/test_parser/test_mindsdb/test_create_view.py index f73e4b99..b6773ad4 100644 --- a/tests/test_parser/test_mindsdb/test_create_view.py +++ b/tests/test_parser/test_mindsdb/test_create_view.py @@ -15,7 +15,7 @@ def test_create_view_lexer(self): assert tokens[1].value == 'VIEW' assert tokens[1].type == 'VIEW' - def test_create_view_raises_wrong_dialect(self): + def test_create_view_raises_wrong_dialect_error(self): sql = "CREATE VIEW my_view FROM integr AS ( SELECT * FROM pred )" for dialect in ['sqlite', 'mysql']: with pytest.raises(ParsingException): diff --git a/tests/test_parser/test_standard_render.py b/tests/test_parser/test_standard_render.py new file mode 100644 index 00000000..bd3d011f --- /dev/null +++ b/tests/test_parser/test_standard_render.py @@ -0,0 +1,85 @@ +import inspect +import pkgutil +import sys +import os +import importlib + +from mindsdb_sql import parse_sql + + +def load_all_modules_from_dir(dir_names): + for importer, package_name, _ in pkgutil.iter_modules(dir_names): + full_package_name = package_name + if full_package_name not in sys.modules: + spec = importer.find_spec(package_name) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + yield module + + +def check_module(module): + if module.__name__ in ('test_mysql_lexer', 'test_base_lexer'): + # skip + return + + for class_name, klass in inspect.getmembers(module, predicate=inspect.isclass): + if not class_name.startswith('Test'): + continue + + tests = klass() + for test_name, test_method in inspect.getmembers(tests, predicate=inspect.ismethod): + if not test_name.startswith('test_') or test_name.endswith('_error'): + # skip tests that expected error + continue + sig = inspect.signature(test_method) + args = [] + # add dialect + if 'dialect' in sig.parameters: + args.append('mindsdb') + if 'cat' in sig.parameters: + # skip it + continue + + test_method(*args) + + +def parse_sql2(sql, dialect='mindsdb'): + + query = parse_sql(sql, dialect) + + # render + sql2 = query.to_string() + + # Parse again + try: + query2 = parse_sql(sql2, dialect) + except Exception as e: + # TODO fix queries + raise e + print(sql2) + return query + + # compare result from first and second parsing + assert str(query) == str(query2) + + # return to test: it compares it with expected_ast + return query2 + + +def test_standard_render(): + + base_dir = os.path.dirname(__file__) + dir_names = [ + os.path.join(base_dir, folder) + for folder in os.listdir(base_dir) + if folder.startswith('test_') + ] + + for module in load_all_modules_from_dir(dir_names): + + # inject function + module.parse_sql = parse_sql2 + + check_module(module) + +