Skip to content

Commit

Permalink
Merge pull request #397 from mindsdb/staging
Browse files Browse the repository at this point in the history
Release 0.18.0
  • Loading branch information
ea-rus authored Sep 2, 2024
2 parents 5c265a0 + 7d4993d commit 53405b5
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 26 deletions.
2 changes: 1 addition & 1 deletion mindsdb_sql/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = 'mindsdb_sql'
__package_name__ = 'mindsdb_sql'
__version__ = '0.17.3'
__version__ = '0.18.0'
__description__ = "Pure python SQL parser"
__email__ = "[email protected]"
__author__ = 'MindsDB Inc'
Expand Down
15 changes: 11 additions & 4 deletions mindsdb_sql/parser/ast/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@


class TableColumn():
def __init__(self, name, type='integer', length=None):
def __init__(self, name, type='integer', length=None, default=None,
is_primary_key=False, nullable=None):
self.name = name
self.type = type
self.is_primary_key = False
self.default = None
self.is_primary_key = is_primary_key
self.default = default
self.length = length
self.nullable = nullable

def __eq__(self, other):
if type(self) != type(other):
Expand Down Expand Up @@ -96,7 +98,12 @@ def get_string(self, *args, **kwargs):
type = str(col.type)
if col.length is not None:
type = f'{type}({col.length})'
columns.append( f'{col.name} {type}')
col_str = f'{col.name} {type}'
if col.nullable is True:
col_str += ' NULL'
elif col.nullable is False:
col_str += ' NOT NULL'
columns.append(col_str)

columns_str = '({})'.format(', '.join(columns))

Expand Down
2 changes: 1 addition & 1 deletion mindsdb_sql/parser/ast/select/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(self, info):
super().__init__(op='interval', args=[info, ])

def get_string(self, *args, **kwargs):
return f'INTERVAL {repr(self.args[0])}'
return f'INTERVAL {self.args[0]}'

def to_tree(self, *args, level=0, **kwargs):
return self.get_string( *args, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion mindsdb_sql/parser/dialects/mindsdb/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MindsDBLexer(Lexer):
SHOW, SCHEMAS, SCHEMA, DATABASES, DATABASE, TABLES, TABLE, FULL, EXTENDED, PROCESSLIST,
MUTEX, CODE, SLAVE, REPLICA, REPLICAS, CHANNEL, TRIGGERS, TRIGGER, KEYS, STORAGE, LOGS, BINARY,
MASTER, PRIVILEGES, PROFILES, HOSTS, OPEN, INDEXES,
VARIABLES, SESSION, STATUS,
VARIABLES, SESSION, STATUS, PRIMARY_KEY, DEFAULT,
GLOBAL, PROCEDURE, FUNCTION, INDEX, WARNINGS,
ENGINES, CHARSET, COLLATION, PLUGINS, CHARACTER,
PERSIST, PERSIST_ONLY,
Expand Down Expand Up @@ -164,6 +164,8 @@ class MindsDBLexer(Lexer):
STATUS = r'\bSTATUS\b'
GLOBAL = r'\bGLOBAL\b'
PROCEDURE = r'\bPROCEDURE\b'
PRIMARY_KEY = r'\bPRIMARY[_|\s]KEY\b'
DEFAULT = r'\bDEFAULT\b'
FUNCTION = r'\bFUNCTION\b'
INDEX = r'\bINDEX\b'
CREATE = r'\bCREATE\b'
Expand Down
70 changes: 59 additions & 11 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,16 +614,16 @@ def update(self, p):
from_select=p.select)

# INSERT
@_('INSERT INTO identifier LPAREN result_columns RPAREN select',
@_('INSERT INTO identifier LPAREN column_list RPAREN select',
'INSERT INTO identifier select')
def insert(self, p):
columns = getattr(p, 'result_columns', None)
columns = getattr(p, 'column_list', None)
return Insert(table=p.identifier, columns=columns, from_select=p.select)

@_('INSERT INTO identifier LPAREN result_columns RPAREN VALUES expr_list_set',
@_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set',
'INSERT INTO identifier VALUES expr_list_set')
def insert(self, p):
columns = getattr(p, 'result_columns', None)
columns = getattr(p, 'column_list', None)
return Insert(table=p.identifier, columns=columns, values=p.expr_list_set)

@_('expr_list_set COMMA expr_list_set')
Expand Down Expand Up @@ -706,14 +706,43 @@ def drop_table(self, p):

# create table
@_('id id',
'id id LPAREN INTEGER RPAREN')
'id id DEFAULT id',
'id id PRIMARY_KEY',
'id id LPAREN INTEGER RPAREN',
'id id LPAREN INTEGER RPAREN DEFAULT id',
'PRIMARY_KEY LPAREN column_list RPAREN',
)
def table_column(self, p):
default = None
if hasattr(p, 'DEFAULT'):
# get last element
default = p[len(p) - 1]

is_primary_key = False
if hasattr(p, 'column_list'):
# is list of primary keys
return p.column_list

elif hasattr(p, 'PRIMARY_KEY'):
is_primary_key = True

return TableColumn(
name=p[0],
type=p[1],
length=getattr(p, 'INTEGER', None)
length=getattr(p, 'INTEGER', None),
default=default,
is_primary_key=is_primary_key
)

@_('table_column NULL',
'table_column NOT NULL')
def table_column(self, p):
nullable = True
if hasattr(p, 'NOT'):
nullable = False
p.table_column.nullable = nullable
return p.table_column

@_('table_column',
'table_column_list COMMA table_column')
def table_column_list(self, p):
Expand All @@ -723,9 +752,20 @@ def table_column_list(self, p):

