Skip to content

Commit

Permalink
Merge pull request #1141 from mindsdb/nhits_fixes_5_30
Browse files Browse the repository at this point in the history
[nhits] early stop by default
  • Loading branch information
paxcema authored May 30, 2023
2 parents b8f750a + bd5731f commit c47eae4
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion lightwood/mixer/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
self.ts_analysis = ts_analysis
self.grouped_by = ['__default'] if not ts_analysis['tss'].group_by else ts_analysis['tss'].group_by
self.train_args = train_args.get('trainer_args', {}) if train_args else {}
self.train_args['early_stop_patience_steps'] = self.train_args.get('early_stop_patience_steps', 10)
self.conf_level = self.train_args.pop('conf_level', [90])
for level in self.conf_level:
assert 0 <= level <= 100, f'A provided level is not in the [0, 100] range (found: {level})'
Expand Down Expand Up @@ -131,7 +132,7 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
def partial_fit(self, train_data: EncodedDs, dev_data: EncodedDs, args: Optional[dict] = None) -> None:
# TODO: reimplement this with automatic novel-row differential
self.hyperparam_search = False
self.fit(dev_data, train_data)
self.fit(dev_data, train_data) # TODO: add support for passing args (e.g. n_epochs)
self.prepared = True

def __call__(self, ds: Union[EncodedDs, ConcatedEncodedDs],
Expand Down

0 comments on commit c47eae4

Please sign in to comment.