From 90334143d6fbe966a35be57941c3e8ee55e7755f Mon Sep 17 00:00:00 2001 From: Xiaojing Zhang <80235074+zhangxjohn@users.noreply.github.com> Date: Fri, 11 Mar 2022 09:57:58 +0800 Subject: [PATCH] Refactor forecast stats wrappers. --- hyperts/framework/wrappers/stats_wrappers.py | 6 ++++++ hyperts/tests/experiment/task_test.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/hyperts/framework/wrappers/stats_wrappers.py b/hyperts/framework/wrappers/stats_wrappers.py index 5337e83..04cbceb 100644 --- a/hyperts/framework/wrappers/stats_wrappers.py +++ b/hyperts/framework/wrappers/stats_wrappers.py @@ -83,6 +83,9 @@ def fit(self, X, y=None, **kwargs): def predict(self, X, **kwargs): last_date = X[self.timestamp].tail(1).to_list()[0].to_pydatetime() + if last_date == self._end_date: + raise ValueError('The end date of the valid set must be ' + 'less than the end date of the test set.') steps = int((last_date - self._end_date).total_seconds() / self._freq) predict_result = self.model.forecast(steps=steps).values @@ -120,6 +123,9 @@ def fit(self, X, y=None, **kwargs): def predict(self, X, **kwargs): last_date = X[self.timestamp].tail(1).to_list()[0].to_pydatetime() + if last_date == self._end_date: + raise ValueError('The end date of the valid set must be ' + 'less than the end date of the test set.') steps = int((last_date - self._end_date).total_seconds() / self._freq) predict_result = self.model.forecast(self.model.y, steps=steps) diff --git a/hyperts/tests/experiment/task_test.py b/hyperts/tests/experiment/task_test.py index 8c6f49e..871d58f 100644 --- a/hyperts/tests/experiment/task_test.py +++ b/hyperts/tests/experiment/task_test.py @@ -29,7 +29,7 @@ def test_univariate_forecast(self): optimize_direction=OptimizeDirection.Minimize) hyper_model = HyperTS(rs, task='univariate-forecast', reward_metric='rmse', callbacks=[SummaryCallback()]) - exp = TSCompeteExperiment(hyper_model, X_train, y_train, X_eval=X_test, y_eval=y_test, + exp = TSCompeteExperiment(hyper_model, X_train, y_train, timestamp_col='ds', covariate_cols=[['id'], cs.covariables_], covariate_cleaner=cs) @@ -46,7 +46,7 @@ def test_multivariate_forecast(self): optimize_direction=OptimizeDirection.Minimize) hyper_model = HyperTS(rs, task='multivariate-forecast', reward_metric='rmse', callbacks=[SummaryCallback()]) - exp = TSCompeteExperiment(hyper_model, X_train, y_train, X_eval=X_test, y_eval=y_test, timestamp_col='ds') + exp = TSCompeteExperiment(hyper_model, X_train, y_train, timestamp_col='ds') pipeline_model = exp.run(max_trials=3) y_pred = pipeline_model.predict(X_test) @@ -60,7 +60,7 @@ def test_univariate_classification(self): optimize_direction=OptimizeDirection.Maximize) hyper_model = HyperTS(rs, task='univariate-multiclass', reward_metric='accuracy', callbacks=[SummaryCallback()]) - exp = TSCompeteExperiment(hyper_model, X_train, y_train, X_eval=X_test, y_eval=y_test) + exp = TSCompeteExperiment(hyper_model, X_train, y_train) pipeline_model = exp.run(max_trials=3) y_pred = pipeline_model.predict(X_test)