Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix parser to allow no storage to be passed to knowledge base #314

Merged
merged 4 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mindsdb_sql/parser/dialects/mindsdb/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(
self,
name,
model,
storage,
storage=None,
from_select=None,
params=None,
if_not_exists=False,
Expand All @@ -36,14 +36,14 @@ def __init__(

def to_tree(self, *args, level=0, **kwargs):
ind = indent(level)
storage_str = f"{ind} storage={self.storage.to_string()},\n" if self.storage else ""
out_str = f"""
{ind}CreateKnowledgeBase(
{ind} if_not_exists={self.if_not_exists},
{ind} name={self.name.to_string()},
{ind} from_query={self.from_query.to_tree(level=level+1) if self.from_query else None},
{ind} from_query={self.from_query.to_tree(level=level + 1) if self.from_query else None},
{ind} model={self.model.to_string()},
{ind} storage={self.storage.to_string()},
{ind} params={self.params}
{storage_str}{ind} params={self.params}
{ind})
"""
return out_str
Expand All @@ -55,13 +55,14 @@ def get_string(self, *args, **kwargs):
from_query_str = (
f"FROM ({self.from_query.get_string()})" if self.from_query else ""
)
storage_str = f" STORAGE = {self.storage.to_string()}" if self.storage else ""

out_str = (
f"CREATE KNOWLEDGE_BASE {'IF NOT EXISTS' if self.if_not_exists else ''}{self.name.to_string()} "
f"{from_query_str} "
f"USING {using_str},"
f" MODEL = {self.model.to_string()}, "
f" STORAGE {self.storage.to_string()} "
f"{storage_str}"
)

return out_str
Expand Down
68 changes: 32 additions & 36 deletions mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
"""
Unfortunately the rules are not iherited from base SQLParser, because it just doesn't work with Sly due to metaclass magic.
"""


class MindsDBParser(Parser):
log = ParserLogger()
tokens = MindsDBLexer.tokens
Expand Down Expand Up @@ -99,18 +101,21 @@ def create_kb(self, p):
from_query = getattr(p, 'select', None)
name = p.identifier
# check model and storage are in params
model = params.pop('model', None) or params.pop('MODEL', None) # case insensitive
storage = params.pop('storage', None) or params.pop('STORAGE', None) # case insensitive
params = {k.lower(): v for k, v in params.items()} # case insensitive
model = params.pop('model', None)
storage = params.pop('storage', None)

if not storage:
# convert to identifier
if isinstance(model, str):
storage = Identifier(storage)
dusvyat marked this conversation as resolved.
Show resolved Hide resolved

if not model:
if isinstance(model, str):
# convert to identifier
model = Identifier(model)
raise ParsingException('Missing model parameter')
if not storage:
if isinstance(storage, str):
# convert to identifier
storage = Identifier(storage)
raise ParsingException('Missing storage parameter')

if_not_exists = p.if_not_exists_or_empty

return CreateKnowledgeBase(
Expand All @@ -122,7 +127,6 @@ def create_kb(self, p):
if_not_exists=if_not_exists
)


@_('DROP KNOWLEDGE_BASE if_exists_or_empty identifier')
def drop_kb(self, p):
return DropKnowledgeBase(name=p.identifier, if_exists=p.if_exists_or_empty)
Expand Down Expand Up @@ -151,7 +155,6 @@ def create_chat_bot(self, p):
def update_chat_bot(self, p):
return UpdateChatBot(name=p.identifier, updated_params=p.kw_parameter_list)


@_('DROP CHATBOT identifier')
def drop_chat_bot(self, p):
return DropChatBot(name=p.identifier)
Expand All @@ -177,7 +180,6 @@ def create_trigger(self, p):
def drop_trigger(self, p):
return DropTrigger(name=p.identifier)


# -- Jobs --
@_('CREATE JOB if_not_exists_or_empty identifier LPAREN raw_query RPAREN job_schedule',
'CREATE JOB if_not_exists_or_empty identifier AS LPAREN raw_query RPAREN job_schedule',
Expand Down Expand Up @@ -221,7 +223,6 @@ def create_job(self, p):
'job_schedule job_schedule')
def job_schedule(self, p):


if isinstance(p[0], dict):
schedule = p[0]
for k in p[1].keys():
Expand All @@ -238,14 +239,13 @@ def job_schedule(self, p):
if hasattr(p, 'integer'):
value = f'{p[1]} {p[2]}'

schedule = {param: value}
schedule = {param:value}
return schedule

@_('DROP JOB if_exists_or_empty identifier')
def drop_job(self, p):
return DropJob(name=p.identifier, if_exists=p.if_exists_or_empty)


# Explain
@_('EXPLAIN identifier')
def explain(self, p):
Expand Down Expand Up @@ -368,9 +368,9 @@ def transact_property_list(self, p):
'transact_access_mode')
def transact_property(self, p):
if hasattr(p, 'transact_level'):
return {'type': 'iso_level', 'value': p.transact_level}
return {'type':'iso_level', 'value':p.transact_level}
else:
return {'type': 'access_mode', 'value': p.transact_access_mode}
return {'type':'access_mode', 'value':p.transact_access_mode}

@_('REPEATABLE READ',
'READ COMMITTED',
Expand Down Expand Up @@ -525,11 +525,11 @@ def show(self, p):
@_('SHOW REPLICA STATUS FOR CHANNEL id',
'SHOW SLAVE STATUS FOR CHANNEL id',
'SHOW REPLICA STATUS',
'SHOW SLAVE STATUS',)
'SHOW SLAVE STATUS', )
def show(self, p):
name = getattr(p, 'id', None)
return Show(
category='REPLICA STATUS', # slave = replica
category='REPLICA STATUS', # slave = replica
name=name
)

Expand Down Expand Up @@ -662,7 +662,7 @@ def drop_table(self, p):
return DropTables(tables=[p.identifier], if_exists=p.if_exists_or_empty)

# create table
@_('CREATE TABLE identifier select', # TODO tests failing without it
@_('CREATE TABLE identifier select', # TODO tests failing without it
'CREATE TABLE if_not_exists_or_empty identifier select',
'CREATE TABLE if_not_exists_or_empty identifier LPAREN select RPAREN',
'CREATE OR REPLACE TABLE identifier select',
Expand Down Expand Up @@ -772,7 +772,6 @@ def create_anomaly_detection_model(self, p):
p.create_anomaly_detection_model.using = p.kw_parameter_list
return p.create_anomaly_detection_model


# RETRAIN PREDICTOR

@_('RETRAIN identifier',
Expand Down Expand Up @@ -816,7 +815,7 @@ def create_predictor(self, p):
)

@_('EVALUATE identifier FROM LPAREN raw_query RPAREN',
'EVALUATE identifier FROM LPAREN raw_query RPAREN USING kw_parameter_list',)
'EVALUATE identifier FROM LPAREN raw_query RPAREN USING kw_parameter_list', )
def evaluate(self, p):
if hasattr(p, 'identifier'):
# single identifier field
Expand Down Expand Up @@ -872,10 +871,10 @@ def create_integration(self, p):
parameters = p.json

return CreateDatabase(name=p.database_engine['identifier'],
engine=p.database_engine['engine'],
is_replace=is_replace,
parameters=parameters,
if_not_exists=p.database_engine['if_not_exists'])
engine=p.database_engine['engine'],
is_replace=is_replace,
parameters=parameters,
if_not_exists=p.database_engine['if_not_exists'])

@_('DATABASE if_not_exists_or_empty identifier',
'DATABASE if_not_exists_or_empty identifier ENGINE string',
Expand All @@ -888,7 +887,7 @@ def database_engine(self, p):
engine = None
if hasattr(p, 'string'):
engine = p.string
return {'identifier': p.identifier, 'engine': engine, 'if_not_exists': p.if_not_exists_or_empty}
return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty}

# UNION / UNION ALL
@_('select UNION select')
Expand Down Expand Up @@ -1090,7 +1089,7 @@ def join_tables(self, p):

@_('from_table_aliased COMMA from_table_aliased',
'join_tables_implicit COMMA from_table_aliased')
def join_tables_implicit (self, p):
def join_tables_implicit(self, p):
return Join(left=p[0],
right=p[2],
join_type=JoinType.INNER_JOIN,
Expand Down Expand Up @@ -1199,7 +1198,6 @@ def result_column(self, p):
def result_column(self, p):
return p.star


@_('expr',
'function',
'window_function',
Expand Down Expand Up @@ -1366,7 +1364,6 @@ def expr(self, p):
arg1 = p.expr1
return BinaryOperation(op=p[1], args=(p[0], arg1))


@_('MINUS expr %prec UMINUS',
'NOT expr %prec UNOT', )
def expr(self, p):
Expand All @@ -1386,7 +1383,7 @@ def update_parameter_list(self, p):

@_('id EQUALS expr')
def update_parameter(self, p):
return {p.id: p.expr}
return {p.id:p.expr}

# EXPRESSIONS

Expand Down Expand Up @@ -1454,7 +1451,7 @@ def kw_parameter(self, p):
key = getattr(p, 'identifier', None) or getattr(p, 'identifier0', None)
assert key is not None
key = '.'.join(key.parts)
return {key: p[2]}
return {key:p[2]}

# json

Expand All @@ -1473,7 +1470,7 @@ def json_element_list(self, p):

@_('string COLON json_value')
def json_element(self, p):
return {p.string: p.json_value}
return {p.string:p.json_value}

# json_array

Expand Down Expand Up @@ -1508,7 +1505,6 @@ def json_value(self, p):
return False
return p[0]


@_('identifier DOT identifier',
'identifier DOT integer',
'identifier DOT star')
Expand Down Expand Up @@ -1536,7 +1532,7 @@ def string(self, p):
def parameter(self, p):
return Parameter(value=p.PARAMETER)

# convert to types
# convert to types
@_('ID',
'BEGIN',
'CAST',
Expand Down Expand Up @@ -1614,7 +1610,7 @@ def parameter(self, p):
'WARNINGS',
'MODEL',
'MODELS',
)
)
def id(self, p):
return p[0]

Expand All @@ -1638,11 +1634,11 @@ def dquote_string(self, p):

@_('LPAREN raw_query RPAREN')
def raw_query(self, p):
return [ p._slice[0] ] + p[1] + [ p._slice[2] ]
return [p._slice[0]] + p[1] + [p._slice[2]]

@_('raw_query LPAREN RPAREN')
def raw_query(self, p):
return p[0] + [ p._slice[1], p._slice[2] ]
return p[0] + [p._slice[1], p._slice[2]]

@_('raw_query raw_query')
def raw_query(self, p):
Expand Down
16 changes: 12 additions & 4 deletions tests/test_parser/test_mindsdb/test_knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,23 @@ def test_create_knowledeg_base():
ast = parse_sql(sql, dialect="mindsdb")

# create without STORAGE
# TODO: this should be an error
# we may allow this in the future when we have a default storage
sql = """
CREATE KNOWLEDGE_BASE my_knowledge_base
USING
MODEL = mindsdb.my_embedding_model
"""
with pytest.raises(Exception):
ast = parse_sql(sql, dialect="mindsdb")

expected_ast = CreateKnowledgeBase(
name=Identifier("my_knowledge_base"),
if_not_exists=False,
model=Identifier(parts=["mindsdb", "my_embedding_model"]),
from_select=None,
params={},
)

ast = parse_sql(sql, dialect="mindsdb")

assert ast == expected_ast

# create if not exists
sql = """
Expand Down
Loading