Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add basic support for typed forecasting models #330

Draft
wants to merge 4 commits into
base: staging
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindsdb_sql/parser/dialects/mindsdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .agents import CreateAgent, DropAgent, UpdateAgent
from .create_view import CreateView
from .create_database import CreateDatabase
from .create_predictor import CreatePredictor, CreateAnomalyDetectionModel
from .create_predictor import CreatePredictor, CreateAnomalyDetectionModel, CreateForecastingModel
from .drop_predictor import DropPredictor
from .retrain_predictor import RetrainPredictor
from .finetune_predictor import FinetunePredictor
Expand Down
7 changes: 7 additions & 0 deletions mindsdb_sql/parser/dialects/mindsdb/create_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._command = 'CREATE ANOMALY DETECTION MODEL'
self.task = Identifier('AnomalyDetection')


class CreateForecastingModel(CreatePredictorBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._command = 'CREATE FORECASTING MODEL'
self.task = Identifier('Forecasting')
2 changes: 2 additions & 0 deletions mindsdb_sql/parser/dialects/mindsdb/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MindsDBLexer(Lexer):
LATEST, LAST, HORIZON, USING,
ENGINE, TRAIN, PREDICT, PARAMETERS, JOB, CHATBOT, EVERY,PROJECT,
ANOMALY, DETECTION,
FORECASTING,
KNOWLEDGE_BASE, KNOWLEDGE_BASES,
SKILL,
AGENT,
Expand Down Expand Up @@ -118,6 +119,7 @@ class MindsDBLexer(Lexer):
# Typed models
ANOMALY = r'\bANOMALY\b'
DETECTION = r'\bDETECTION\b'
FORECASTING = r'\bFORECASTING\b'

KNOWLEDGE_BASE = r'\bKNOWLEDGE[_|\s]BASE\b'
KNOWLEDGE_BASES = r'\bKNOWLEDGE[_|\s]BASES\b'
Expand Down
58 changes: 57 additions & 1 deletion mindsdb_sql/parser/dialects/mindsdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from mindsdb_sql.parser.dialects.mindsdb.drop_predictor import DropPredictor
from mindsdb_sql.parser.dialects.mindsdb.drop_dataset import DropDataset
from mindsdb_sql.parser.dialects.mindsdb.drop_ml_engine import DropMLEngine
from mindsdb_sql.parser.dialects.mindsdb.create_predictor import CreatePredictor, CreateAnomalyDetectionModel
from mindsdb_sql.parser.dialects.mindsdb.create_predictor import CreatePredictor
from mindsdb_sql.parser.dialects.mindsdb.create_predictor import CreateAnomalyDetectionModel, CreateForecastingModel
from mindsdb_sql.parser.dialects.mindsdb.create_database import CreateDatabase
from mindsdb_sql.parser.dialects.mindsdb.create_ml_engine import CreateMLEngine
from mindsdb_sql.parser.dialects.mindsdb.create_view import CreateView
Expand Down Expand Up @@ -64,6 +65,7 @@ class MindsDBParser(Parser):
'create_integration',
'create_view',
'create_anomaly_detection_model',
'create_forecasting_model',
'drop_predictor',
'drop_datasource',
'drop_dataset',
Expand Down Expand Up @@ -817,6 +819,60 @@ def create_anomaly_detection_model(self, p):
p.create_anomaly_detection_model.using = p.kw_parameter_list
return p.create_anomaly_detection_model

## Forecasting
@_(
'CREATE FORECASTING MODEL identifier PREDICT result_columns', # for pre-trained models (e.g. TimeGPT)
'CREATE FORECASTING MODEL identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns',
# TODO add IF_NOT_EXISTS elegantly (should be low level BNF expansion)
)
def create_forecasting_model(self, p):
query_str = None
if hasattr(p, 'raw_query'):
query_str = tokens_to_string(p.raw_query)

if hasattr(p, 'identifier'):
# single identifier field
name = p.identifier
else:
name = p.identifier0

return CreateForecastingModel(
name=name,
targets=getattr(p, 'result_columns', None),
integration_name=getattr(p, 'identifier1', None),
query_str=query_str,
if_not_exists=hasattr(p, 'IF_NOT_EXISTS')
)

@_('create_forecasting_model WINDOW integer')
def create_forecasting_model(self, p):
p.create_forecasting_model.window = p.integer
return p.create_forecasting_model

@_('create_forecasting_model HORIZON integer')
def create_forecasting_model(self, p):
p.create_forecasting_model.horizon = p.integer
return p.create_forecasting_model

@_('create_forecasting_model GROUP_BY expr_list')
def create_forecasting_model(self, p):
group_by = p.expr_list
if not isinstance(group_by, list):
group_by = [group_by]

p.create_forecasting_model.group_by = group_by
return p.create_forecasting_model

@_('create_forecasting_model ORDER_BY ordering_terms')
def create_forecasting_model(self, p):
p.create_forecasting_model.order_by = p.ordering_terms
return p.create_forecasting_model

@_('create_forecasting_model USING kw_parameter_list')
def create_forecasting_model(self, p):
p.create_forecasting_model.using = p.kw_parameter_list
return p.create_forecasting_model

# RETRAIN PREDICTOR

@_('RETRAIN identifier',
Expand Down
33 changes: 33 additions & 0 deletions tests/test_parser/test_mindsdb/test_create_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,36 @@ def test_create_anomaly_detection_model(self):

assert to_single_line(str(ast)) == to_single_line(str(expected_ast))
assert ast.to_tree() == expected_ast.to_tree()

def test_create_forecasting_model(self):
create_clause = "CREATE FORECASTING MODEL forecasting_model"
rest_clause = """
FROM integration_name (select * FROM table)
PREDICT y
WINDOW 10
HORIZON 5
ORDER BY time
GROUP BY group
USING
param='a'
"""
sql = create_clause + rest_clause
ast = parse_sql(sql, dialect='mindsdb')

expected_ast = CreateForecastingModel(
name=Identifier('forecasting_model'),
task=Identifier('Forecasting'),
integration_name=Identifier('integration_name'),
query_str='select * FROM table',
targets=[Identifier('y')],
window=10,
horizon=5,
order_by=[OrderBy(Identifier('time'), direction='default')],
group_by=[Identifier('group')],
using={
'param': 'a'
}
)

assert to_single_line(str(ast)) == to_single_line(str(expected_ast))
assert ast.to_tree() == expected_ast.to_tree()
Loading