Skip to content

Commit

Permalink
Merge pull request #8 from salcc/main
Browse files Browse the repository at this point in the history
Fix n_classes_
  • Loading branch information
davidbonet authored Dec 10, 2024
2 parents 7465834 + 08146c4 commit ffdd4e1
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions hyperfast/hyperfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _preprocess_fitting_data(
y = column_or_1d(y, warn=True)
self.n_features_in_ = x.shape[1]
self.classes_, y = np.unique(y, return_inverse=True)
self.n_classes_ = len(self.classes_)
return torch.tensor(x, dtype=torch.float), torch.tensor(y, dtype=torch.long)

def _preprocess_test_data(
Expand Down Expand Up @@ -286,7 +287,7 @@ def _sample_data(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
X_pred = torch.repeat_interleave(X_pred, n_repeats, axis=0)
y_pred = torch.repeat_interleave(y_pred, n_repeats, axis=0)
return X_pred, y_pred

def _move_to_device(self, data, device=None):
if device is None:
device = self.device
Expand All @@ -295,7 +296,7 @@ def _move_to_device(self, data, device=None):
elif isinstance(data, TorchPCA):
data.mean_, data.components_ = data.mean_.to(device), data.components_.to(device)
return data
elif isinstance(data, PCA): # scikit-learn PCA
elif isinstance(data, PCA): # scikit-learn PCA
return data
return data.to(device)

Expand Down Expand Up @@ -342,7 +343,6 @@ def fit(
for n in range(self.n_ensemble):
X_pred, y_pred = self._sample_data(X, y)
X_pred, y_pred = X_pred.to(self.device), y_pred.to(self.device)
self.n_classes_ = len(torch.unique(y_pred).cpu().numpy())
with torch.no_grad():
rf, pca, main_network, nnbias = self._model(X_pred, y_pred, self.n_classes_)
if self.optimization == "ensemble_optimize":
Expand Down Expand Up @@ -438,14 +438,14 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:

predicted = F.softmax(outputs, dim=1)
yhats.append(predicted)
for data in [rf, pca, main_network, nnbias, X_pred, y_pred,

for data in [rf, pca, main_network, nnbias, X_pred, y_pred,
X_transformed, outputs, intermediate_activations]:
data = self._move_to_cpu(data)
if self.nn_bias:
for data in [X_pred_, outputs_pred, intermediate_activations_pred]:
data = self._move_to_cpu(data)

yhats = torch.stack(yhats)
yhats = torch.mean(yhats, axis=0)
yhats = yhats.cpu().numpy()
Expand Down

0 comments on commit ffdd4e1

Please sign in to comment.