Skip to content

Commit

Permalink
Merge pull request #7 from salcc/main
Browse files Browse the repository at this point in the history
Optimize memory and improve code
  • Loading branch information
davidbonet authored Dec 8, 2024
2 parents 9ec0033 + 860e8b5 commit 7c1331f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 66 deletions.
112 changes: 58 additions & 54 deletions hyperfast/hyperfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class HyperFastClassifier(BaseEstimator, ClassifierMixin):
n_ensemble (int): Number of ensemble models to use.
batch_size (int): Size of the batch for weight prediction and ensembling.
nn_bias (bool): Whether to use nearest neighbor bias.
nn_bias_mini_batches (bool): Whether to use mini-batches of size 128 for nearest neighbor bias.
optimization (str or None): Strategy for optimization, can be None, 'optimize', or 'ensemble_optimize'.
optimize_steps (int): Number of optimization steps.
torch_pca (bool): Whether to use PyTorch-based PCA optimized for GPU (fast) or scikit-learn PCA (slower).
Expand All @@ -58,6 +59,7 @@ def __init__(
n_ensemble: int = 16,
batch_size: int = 2048,
nn_bias: bool = False,
nn_bias_mini_batches: bool = True,
optimization: str | None = "ensemble_optimize",
optimize_steps: int = 64,
torch_pca: bool = True,
Expand All @@ -72,6 +74,7 @@ def __init__(
self.n_ensemble = n_ensemble
self.batch_size = batch_size
self.nn_bias = nn_bias
self.nn_bias_mini_batches = nn_bias_mini_batches
self.optimization = optimization
self.optimize_steps = optimize_steps
self.torch_pca = torch_pca
Expand Down Expand Up @@ -112,7 +115,7 @@ def _initialize_model(self, cfg: SimpleNamespace) -> HyperFast:
flush=True,
)
model.load_state_dict(
torch.load(cfg.model_path, map_location=torch.device(cfg.device))
torch.load(cfg.model_path, map_location=torch.device(cfg.device), weights_only=True)
)
print(
f"Model loaded from {cfg.model_path} on {cfg.device} device.",
Expand Down Expand Up @@ -208,9 +211,7 @@ def _preprocess_fitting_data(
y = column_or_1d(y, warn=True)
self.n_features_in_ = x.shape[1]
self.classes_, y = np.unique(y, return_inverse=True)
return torch.tensor(x, dtype=torch.float).to(self.device), torch.tensor(
y, dtype=torch.long
).to(self.device)
return torch.tensor(x, dtype=torch.float), torch.tensor(y, dtype=torch.long)

def _preprocess_test_data(
self,
Expand Down Expand Up @@ -240,7 +241,7 @@ def _preprocess_test_data(
x_test = check_array(x_test)
# Standardize data
x_test = self._scaler.transform(x_test)
return torch.tensor(x_test, dtype=torch.float).to(self.device)
return torch.tensor(x_test, dtype=torch.float)

def _initialize_fit_attributes(self) -> None:
self._rfs = []
Expand Down Expand Up @@ -314,6 +315,7 @@ def fit(

for n in range(self.n_ensemble):
X_pred, y_pred = self._sample_data(X, y)
X_pred, y_pred = X_pred.to(self.device), y_pred.to(self.device)
self.n_classes_ = len(torch.unique(y_pred).cpu().numpy())

rf, pca, main_network = self._model(X_pred, y_pred, self.n_classes_)
Expand Down Expand Up @@ -362,57 +364,59 @@ def fit(
def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
check_is_fitted(self)
X = self._preprocess_test_data(X)
with torch.no_grad():
orig_X = X
yhats = []
for jj in range(len(self._main_networks)):
main_network = self._main_networks[jj]
rf = self._rfs[jj]
pca = self._pcas[jj]
X_pred = self._X_preds[jj]
y_pred = self._y_preds[jj]
if self.feature_bagging:
X_ = X[:, self.selected_features[jj]]
orig_X_ = orig_X[:, self.selected_features[jj]]
else:
X_ = X
orig_X_ = orig_X

X_transformed = transform_data_for_main_network(
X=X_, cfg=self._cfg, rf=rf, pca=pca
)
outputs, intermediate_activations = forward_main_network(
X_transformed, main_network
)

if self.nn_bias:
X_pred_ = transform_data_for_main_network(
X=X_pred, cfg=self._cfg, rf=rf, pca=pca
X_dataset = torch.utils.data.TensorDataset(X)
X_loader = torch.utils.data.DataLoader(X_dataset, batch_size=self.batch_size, shuffle=False)
all_yhats = []
for X_batch in X_loader:
X_batch = X_batch[0].to(self.device)
with torch.no_grad():
orig_X = X_batch
yhats = []
for jj in range(len(self._main_networks)):
main_network = self._main_networks[jj]
rf = self._rfs[jj]
pca = self._pcas[jj]
X_pred = self._X_preds[jj]
y_pred = self._y_preds[jj]
if self.feature_bagging:
X_ = X_batch[:, self.selected_features[jj]]
orig_X_ = orig_X[:, self.selected_features[jj]]
else:
X_ = X_batch
orig_X_ = orig_X

X_transformed = transform_data_for_main_network(
X=X_, cfg=self._cfg, rf=rf, pca=pca
)
outputs_pred, intermediate_activations_pred = forward_main_network(
X_pred_, main_network
outputs, intermediate_activations = forward_main_network(
X_transformed, main_network
)
for bb, bias in enumerate(self._model.nn_bias):
if bb == 0:
outputs = nn_bias_logits(
outputs, orig_X_, X_pred, y_pred, bias, self.n_classes_
)
elif bb == 1:
outputs = nn_bias_logits(
outputs,
intermediate_activations,
intermediate_activations_pred,
y_pred,
bias,
self.n_classes_,
)

predicted = F.softmax(outputs, dim=1)
yhats.append(predicted)

yhats = torch.stack(yhats)
yhats = torch.mean(yhats, axis=0)
return yhats.cpu().numpy()

if self.nn_bias:
X_pred_ = transform_data_for_main_network(
X=X_pred, cfg=self._cfg, rf=rf, pca=pca
)
outputs_pred, intermediate_activations_pred = forward_main_network(
X_pred_, main_network
)
for bb, bias in enumerate(self._model.nn_bias):
if bb == 0:
outputs = nn_bias_logits(
outputs, orig_X_, X_pred, y_pred, bias, self.n_classes_, self.nn_bias_mini_batches
)
elif bb == 1:
outputs = nn_bias_logits(
outputs, intermediate_activations, intermediate_activations_pred, y_pred, bias, self.n_classes_, self.nn_bias_mini_batches,
)

predicted = F.softmax(outputs, dim=1)
yhats.append(predicted)

yhats = torch.stack(yhats)
yhats = torch.mean(yhats, axis=0)
yhats = yhats.cpu().numpy()
all_yhats.append(yhats)
return np.concatenate(all_yhats, axis=0)

def predict(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
outputs = self.predict_proba(X)
Expand Down
2 changes: 1 addition & 1 deletion hyperfast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
from sklearn.decomposition import PCA
from .utils import *
from .utils import TorchPCA, get_main_weights, forward_linear_layer


class HyperFast(nn.Module):
Expand Down
19 changes: 8 additions & 11 deletions hyperfast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from types import SimpleNamespace


def seed_everything(seed: int):
Expand All @@ -22,11 +21,11 @@ def seed_everything(seed: int):


def nn_bias_logits(
test_logits, test_samples, train_samples, train_labels, bias_param, n_classes
test_logits, test_samples, train_samples, train_labels, bias_param, n_classes, mini_batches
):
with torch.no_grad():
nn = NN(train_samples, train_labels)
preds = nn.predict(test_samples)
preds = nn.predict(test_samples, mini_batches=mini_batches)
preds_onehot = F.one_hot(preds, n_classes)
test_logits[preds_onehot.bool()] += bias_param
return test_logits
Expand Down Expand Up @@ -240,6 +239,7 @@ def fine_tune_main_network(

for step in range(optimize_steps):
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = main_model(inputs, targets)
loss = criterion(outputs, targets)
Expand Down Expand Up @@ -294,7 +294,7 @@ def transform_data_for_main_network(X, cfg, rf, pca):


def distance_matrix(x, y=None, p=2):
y = x if type(y) == type(None) else y
y = x if y is None else y

n = x.size(0)
m = y.size(0)
Expand All @@ -321,11 +321,8 @@ def train(self, X, Y):
self.train_pts = X
self.train_label = Y

def __call__(self, x, mini_batches=True):
return self.predict(x)

def predict(self, x, mini_batches=True):
if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
if self.train_pts is None or self.train_label is None:
name = self.__class__.__name__
raise RuntimeError(
f"{name} wasn't trained. Need to execute {name}.train() first"
Expand All @@ -341,7 +338,7 @@ def predict(self, x, mini_batches=True):
num_batches = math.ceil(x.shape[0] / batch_size)
labels = []
for ii in range(num_batches):
x_ = x[batch_size * ii : batch_size * (ii + 1), :]
x_ = x[batch_size * ii:batch_size * (ii + 1), :]
dist = distance_matrix(x_, self.train_pts, self.p)
labels_ = torch.argmin(dist, dim=1)
labels.append(labels_)
Expand All @@ -350,7 +347,7 @@ def predict(self, x, mini_batches=True):
return self.train_label[labels]

def predict_from_training_with_LOO(self, mini_batches=True):
if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
if self.train_pts is None or self.train_label is None:
name = self.__class__.__name__
raise RuntimeError(
f"{name} wasn't trained. Need to execute {name}.train() first"
Expand All @@ -365,7 +362,7 @@ def predict_from_training_with_LOO(self, mini_batches=True):
num_batches = math.ceil(self.train_pts.shape[0] / batch_size)
labels = []
for ii in range(num_batches):
x_ = self.train_pts[batch_size * ii : batch_size * (ii + 1), :]
x_ = self.train_pts[batch_size * ii:batch_size * (ii + 1), :]
dist = distance_matrix(x_, self.train_pts, self.p)
dist.fill_diagonal_(float("inf"))
labels_ = torch.argmin(dist, dim=1)
Expand Down

0 comments on commit 7c1331f

Please sign in to comment.