Skip to content

Commit

Permalink
Set active model_name.11
Browse files Browse the repository at this point in the history
  • Loading branch information
ea-rus committed Apr 25, 2024
1 parent 2ecdae0 commit d34194b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 20 deletions.
21 changes: 6 additions & 15 deletions mindsdb_sql/parser/dialects/mindsdb/create_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(self,
targets=None,
integration_name=None,
query_str=None,
datasource_name=None,
order_by=None,
group_by=None,
window=None,
Expand All @@ -26,7 +25,6 @@ def __init__(self,
self.integration_name = integration_name
self.query_str = query_str
self.targets = targets
self.datasource_name = datasource_name
self.order_by = order_by
self.group_by = group_by
self.window = window
Expand All @@ -50,10 +48,6 @@ def to_tree(self, *args, level=0, **kwargs):

query_str = f'\n{ind1}query={self.query_str},'

datasource_name_str = ''
if self.datasource_name:
datasource_name_str = f'\n{ind1}datasource_name={self.datasource_name.to_tree()},'

if self.targets is not None:
target_trees = ',\n'.join([t.to_tree(level=level+2) for t in self.targets])
targets_str = f'\n{ind1}targets=[\n{target_trees}\n{ind1}],'
Expand Down Expand Up @@ -83,7 +77,6 @@ def to_tree(self, *args, level=0, **kwargs):
f'{name_str}' \
f'{integration_name_str}' \
f'{query_str}' \
f'{datasource_name_str}' \
f'{targets_str}' \
f'{order_by_str}' \
f'{group_by_str}' \
Expand Down Expand Up @@ -119,22 +112,20 @@ def get_string(self, *args, **kwargs):
using_ar.append(f'{Identifier(key).to_string()}={value}')

using_str = f'USING ' + ', '.join(using_ar)
datasource_name_str = f'AS {self.datasource_name.to_string()} ' if self.datasource_name is not None else ''

query_str = ''
if self.query_str is not None:
query_str = f'({self.query_str}) '
integration_name_str = ''
if self.integration_name is not None:
integration_name_str = f' {self.integration_name.to_string()}'

integration_name_str = ''
if self.integration_name is not None:
integration_name_str = f'FROM {self.integration_name.to_string()} '
query_str = f'FROM{integration_name_str} ({self.query_str}) '

or_replace_str = ' OR REPLACE' if self.is_replace else ''
if_not_exists_str = 'IF NOT EXISTS ' if self.if_not_exists else ''
object_str = self._object + ' ' if self._object else ''

out_str = f'{self._action}{or_replace_str} {object_str}{if_not_exists_str}{self.name.to_string()} {integration_name_str}{query_str}' \
f'{datasource_name_str}' \
out_str = f'{self._action}{or_replace_str} {object_str}{if_not_exists_str}{self.name.to_string()} {query_str}' \
f'{targets_str} ' \
f'{order_by_str}' \
f'{group_by_str}' \
Expand All @@ -148,7 +139,7 @@ def get_string(self, *args, **kwargs):
class CreatePredictor(CreatePredictorBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._object = 'PREDICTOR'
self._object = 'MODEL'


# Models by task type
Expand Down
18 changes: 14 additions & 4 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,22 +365,25 @@ def set_item_list(self, p):
# set names
@_('id id',
'id constant',
'id identifier',
'id id COLLATE constant',
'id id COLLATE id',
'id constant COLLATE constant',
'id constant COLLATE id')
def set_item(self, p):
category = p[0]
if category.lower() != 'names':
raise ParsingException(f'Expected "SET names", got "SET {category}"')
if isinstance(p[1], Constant):

if isinstance(p[1], (Constant, Identifier)):
value = p[1]
else:
# is id
value = Constant(p[1], with_quotes=False)

params = {}
if hasattr(p, 'COLLATE'):
if category.lower() != 'names':
raise ParsingException(f'Expected "SET names", got "SET {category}"')

if isinstance(p[3], Constant):
val = p[3]
else:
Expand Down Expand Up @@ -776,6 +779,7 @@ def create_predictor(self, p):
@_('CREATE replace_or_empty PREDICTOR if_not_exists_or_empty identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns',
'CREATE replace_or_empty PREDICTOR if_not_exists_or_empty identifier PREDICT result_columns',
'CREATE replace_or_empty MODEL if_not_exists_or_empty identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns',
'CREATE replace_or_empty MODEL if_not_exists_or_empty identifier FROM LPAREN raw_query RPAREN PREDICT result_columns',
'CREATE replace_or_empty MODEL if_not_exists_or_empty identifier PREDICT result_columns'
)
def create_predictor(self, p):
Expand Down Expand Up @@ -837,11 +841,15 @@ def create_anomaly_detection_model(self, p):

@_('RETRAIN identifier',
'RETRAIN identifier PREDICT result_columns',
'RETRAIN identifier FROM LPAREN raw_query RPAREN',
'RETRAIN identifier FROM LPAREN raw_query RPAREN PREDICT result_columns',
'RETRAIN identifier FROM identifier LPAREN raw_query RPAREN',
'RETRAIN identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns',
'RETRAIN MODEL identifier',
'RETRAIN MODEL identifier PREDICT result_columns',
'RETRAIN MODEL identifier FROM LPAREN raw_query RPAREN',
'RETRAIN MODEL identifier FROM identifier LPAREN raw_query RPAREN',
'RETRAIN MODEL identifier FROM LPAREN raw_query RPAREN PREDICT result_columns',
'RETRAIN MODEL identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns')
def create_predictor(self, p):
query_str = None
Expand All @@ -862,7 +870,9 @@ def create_predictor(self, p):
)

@_('FINETUNE identifier FROM identifier LPAREN raw_query RPAREN',
'FINETUNE MODEL identifier FROM identifier LPAREN raw_query RPAREN')
'FINETUNE identifier FROM LPAREN raw_query RPAREN',
'FINETUNE MODEL identifier FROM identifier LPAREN raw_query RPAREN',
'FINETUNE MODEL identifier FROM LPAREN raw_query RPAREN')
def create_predictor(self, p):
query_str = None
if hasattr(p, 'raw_query'):
Expand Down
9 changes: 9 additions & 0 deletions tests/test_parser/test_base_sql/test_misc_sql_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,15 @@ def test_charset(self):
assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)

