Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 0.18.0 #397

Merged
merged 7 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindsdb_sql/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = 'mindsdb_sql'
__package_name__ = 'mindsdb_sql'
__version__ = '0.17.3'
__version__ = '0.18.0'
__description__ = "Pure python SQL parser"
__email__ = "[email protected]"
__author__ = 'MindsDB Inc'
Expand Down
15 changes: 11 additions & 4 deletions mindsdb_sql/parser/ast/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@


class TableColumn():
def __init__(self, name, type='integer', length=None):
def __init__(self, name, type='integer', length=None, default=None,
is_primary_key=False, nullable=None):
self.name = name
self.type = type
self.is_primary_key = False
self.default = None
self.is_primary_key = is_primary_key
self.default = default
self.length = length
self.nullable = nullable

def __eq__(self, other):
if type(self) != type(other):
Expand Down Expand Up @@ -96,7 +98,12 @@ def get_string(self, *args, **kwargs):
type = str(col.type)
if col.length is not None:
type = f'{type}({col.length})'
columns.append( f'{col.name} {type}')
col_str = f'{col.name} {type}'
if col.nullable is True:
col_str += ' NULL'
elif col.nullable is False:
col_str += ' NOT NULL'
columns.append(col_str)

columns_str = '({})'.format(', '.join(columns))

Expand Down
2 changes: 1 addition & 1 deletion mindsdb_sql/parser/ast/select/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(self, info):
super().__init__(op='interval', args=[info, ])

def get_string(self, *args, **kwargs):
return f'INTERVAL {repr(self.args[0])}'
return f'INTERVAL {self.args[0]}'

def to_tree(self, *args, level=0, **kwargs):
return self.get_string( *args, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion mindsdb_sql/parser/dialects/mindsdb/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MindsDBLexer(Lexer):
SHOW, SCHEMAS, SCHEMA, DATABASES, DATABASE, TABLES, TABLE, FULL, EXTENDED, PROCESSLIST,
MUTEX, CODE, SLAVE, REPLICA, REPLICAS, CHANNEL, TRIGGERS, TRIGGER, KEYS, STORAGE, LOGS, BINARY,
MASTER, PRIVILEGES, PROFILES, HOSTS, OPEN, INDEXES,
VARIABLES, SESSION, STATUS,
VARIABLES, SESSION, STATUS, PRIMARY_KEY, DEFAULT,
GLOBAL, PROCEDURE, FUNCTION, INDEX, WARNINGS,
ENGINES, CHARSET, COLLATION, PLUGINS, CHARACTER,
PERSIST, PERSIST_ONLY,
Expand Down Expand Up @@ -164,6 +164,8 @@ class MindsDBLexer(Lexer):
STATUS = r'\bSTATUS\b'
GLOBAL = r'\bGLOBAL\b'
PROCEDURE = r'\bPROCEDURE\b'
PRIMARY_KEY = r'\bPRIMARY[_|\s]KEY\b'
DEFAULT = r'\bDEFAULT\b'
FUNCTION = r'\bFUNCTION\b'
INDEX = r'\bINDEX\b'
CREATE = r'\bCREATE\b'
Expand Down
70 changes: 59 additions & 11 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,16 +614,16 @@ def update(self, p):
from_select=p.select)

# INSERT
@_('INSERT INTO identifier LPAREN result_columns RPAREN select',
@_('INSERT INTO identifier LPAREN column_list RPAREN select',
'INSERT INTO identifier select')
def insert(self, p):
columns = getattr(p, 'result_columns', None)
columns = getattr(p, 'column_list', None)
return Insert(table=p.identifier, columns=columns, from_select=p.select)

@_('INSERT INTO identifier LPAREN result_columns RPAREN VALUES expr_list_set',
@_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set',
'INSERT INTO identifier VALUES expr_list_set')
def insert(self, p):
columns = getattr(p, 'result_columns', None)
columns = getattr(p, 'column_list', None)
return Insert(table=p.identifier, columns=columns, values=p.expr_list_set)

@_('expr_list_set COMMA expr_list_set')
Expand Down Expand Up @@ -706,14 +706,43 @@ def drop_table(self, p):

# create table
@_('id id',
'id id LPAREN INTEGER RPAREN')
'id id DEFAULT id',
'id id PRIMARY_KEY',
'id id LPAREN INTEGER RPAREN',
'id id LPAREN INTEGER RPAREN DEFAULT id',
'PRIMARY_KEY LPAREN column_list RPAREN',
)
def table_column(self, p):
default = None
if hasattr(p, 'DEFAULT'):
# get last element
default = p[len(p) - 1]

is_primary_key = False
if hasattr(p, 'column_list'):
# is list of primary keys
return p.column_list

elif hasattr(p, 'PRIMARY_KEY'):
is_primary_key = True

return TableColumn(
name=p[0],
type=p[1],
length=getattr(p, 'INTEGER', None)
length=getattr(p, 'INTEGER', None),
default=default,
is_primary_key=is_primary_key
)

@_('table_column NULL',
'table_column NOT NULL')
def table_column(self, p):
nullable = True
if hasattr(p, 'NOT'):
nullable = False
p.table_column.nullable = nullable
return p.table_column

@_('table_column',
'table_column_list COMMA table_column')
def table_column_list(self, p):
Expand All @@ -723,9 +752,20 @@ def table_column_list(self, p):

@_('CREATE replace_or_empty TABLE if_not_exists_or_empty identifier LPAREN table_column_list RPAREN')
def create_table(self, p):
table_columns = {}
primary_keys = []
for item in p.table_column_list:
if isinstance(item, TableColumn):
table_columns[item.name] = item
else:
primary_keys = item
for col_name in primary_keys:
if col_name in table_columns:
table_columns[col_name].is_primary_key = True

