Skip to content

Commit

Permalink
Rebuild TSFinalTrainStep.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxjohn committed Mar 11, 2022
1 parent ec795fd commit 35c2290
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 10 deletions.
4 changes: 3 additions & 1 deletion hyperts/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def append_early_stopping_callbacks(cbs):
raise ValueError("Forecast task 'timestamp' cannot be None.")

if task in consts.TASK_LIST_FORECAST and covariables is None:
logger.warning('If the data contains covariables, specify the covariable column names.')
logger.info('If the data contains covariables, specify the covariable column names.')

if mode != consts.Mode_STATS:
try:
Expand Down Expand Up @@ -358,6 +358,8 @@ def append_early_stopping_callbacks(cbs):

if freq is None:
freq = tb.infer_ts_freq(X_train, ts_name=timestamp)
if freq is None:
raise RuntimeError('Unable to infer correct frequency, please check data or specify frequency.')

# 7. Covarite Transformer
if covariables is not None:
Expand Down
30 changes: 27 additions & 3 deletions hyperts/framework/compete.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def fit_transform(self, hyper_model, X_train, y_train, X_test=None, X_eval=None,
# 4. eval variables data process
if X_eval is None or y_eval is None:
if self.task in consts.TASK_LIST_FORECAST:
if X_train.shape[0] <= 2*consts.DEFAULT_FORECAST_EVAL_SIZE or isinstance(self.experiment.eval_size, int):
if int(X_train.shape[0]*consts.DEFAULT_MIN_EVAL_SIZE)<=10 or isinstance(self.experiment.eval_size, int):
eval_horizon = self.experiment.eval_size
else:
eval_horizon = consts.DEFAULT_FORECAST_EVAL_SIZE
eval_horizon = consts.DEFAULT_MIN_EVAL_SIZE
X_train, X_eval, y_train, y_eval = \
tb.temporal_train_test_split(X_train, y_train, test_size=eval_horizon)
self.step_progress('split into train set and eval set')
Expand Down Expand Up @@ -284,6 +284,30 @@ def get_ensemble(self, estimators, X_train, y_train):
return tb.greedy_ensemble(ensemble_task, estimators, scoring=self.scorer, ensemble_size=self.ensemble_size)


class TSFinalTrainStep(FinalTrainStep):
def __init__(self, experiment, name, mode=None, retrain_on_wholedata=False):
super().__init__(experiment, name)

self.mode = mode
self.retrain_on_wholedata = retrain_on_wholedata

def build_estimator(self, hyper_model, X_train, y_train, X_test=None, X_eval=None, y_eval=None, **kwargs):
if self.retrain_on_wholedata:
trial = hyper_model.get_best_trial()
tb = get_tool_box(X_train, X_eval)
X_all = tb.concat_df([X_train, X_eval], axis=0)
y_all = tb.concat_df([y_train, y_eval], axis=0)

if self.mode != consts.Mode_STATS:
kwargs.update({'epochs': consts.FINAL_TRAINING_EPOCHS})

estimator = hyper_model.final_train(trial.space_sample, X_all, y_all, **kwargs)
else:
estimator = hyper_model.load_estimator(hyper_model.get_best_trial().model_file)

return estimator


class TSPipeline:
"""Pipeline Extension for Time Series Analysis.
Expand Down Expand Up @@ -731,7 +755,7 @@ def __init__(self, hyper_model, X_train, y_train, X_eval=None, y_eval=None, X_te
# ensemble_size=ensemble_size))
# else:
# final train step
steps.append(FinalTrainStep(self, consts.StepName_FINAL_TRAINING, retrain_on_wholedata=False))
steps.append(TSFinalTrainStep(self, consts.StepName_FINAL_TRAINING, retrain_on_wholedata=True))

# ignore warnings
import warnings
Expand Down
9 changes: 5 additions & 4 deletions hyperts/framework/dl/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def fit(self,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False):
use_multiprocessing=False,
**kwargs):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
Parameters
Expand Down Expand Up @@ -666,15 +667,15 @@ def _from_tensor_slices(self, X, y, batch_size, epochs=None, shuffle=False, drop

dataset = tf.data.Dataset.from_tensor_slices((data, y))

if epochs is not None:
dataset = dataset.repeat(epochs)

if shuffle:
dataset = dataset.shuffle(y.shape[0])

dataset = dataset.batch(batch_size, drop_remainder=drop_remainder and y.shape[0] >= batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

if epochs is not None:
dataset = dataset.repeat(epochs+1)

return dataset

def _preprocessor(self, X, y):
Expand Down
2 changes: 1 addition & 1 deletion hyperts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def set_random_state(seed=9527, mode=consts.Mode_STATS):

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['TF_DETERMINISTIC_OPS'] = '0'
os.environ['TF_DETERMINISTIC_OPS'] = '1'
np.random.seed(seed)

if mode == consts.Mode_DL:
Expand Down
3 changes: 2 additions & 1 deletion hyperts/utils/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

TIMESTAMP = 'timestamp'
DEFAULT_EVAL_SIZE = 0.2
DEFAULT_FORECAST_EVAL_SIZE = 10
DEFAULT_MIN_EVAL_SIZE = 0.05
NAN_DROP_SIZE = 0.6
FINAL_TRAINING_EPOCHS = 120

Task_UNIVARIATE_FORECAST = 'univariate-forecast'
Task_MULTIVARIATE_FORECAST = 'multivariate-forecast'
Expand Down

0 comments on commit 35c2290

Please sign in to comment.