def test_set_version(self):
sql = "SET active model_name.1"

ast = parse_sql(sql)
expected_ast = Set(category='active', value=Identifier(parts=['model_name', '1']))

assert ast.to_tree() == expected_ast.to_tree()
assert str(ast) == str(expected_ast)

def test_interval(self):
sql = """
select interval '1 day'+1 from aaa
Expand Down
44 changes: 43 additions & 1 deletion tests/test_parser/test_mindsdb/test_create_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_create_predictor_full(self):
assert ast.to_tree() == ast2.to_tree()

def test_create_predictor_minimal(self):
sql = """CREATE PREDICTOR IF NOT EXISTS pred
sql = """CREATE MODEL IF NOT EXISTS pred
FROM integration_name
(select * FROM table_name)
PREDICT f1 as f1_alias, f2
Expand Down Expand Up @@ -196,3 +196,45 @@ def test_create_anomaly_detection_model(self):

assert to_single_line(str(ast)) == to_single_line(str(expected_ast))
assert ast.to_tree() == expected_ast.to_tree()

def test_optional_db(self):
sql = "CREATE MODEL xxx from (select 1) PREDICT sss"
ast = parse_sql(sql, dialect='mindsdb')
expected_ast = CreatePredictor(
name=Identifier('xxx'),
query_str='select 1',
targets=[Identifier('sss')],
)
assert to_single_line(str(ast)) == to_single_line(str(expected_ast))
assert ast.to_tree() == expected_ast.to_tree()

# retrain
sql = "RETRAIN MODEL xxx from (select 1)"
ast = parse_sql(sql, dialect='mindsdb')
expected_ast = RetrainPredictor(
name=Identifier('xxx'),
query_str='select 1',
)
assert to_single_line(str(ast)) == to_single_line(str(expected_ast))
assert ast.to_tree() == expected_ast.to_tree()

sql = "RETRAIN xxx from (select 1)"
ast = parse_sql(sql, dialect='mindsdb')
assert to_single_line(str(ast)) == to_single_line(str(expected_ast))
assert ast.to_tree() == expected_ast.to_tree()

# finetune
sql = "FINETUNE MODEL xxx from (select 1)"
ast = parse_sql(sql, dialect='mindsdb')
expected_ast = FinetunePredictor(
name=Identifier('xxx'),
query_str='select 1',
)
assert to_single_line(str(ast)) == to_single_line(str(expected_ast))
assert ast.to_tree() == expected_ast.to_tree()

sql = "FINETUNE xxx from (select 1)"
ast = parse_sql(sql, dialect='mindsdb')
assert to_single_line(str(ast)) == to_single_line(str(expected_ast))
assert ast.to_tree() == expected_ast.to_tree()

0 comments on commit d34194b

Please sign in to comment.