From 218018f35c992ae391bc433698a919f868a9425c Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 6 Mar 2024 22:54:48 +0300 Subject: [PATCH] fix escaping variant: "a \\\n b" --- mindsdb_sql/parser/dialects/mindsdb/lexer.py | 4 +-- .../test_base_sql/test_base_sql.py | 34 +++++++++++-------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index 398a1d2b..a6a00485 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -308,12 +308,12 @@ def FLOAT(self, t): def INTEGER(self, t): return t - @_(r"'(?:[^\'\\]|\\.)*'") + @_(r"'(?:\\.|[^'])*'") def QUOTE_STRING(self, t): t.value = t.value.replace('\\"', '"').replace("\\'", "'") return t - @_(r'"(?:[^\"\\]|\\.)*"') + @_(r'"(?:\\.|[^"])*"') def DQUOTE_STRING(self, t): t.value = t.value.replace('\\"', '"').replace("\\'", "'") return t diff --git a/tests/test_parser/test_base_sql/test_base_sql.py b/tests/test_parser/test_base_sql/test_base_sql.py index 3294b897..420218b5 100644 --- a/tests/test_parser/test_base_sql/test_base_sql.py +++ b/tests/test_parser/test_base_sql/test_base_sql.py @@ -1,4 +1,4 @@ -import pytest +from textwrap import dedent from mindsdb_sql import parse_sql from mindsdb_sql.parser.ast import * @@ -34,22 +34,26 @@ def test_not_equal(self): def test_escaping(self): expected_ast = Select( - targets=[Constant(value="a ' \" b")] + targets=[ + Constant(value="a ' \" b"), + Constant(value="a ' \" b"), + Constant(value="a \\n b"), + Constant(value="a \\\n b"), + Constant(value="a \\\n b"), + Constant(value="a\nb"), + ] ) - sql = """ - select 'a \\' \\" b' - """ - - ast = parse_sql(sql) - - assert str(ast).lower() == str(expected_ast).lower() - assert ast.to_tree() == expected_ast.to_tree() - - # in double quotes - sql = """ - select "a \\' \\" b" - """ + sql = dedent(''' +select +'a \\' \\" b', -- double quote +"a \\' \\" b", -- single quote +"a \\n b", +"a \\\n b", -- double quote +'a \\\n b', -- single quote +"a +b" + ''') ast = parse_sql(sql)