Skip to content

Commit

Permalink
fix: dont instantiate if estimators loaded from save_dir
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Jun 14, 2024
1 parent 6726a99 commit 460c382
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchensemble/soft_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,10 @@ def fit(
):

# Instantiate base estimators and set attributes
for _ in range(self.n_estimators):
self.estimators_.append(self._make_estimator())
# dont instantiate if estimators loaded from save_dir
if len(self.estimators_) != self.n_estimators:
for _ in range(self.n_estimators):
self.estimators_.append(self._make_estimator())
self._validate_parameters(epochs, log_interval)
self.n_outputs = self._decide_n_outputs(train_loader)

Expand Down

0 comments on commit 460c382

Please sign in to comment.