@_('CREATE replace_or_empty TABLE if_not_exists_or_empty identifier LPAREN table_column_list RPAREN')
def create_table(self, p):
table_columns = {}
primary_keys = []
for item in p.table_column_list:
if isinstance(item, TableColumn):
table_columns[item.name] = item
else:
primary_keys = item
for col_name in primary_keys:
if col_name in table_columns:
table_columns[col_name].is_primary_key = True

return CreateTable(
name=p.identifier,
columns=p.table_column_list,
columns=list(table_columns.values()),
is_replace=getattr(p, 'replace_or_empty', False),
if_not_exists=getattr(p, 'if_not_exists_or_empty', False)
)
Expand Down Expand Up @@ -1183,17 +1223,17 @@ def from_table(self, p):

@_('LPAREN query RPAREN')
@_('LPAREN query RPAREN AS id')
@_('LPAREN query RPAREN AS id LPAREN result_columns RPAREN')
@_('LPAREN query RPAREN AS id LPAREN column_list RPAREN')
def from_table(self, p):
query = p.query
query.parentheses = True
if hasattr(p, 'id'):
query.alias = Identifier(parts=[p.id])
if hasattr(p, 'result_columns'):
for i, col in enumerate(p.result_columns):
if hasattr(p, 'column_list'):
for i, col in enumerate(p.column_list):
if i >= len(query.targets):
break
query.targets[i].alias = col
query.targets[i].alias = Identifier(parts=[col])
return query

# keywords for table
Expand Down Expand Up @@ -1277,6 +1317,13 @@ def result_column(self, p):
def result_column(self, p):
return p[0]

@_('column_list COMMA id',
'id')
def column_list(self, p):
column_list = getattr(p, 'column_list', [])
column_list.append(p.id)
return column_list

# case
@_('CASE case_conditions ELSE expr END')
def case(self, p):
Expand Down Expand Up @@ -1735,6 +1782,7 @@ def function_name(self, p):
'VIEWS',
'WARNINGS',
'MODEL',
'DEFAULT',
'MODELS',
'AGENT',
'SCHEMAS',
Expand Down
23 changes: 18 additions & 5 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def __init__(self, info):

@compiles(INTERVAL)
def _compile_interval(element, compiler, **kw):
return f"INTERVAL '{element.info}'"
items = element.info.split(' ', maxsplit=1)
# quote first element
items[0] = f"'{items[0]}'"
return "INTERVAL " + " ".join(items)


class SqlalchemyRender:
Expand Down Expand Up @@ -67,6 +70,7 @@ def __init__(self, dialect_name):
self.types_map = {}
for type_name in sa_type_names:
self.types_map[type_name.upper()] = getattr(sa.types, type_name)
self.types_map['BOOL'] = self.types_map['BOOLEAN']

def to_column(self, parts):
# because sqlalchemy doesn't allow columns consist from parts therefore we do it manually
Expand Down Expand Up @@ -532,15 +536,24 @@ def prepare_create_table(self, ast_query):
for col in ast_query.columns:
default = None
if col.default is not None:
if isinstance(col.default, ast.Function):
default = self.to_function(col.default)
if isinstance(col.default, str):
default = sa.text(col.default)
if col.type.lower() == 'serial':
col.is_primary_key = True
col.type = 'INT'

kwargs = {
'primary_key': col.is_primary_key,
'server_default': default,
}
if col.nullable is not None:
kwargs['nullable'] = col.nullable

columns.append(
sa.Column(
col.name,
self.get_type(col.type),
primary_key=col.is_primary_key,
default=default,
**kwargs
)
)

Expand Down
70 changes: 69 additions & 1 deletion tests/test_parser/test_base_sql/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,77 @@ def test_create(self):
City varchar
)
'''
print(sql)
ast = parse_sql(sql)

assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()

# test with primary keys / defaults
# using serial

sql = f'''
CREATE TABLE mydb.Persons(
PersonID serial,
active BOOL NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
'''
ast = parse_sql(sql)

expected_ast = CreateTable(
name=Identifier('mydb.Persons'),
columns=[
TableColumn(name='PersonID', type='serial'),
TableColumn(name='active', type='BOOL', nullable=False),
TableColumn(name='created_at', type='TIMESTAMP', default='CURRENT_TIMESTAMP'),
]
)

assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()

# using primary key column

sql = f'''
CREATE TABLE mydb.Persons(
PersonID INT PRIMARY KEY,
name TEXT NULL
)
'''
ast = parse_sql(sql)

expected_ast = CreateTable(
name=Identifier('mydb.Persons'),
columns=[
TableColumn(name='PersonID', type='INT', is_primary_key=True),
TableColumn(name='name', type='TEXT', nullable=True),
]
)

assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()

# multiple primary keys

sql = f'''
CREATE TABLE mydb.Persons(
location_id INT,
num INT,
name TEXT,
PRIMARY KEY (location_id, num)
)
'''
ast = parse_sql(sql)

expected_ast = CreateTable(
name=Identifier('mydb.Persons'),
columns=[
TableColumn(name='location_id', type='INT', is_primary_key=True),
TableColumn(name='num', type='INT', is_primary_key=True),
TableColumn(name='name', type='TEXT'),
]
)

assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()

4 changes: 2 additions & 2 deletions tests/test_parser/test_base_sql/test_misc_sql_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_interval(self):
for value in ('1 day', "'1' day", "'1 day'"):
sql = f"""
select interval {value} + 1 from aaa
where 'a' > interval "3 day 1 min"
where 'a' > interval "1 min"
"""

expected_ast = Select(
Expand All @@ -250,7 +250,7 @@ def test_interval(self):
op='>',
args=[
Constant('a'),
Interval('3 day 1 min'),
Interval('1 min'),
]
)
)
Expand Down

0 comments on commit 53405b5

Please sign in to comment.