diff --git a/mindsdb_sql/parser/ast/create.py b/mindsdb_sql/parser/ast/create.py index 679b0a0..1fcecca 100644 --- a/mindsdb_sql/parser/ast/create.py +++ b/mindsdb_sql/parser/ast/create.py @@ -9,12 +9,14 @@ class TableColumn(): - def __init__(self, name, type='integer', length=None): + def __init__(self, name, type='integer', length=None, default=None, + is_primary_key=False, nullable=None): self.name = name self.type = type - self.is_primary_key = False - self.default = None + self.is_primary_key = is_primary_key + self.default = default self.length = length + self.nullable = nullable def __eq__(self, other): if type(self) != type(other): @@ -96,7 +98,12 @@ def get_string(self, *args, **kwargs): type = str(col.type) if col.length is not None: type = f'{type}({col.length})' - columns.append( f'{col.name} {type}') + col_str = f'{col.name} {type}' + if col.nullable is True: + col_str += ' NULL' + elif col.nullable is False: + col_str += ' NOT NULL' + columns.append(col_str) columns_str = '({})'.format(', '.join(columns)) diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index 9e4a971..1a9cb35 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -40,7 +40,7 @@ class MindsDBLexer(Lexer): SHOW, SCHEMAS, SCHEMA, DATABASES, DATABASE, TABLES, TABLE, FULL, EXTENDED, PROCESSLIST, MUTEX, CODE, SLAVE, REPLICA, REPLICAS, CHANNEL, TRIGGERS, TRIGGER, KEYS, STORAGE, LOGS, BINARY, MASTER, PRIVILEGES, PROFILES, HOSTS, OPEN, INDEXES, - VARIABLES, SESSION, STATUS, + VARIABLES, SESSION, STATUS, PRIMARY_KEY, DEFAULT, GLOBAL, PROCEDURE, FUNCTION, INDEX, WARNINGS, ENGINES, CHARSET, COLLATION, PLUGINS, CHARACTER, PERSIST, PERSIST_ONLY, @@ -164,6 +164,8 @@ class MindsDBLexer(Lexer): STATUS = r'\bSTATUS\b' GLOBAL = r'\bGLOBAL\b' PROCEDURE = r'\bPROCEDURE\b' + PRIMARY_KEY = r'\bPRIMARY[_|\s]KEY\b' + DEFAULT = r'\bDEFAULT\b' FUNCTION = r'\bFUNCTION\b' INDEX = r'\bINDEX\b' CREATE = r'\bCREATE\b' diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 26d9b26..c53a2b0 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -614,16 +614,16 @@ def update(self, p): from_select=p.select) # INSERT - @_('INSERT INTO identifier LPAREN result_columns RPAREN select', + @_('INSERT INTO identifier LPAREN column_list RPAREN select', 'INSERT INTO identifier select') def insert(self, p): - columns = getattr(p, 'result_columns', None) + columns = getattr(p, 'column_list', None) return Insert(table=p.identifier, columns=columns, from_select=p.select) - @_('INSERT INTO identifier LPAREN result_columns RPAREN VALUES expr_list_set', + @_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set', 'INSERT INTO identifier VALUES expr_list_set') def insert(self, p): - columns = getattr(p, 'result_columns', None) + columns = getattr(p, 'column_list', None) return Insert(table=p.identifier, columns=columns, values=p.expr_list_set) @_('expr_list_set COMMA expr_list_set') @@ -706,14 +706,43 @@ def drop_table(self, p): # create table @_('id id', - 'id id LPAREN INTEGER RPAREN') + 'id id DEFAULT id', + 'id id PRIMARY_KEY', + 'id id LPAREN INTEGER RPAREN', + 'id id LPAREN INTEGER RPAREN DEFAULT id', + 'PRIMARY_KEY LPAREN column_list RPAREN', + ) def table_column(self, p): + default = None + if hasattr(p, 'DEFAULT'): + # get last element + default = p[len(p) - 1] + + is_primary_key = False + if hasattr(p, 'column_list'): + # is list of primary keys + return p.column_list + + elif hasattr(p, 'PRIMARY_KEY'): + is_primary_key = True + return TableColumn( name=p[0], type=p[1], - length=getattr(p, 'INTEGER', None) + length=getattr(p, 'INTEGER', None), + default=default, + is_primary_key=is_primary_key ) + @_('table_column NULL', + 'table_column NOT NULL') + def table_column(self, p): + nullable = True + if hasattr(p, 'NOT'): + nullable = False + p.table_column.nullable = nullable + return p.table_column + @_('table_column', 'table_column_list COMMA table_column') def table_column_list(self, p): @@ -723,9 +752,20 @@ def table_column_list(self, p): @_('CREATE replace_or_empty TABLE if_not_exists_or_empty identifier LPAREN table_column_list RPAREN') def create_table(self, p): + table_columns = {} + primary_keys = [] + for item in p.table_column_list: + if isinstance(item, TableColumn): + table_columns[item.name] = item + else: + primary_keys = item + for col_name in primary_keys: + if col_name in table_columns: + table_columns[col_name].is_primary_key = True + return CreateTable( name=p.identifier, - columns=p.table_column_list, + columns=list(table_columns.values()), is_replace=getattr(p, 'replace_or_empty', False), if_not_exists=getattr(p, 'if_not_exists_or_empty', False) ) @@ -1183,17 +1223,17 @@ def from_table(self, p): @_('LPAREN query RPAREN') @_('LPAREN query RPAREN AS id') - @_('LPAREN query RPAREN AS id LPAREN result_columns RPAREN') + @_('LPAREN query RPAREN AS id LPAREN column_list RPAREN') def from_table(self, p): query = p.query query.parentheses = True if hasattr(p, 'id'): query.alias = Identifier(parts=[p.id]) - if hasattr(p, 'result_columns'): - for i, col in enumerate(p.result_columns): + if hasattr(p, 'column_list'): + for i, col in enumerate(p.column_list): if i >= len(query.targets): break - query.targets[i].alias = col + query.targets[i].alias = Identifier(parts=[col]) return query # keywords for table @@ -1277,6 +1317,13 @@ def result_column(self, p): def result_column(self, p): return p[0] + @_('column_list COMMA id', + 'id') + def column_list(self, p): + column_list = getattr(p, 'column_list', []) + column_list.append(p.id) + return column_list + # case @_('CASE case_conditions ELSE expr END') def case(self, p): @@ -1735,6 +1782,7 @@ def function_name(self, p): 'VIEWS', 'WARNINGS', 'MODEL', + 'DEFAULT', 'MODELS', 'AGENT', 'SCHEMAS', diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index 351d2f5..be8ac49 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -67,6 +67,7 @@ def __init__(self, dialect_name): self.types_map = {} for type_name in sa_type_names: self.types_map[type_name.upper()] = getattr(sa.types, type_name) + self.types_map['BOOL'] = self.types_map['BOOLEAN'] def to_column(self, parts): # because sqlalchemy doesn't allow columns consist from parts therefore we do it manually @@ -532,15 +533,24 @@ def prepare_create_table(self, ast_query): for col in ast_query.columns: default = None if col.default is not None: - if isinstance(col.default, ast.Function): - default = self.to_function(col.default) + if isinstance(col.default, str): + default = sa.text(col.default) + if col.type.lower() == 'serial': + col.is_primary_key = True + col.type = 'INT' + + kwargs = { + 'primary_key': col.is_primary_key, + 'server_default': default, + } + if col.nullable is not None: + kwargs['nullable'] = col.nullable columns.append( sa.Column( col.name, self.get_type(col.type), - primary_key=col.is_primary_key, - default=default, + **kwargs ) ) diff --git a/tests/test_parser/test_base_sql/test_create.py b/tests/test_parser/test_base_sql/test_create.py index 0e9a7f5..f888af0 100644 --- a/tests/test_parser/test_base_sql/test_create.py +++ b/tests/test_parser/test_base_sql/test_create.py @@ -81,9 +81,77 @@ def test_create(self): City varchar ) ''' - print(sql) ast = parse_sql(sql) assert str(ast).lower() == str(expected_ast).lower() assert ast.to_tree() == expected_ast.to_tree() + # test with primary keys / defaults + # using serial + + sql = f''' + CREATE TABLE mydb.Persons( + PersonID serial, + active BOOL NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''' + ast = parse_sql(sql) + + expected_ast = CreateTable( + name=Identifier('mydb.Persons'), + columns=[ + TableColumn(name='PersonID', type='serial'), + TableColumn(name='active', type='BOOL', nullable=False), + TableColumn(name='created_at', type='TIMESTAMP', default='CURRENT_TIMESTAMP'), + ] + ) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + # using primary key column + + sql = f''' + CREATE TABLE mydb.Persons( + PersonID INT PRIMARY KEY, + name TEXT NULL + ) + ''' + ast = parse_sql(sql) + + expected_ast = CreateTable( + name=Identifier('mydb.Persons'), + columns=[ + TableColumn(name='PersonID', type='INT', is_primary_key=True), + TableColumn(name='name', type='TEXT', nullable=True), + ] + ) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + # multiple primary keys + + sql = f''' + CREATE TABLE mydb.Persons( + location_id INT, + num INT, + name TEXT, + PRIMARY KEY (location_id, num) + ) + ''' + ast = parse_sql(sql) + + expected_ast = CreateTable( + name=Identifier('mydb.Persons'), + columns=[ + TableColumn(name='location_id', type='INT', is_primary_key=True), + TableColumn(name='num', type='INT', is_primary_key=True), + TableColumn(name='name', type='TEXT'), + ] + ) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() +