Skip to content

Commit

Permalink
feat: add on_epoch_end_cb to call at end of each epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Jun 14, 2024
1 parent 460c382 commit a867c36
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchensemble/soft_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def fit(
test_loader=None,
save_model=True,
save_dir=None,
on_epoch_end_cb=None
):

# Instantiate base estimators and set attributes
Expand Down Expand Up @@ -297,6 +298,9 @@ def fit(
else:
scheduler.step()

# Call on epoch end
if on_epoch_end_cb:
on_epoch_end_cb(epoch)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)

Expand Down Expand Up @@ -392,6 +396,7 @@ def fit(
test_loader=None,
save_model=True,
save_dir=None,
on_epoch_end_cb=None
):
super().fit(
train_loader=train_loader,
Expand All @@ -401,6 +406,7 @@ def fit(
test_loader=test_loader,
save_model=save_model,
save_dir=save_dir,
on_epoch_end_cb=on_epoch_end_cb,
)

@torchensemble_model_doc(
Expand Down

0 comments on commit a867c36

Please sign in to comment.