return CreateTable(
name=p.identifier,
columns=p.table_column_list,
columns=list(table_columns.values()),
is_replace=getattr(p, 'replace_or_empty', False),
if_not_exists=getattr(p, 'if_not_exists_or_empty', False)
)
Expand Down Expand Up @@ -1183,17 +1223,17 @@ def from_table(self, p):

@_('LPAREN query RPAREN')
@_('LPAREN query RPAREN AS id')
@_('LPAREN query RPAREN AS id LPAREN result_columns RPAREN')
@_('LPAREN query RPAREN AS id LPAREN column_list RPAREN')
def from_table(self, p):
query = p.query
query.parentheses = True
if hasattr(p, 'id'):
query.alias = Identifier(parts=[p.id])
if hasattr(p, 'result_columns'):
for i, col in enumerate(p.result_columns):
if hasattr(p, 'column_list'):
for i, col in enumerate(p.column_list):
if i >= len(query.targets):
break
query.targets[i].alias = col
query.targets[i].alias = Identifier(parts=[col])
return query

# keywords for table
Expand Down Expand Up @@ -1277,6 +1317,13 @@ def result_column(self, p):
def result_column(self, p):
return p[0]

@_('column_list COMMA id',
'id')
def column_list(self, p):
column_list = getattr(p, 'column_list', [])
column_list.append(p.id)
return column_list

# case
@_('CASE case_conditions ELSE expr END')
def case(self, p):
Expand Down Expand Up @@ -1735,6 +1782,7 @@ def function_name(self, p):
'VIEWS',
'WARNINGS',
'MODEL',
'DEFAULT',
'MODELS',
'AGENT',
'SCHEMAS',
Expand Down
23 changes: 18 additions & 5 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def __init__(self, info):

@compiles(INTERVAL)
def _compile_interval(element, compiler, **kw):
return f"INTERVAL '{element.info}'"
items = element.info.split(' ', maxsplit=1)
# quote first element
items[0] = f"'{items[0]}'"
return "INTERVAL " + " ".join(items)


class SqlalchemyRender:
Expand Down Expand Up @@ -67,6 +70,7 @@ def __init__(self, dialect_name):
self.types_map = {}
for type_name in sa_type_names:
self.types_map[type_name.upper()] = getattr(sa.types, type_name)
self.types_map['BOOL'] = self.types_map['BOOLEAN']

def to_column(self, parts):
# because sqlalchemy doesn't allow columns consist from parts therefore we do it manually
Expand Down Expand Up @@ -532,15 +536,24 @@ def prepare_create_table(self, ast_query):
for col in ast_query.columns:
default = None
if col.default is not None:
if isinstance(col.default, ast.Function):
default = self.to_function(col.default)
if isinstance(col.default, str):
default = sa.text(col.default)
if col.type.lower() == 'serial':
col.is_primary_key = True
col.type = 'INT'

kwargs = {
'primary_key': col.is_primary_key,
'server_default': default,
}
if col.nullable is not None:
kwargs['nullable'] = col.nullable

columns.append(
sa.Column(
col.name,
self.get_type(col.type),
primary_key=col.is_primary_key,
default=default,
**kwargs
)
)

Expand Down
70 changes: 69 additions & 1 deletion tests/test_parser/test_base_sql/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,77 @@ def test_create(self):
City varchar
)
'''
print(sql)
ast = parse_sql(sql)

assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()

# test with primary keys / defaults
# using serial

sql = f'''
CREATE TABLE mydb.Persons(
PersonID serial,
active BOOL NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
'''
ast = parse_sql(sql)

expected_ast = CreateTable(
name=Identifier('mydb.Persons'),
columns=[
TableColumn(name='PersonID', type='serial'),
TableColumn(name='active', type='BOOL', nullable=False),
TableColumn(name='created_at', type='TIMESTAMP', default='CURRENT_TIMESTAMP'),
]
)

assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()

# using primary key column

sql = f'''
CREATE TABLE mydb.Persons(
PersonID INT PRIMARY KEY,
name TEXT NULL
)
'''
ast = parse_sql(sql)

expected_ast = CreateTable(
name=Identifier('mydb.Persons'),
columns=[
TableColumn(name='PersonID', type='INT', is_primary_key=True),
TableColumn(name='name', type='TEXT', nullable=True),
]
)

assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()

# multiple primary keys

sql = f'''
CREATE TABLE mydb.Persons(
location_id INT,
num INT,
name TEXT,
PRIMARY KEY (location_id, num)
)
'''
ast = parse_sql(sql)

expected_ast = CreateTable(
name=Identifier('mydb.Persons'),
columns=[
TableColumn(name='location_id', type='INT', is_primary_key=True),
TableColumn(name='num', type='INT', is_primary_key=True),
TableColumn(name='name', type='TEXT'),
]
)

assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()

4 changes: 2 additions & 2 deletions tests/test_parser/test_base_sql/test_misc_sql_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_interval(self):
for value in ('1 day', "'1' day", "'1 day'"):
sql = f"""
select interval {value} + 1 from aaa
where 'a' > interval "3 day 1 min"
where 'a' > interval "1 min"
"""

expected_ast = Select(
Expand All @@ -250,7 +250,7 @@ def test_interval(self):
op='>',
args=[
Constant('a'),
Interval('3 day 1 min'),
Interval('1 min'),
]
)
)
Expand Down