diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c730d91 --- /dev/null +++ b/.gitignore @@ -0,0 +1,60 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.uv/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Jupyter +.ipynb_checkpoints +*.ipynb + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ + +# Model checkpoints +NeurCAMCheckpoints/ +*.pth +*.pt + +# Logs +*.log diff --git a/Loss.py b/Loss.py deleted file mode 100644 index 6d0b7f5..0000000 --- a/Loss.py +++ /dev/null @@ -1,35 +0,0 @@ -from numpy import real -import torch -from torch import nn - -class FuzzyCMeansLoss(nn.Module): - def __init__(self, m=1.0, return_centroids=False): - super(FuzzyCMeansLoss, self).__init__() - self.m = m # hyperparameter that controls the fuzziness of the cluster - self.return_centroids = return_centroids - - def forward(self, X, W, centroids=None): - """ - X is the input data of shape (batch_size, n_features) - W is the fuzzy membership matrix of shape (batch_size, cluster_size) - centroids is the cluster centroids of shape (cluster_size, n_features) - """ - # Raise W to the power m - W_raised = torch.pow(W, self.m) - - # Calculate centroids if not provided - if centroids is None: - centroids_num = torch.sum(W_raised.unsqueeze(2) * X.unsqueeze(1), axis=0) - centroids_den = torch.sum(W_raised, axis=0).unsqueeze(1) + 1e-8 # Adding epsilon for numerical stability - centroids = centroids_num / centroids_den - - # Calculate distances (batch_size, cluster_size, n_features) - distances = torch.norm(X.unsqueeze(1) - centroids, dim=2, p=2) # Euclidean distance - - # Calculate the loss - loss = torch.mean(torch.pow(distances, 2) * W_raised) - - if self.return_centroids: - return loss, centroids - else: - return loss \ No newline at end of file diff --git a/NeurCAM.py b/NeurCAM.py deleted file mode 100644 index 659829e..0000000 --- a/NeurCAM.py +++ /dev/null @@ -1,594 +0,0 @@ -import pandas as pd -import numpy as np -import os -import gc -from tqdm import trange - -import torch -from torch import nn -import torch.nn.functional as F -from torch.utils.data import DataLoader, TensorDataset -from sklearn.cluster import MiniBatchKMeans, KMeans -import random - -from entmax import Entmax15 -from Loss import FuzzyCMeansLoss - - -class NeurCAM: - def __init__(self, - k, - random_state:int = 42, - m: float=1.05, - hidden_layers:list[int] = [128,128], - n_bases:int = 64, - learning_rate: float = 2e-3, - epochs: int = 5000, - batch_size: int = 512, - single_feature_channels: float | int = 1.0, - pairwise_feature_channels: float | int = 0.0, - warmup_ratio: float | int = 0.4, - o1_anneal_ratio: float | int = 0.1, - o2_anneal_ratio: float | int = 0.1, - min_temp: float =1e-5, - kl_weight:float = 1.0, - smart_init: str = 'none', - model_dir: str = 'NeurCAMCheckpoints', - device: str = 'auto', - verbose = True - ): - """ - NeurCAM class for interpretable clustering. - - Attributes: - k (int): Number of clusters. - random_state (int): Random seed for reproducibility. - m (float): Fuzziness parameter. - hidden_dim (int): Dimension of the hidden layer for the backbone. - n_bases (int): output dimension of the backbone. - learning_rate (float): Learning rate for the optimizer. - epochs (int): Total number of training epochs. - batch_size (int): Batch size for training. - single_feature_channels (float | int): Number of channels for single feature interactions. If values are <=1.0, it is interpreted as a ratio of the number of features. If values are >1.0, it is interpreted as the number of channels. (set to 1.01 for one channel per feature) - pairwise_feature_channels (float | int): Number of channels for pairwise feature interactions. - warmup_ratio (float | int): Ratio of warmup epochs. - o1_anneal_ratio (float | int): Ratio of first annealing phase. - o2_anneal_ratio (float | int): Ratio of second annealing phase. - min_temp (float): Minimum temperature for annealing. - kl_weight (float): Weight for the KL divergence loss. - smart_init (str): Clustering initialization method. - model_dir (str): Directory to save model checkpoints. - """ - self.k = k - self.random_state = random_state - self.m = m - self.hidden_layers = hidden_layers - self.n_bases = n_bases - self.learning_rate = learning_rate - self.epochs = epochs - self.batch_size = batch_size - self.sf_channels = single_feature_channels - self.pf_channels = pairwise_feature_channels - self.warmup_ratio = warmup_ratio - self.o1_anneal_ratio = o1_anneal_ratio - self.o2_anneal_ratio = o2_anneal_ratio - self.min_temp = min_temp - self.kl_weight = kl_weight - self.smart_init = smart_init - self.model_dir = model_dir - self.model = None - self.verbose = verbose - - self.warmup_epochs = int(self.epochs * self.warmup_ratio) - self.o1_anneal_epochs = int(self.epochs * self.o1_anneal_ratio) - self.o2_anneal_epochs = int(self.epochs * self.o2_anneal_ratio) - self.feature_names = None - self.n_features = -1 - self.repr_dim = -1 - self.device = device - if self.device == 'auto': - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - - - def fit(self, X: pd.DataFrame| np.ndarray, X_repr: pd.DataFrame | np.ndarray = None): - """ - Fit the NeurCAM model. - Args: - X (pd.DataFrame | np.ndarray): Input data in the interpretable space. - X_repr (pd.DataFrame | np.ndarray): Input data in transformed/latent space (optional). - """ - # seed random state - torch.manual_seed(self.random_state) - np.random.seed(self.random_state) - random.seed(self.random_state) - - # create model dir - if isinstance(X, pd.DataFrame): - self.feature_names = X.columns - X = X.values - if X_repr is not None: - if isinstance(X_repr, pd.DataFrame): - X_repr = X_repr.values - else: - X_repr = X - - self.n_features = X.shape[1] - self.repr_dim = X_repr.shape[1] - dataset = TensorDataset(torch.tensor(X, dtype=torch.float32, device=self.device), torch.tensor(X_repr, dtype=torch.float32, device=self.device)) - - dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) - - - if self.sf_channels <= 1.0: - self.single_feature_channels = max(0, int(self.sf_channels * self.n_features)) - else: - self.single_feature_channels = int(self.sf_channels) - if self.pf_channels <= 1.0: - self.pairwise_feature_channels = max(0, int(self.pf_channels * self.n_features)) - else: - self.pairwise_feature_channels = int(self.pf_channels) - - model = NeurCAMModel( - input_dim = self.n_features, - repr_dim = self.repr_dim, - o1_channels = self.single_feature_channels, - o2_channels = self.pairwise_feature_channels, - n_bases = self.n_bases, - hidden_layers = self.hidden_layers, - n_clusters = self.k - ) - # move to device - model = model.to(self.device) - - - if self.smart_init == 'kmeans': - kmeans = KMeans(n_clusters=self.k, random_state=self.random_state, n_init=10) - kmeans.fit(X_repr) - tens = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32) - tens.to(model.centroids.data.device) - model.centroids.data = tens - - elif self.smart_init == 'mbkmeans': - kmeans = MiniBatchKMeans(n_clusters=self.k, random_state=self.random_state, n_init=10) - kmeans.fit(X_repr) - tens = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32) - tens.to(model.centroids.data.device) - model.centroids.data = tens - else: - model._initialize_centroids(dataloader, init_size=3) - - - optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50, verbose=False) - loss_func = FuzzyCMeansLoss(m=self.m) - kld_loss = nn.KLDivLoss(reduction='batchmean') - - if self.verbose: - print('Starting warmup phase...') - p1_bar = trange(self.warmup_epochs, desc='Warmup Phase') - else: - p1_bar = range(self.warmup_epochs) - - best_loss = np.inf - best_ckpt = None - model.train() - epoch_losses = [] - best_epoch = 0 - for epoch in p1_bar: - model.train() - epoch_loss = { - 'epoch': epoch, - 'clust_loss': 0.0, - 'kl_div': 0.0, - } - n_points = 0 - for batch in dataloader: - optimizer.zero_grad() - x, x_repr = batch - network_result = model(x) - assignments = network_result['assignments'] - clust_loss = loss_func(x_repr, assignments, centroids=model.centroids) - loss = clust_loss - loss.backward() - optimizer.step() - scheduler.step(loss) - epoch_loss['clust_loss'] += clust_loss.item() - n_points += x.shape[0] - p1_bar.set_postfix({'clust_loss': loss.item()}) - - epoch_loss['clust_loss'] /= n_points - epoch_losses.append(epoch_loss) - if epoch_loss['clust_loss'] < best_loss: - best_loss = epoch_loss['clust_loss'] - best_ckpt = model.state_dict() - best_epoch = epoch - if epoch - best_epoch > 100: - break - - model.load_state_dict(best_ckpt) - model_copy = NeurCAMModel( - input_dim = self.n_features, - repr_dim = self.repr_dim, - o1_channels = self.single_feature_channels, - o2_channels = self.pairwise_feature_channels, - n_bases = self.n_bases, - hidden_layers = self.hidden_layers, - n_clusters = self.k - ) - model_copy.to(self.device) - model_copy.load_state_dict(best_ckpt) - model_copy.eval() - del best_ckpt - gc.collect() - - if self.pairwise_feature_channels > 0: - if self.verbose and self.pairwise_feature_channels > 0: - print('Starting pairwise shape function annealing...') - if self.verbose: - p2_bar = trange(self.o2_anneal_epochs, desc='O2 Annealing Phase') - else: - p2_bar = range(self.o2_anneal_epochs) - - for epoch in p2_bar: - model.train() - model._anneal_o2(epoch, self.o2_anneal_epochs, self.min_temp) - valid_o2 = model._o2_valid_cuts() - if valid_o2: - break - epoch_loss = { - 'epoch': epoch, - 'clust_loss': 0.0, - 'kl_div': 0.0, - } - n_points = 0 - for batch in dataloader: - optimizer.zero_grad() - x, x_repr = batch - network_result = model(x) - assignments = network_result['assignments'] - log_assignments = network_result['log_assignments'] - old_assignments = model_copy(x)['assignments'] - - clust_loss = loss_func(x_repr, assignments, centroids=model.centroids) - kl_div = kld_loss(log_assignments, old_assignments) * self.kl_weight - loss = kl_div + clust_loss - - loss.backward() - optimizer.step() - scheduler.step(loss) - epoch_loss['clust_loss'] += clust_loss.item() - epoch_loss['kl_div'] += kl_div.item() - n_points += x.shape[0] - p2_bar.set_postfix({'clust_loss': loss.item()}) - - epoch_loss['clust_loss'] /= n_points - epoch_loss['kl_div'] /= n_points - epoch_losses.append(epoch_loss) - model._lock_in_o2() - - if self.verbose and self.single_feature_channels > 0: - print('Starting single feature shape function annealing...') - if self.verbose: - p3_bar = trange(self.o1_anneal_epochs, desc='O1 Annealing Phase') - else: - p3_bar = range(self.o1_anneal_epochs) - - if self.single_feature_channels > 0: - for epoch in p3_bar: - model.train() - model._anneal_o1(epoch, self.o1_anneal_epochs, self.min_temp) - valid_o1 = model._o1_valid_cuts() - if valid_o1: - break - epoch_loss = { - 'epoch': epoch, - 'clust_loss': 0.0, - 'kl_div': 0.0, - } - n_points = 0 - for batch in dataloader: - optimizer.zero_grad() - x, x_repr = batch - network_result = model(x) - assignments = network_result['assignments'] - log_assignments = network_result['log_assignments'] - old_assignments = model_copy(x)['assignments'] - - clust_loss = loss_func(x_repr, assignments, centroids=model.centroids) - kl_div = kld_loss(log_assignments, old_assignments) * self.kl_weight - loss = kl_div + clust_loss - - loss.backward() - optimizer.step() - scheduler.step(loss) - epoch_loss['clust_loss'] += clust_loss.item() - epoch_loss['kl_div'] += kl_div.item() - n_points += x.shape[0] - p3_bar.set_postfix({'clust_loss': loss.item()}) - - epoch_loss['clust_loss'] /= n_points - epoch_loss['kl_div'] /= n_points - epoch_losses.append(epoch_loss) - model._lock_in_o1() - - if self.verbose: - print('Starting Final Phase...') - p4_bar = trange(self.epochs - self.warmup_epochs - self.o1_anneal_epochs - self.o2_anneal_epochs, desc='Training Phase') - else: - p4_bar = range(self.epochs - self.warmup_epochs - self.o1_anneal_epochs - self.o2_anneal_epochs) - best_loss = np.inf - best_ckpt = None - best_epoch = 0 - for epoch in p4_bar: - model.train() - epoch_loss = { - 'epoch': epoch, - 'clust_loss': 0.0, - 'kl_div': 0.0, - } - n_points = 0 - for batch in dataloader: - optimizer.zero_grad() - x, x_repr = batch - network_result = model(x) - assignments = network_result['assignments'] - log_assignments = network_result['log_assignments'] - old_assignments = model_copy(x)['assignments'] - - clust_loss = loss_func(x_repr, assignments, centroids=model.centroids) - kl_div = kld_loss(log_assignments, old_assignments) * self.kl_weight - loss = kl_div + clust_loss - - loss.backward() - optimizer.step() - scheduler.step(loss) - epoch_loss['clust_loss'] += clust_loss.item() - epoch_loss['kl_div'] += kl_div.item() - n_points += x.shape[0] - p4_bar.set_postfix({'clust_loss': loss.item()}) - - epoch_loss['clust_loss'] /= n_points - epoch_loss['kl_div'] /= n_points - epoch_losses.append(epoch_loss) - if epoch_loss['clust_loss'] < best_loss: - best_loss = epoch_loss['clust_loss'] - best_ckpt = model.state_dict() - best_epoch = epoch - if epoch - best_epoch > 100: - break - model.load_state_dict(best_ckpt) - self.model = model - del model_copy - gc.collect() - torch.cuda.empty_cache() - return self - - def predict_proba(self, X: pd.DataFrame | np.ndarray): - """ - Predict the soft cluster assignments for the input data. - Args: - X (pd.DataFrame | np.ndarray): Input data in the interpretable space. - Returns: - pd.Series: Cluster assignments. - """ - if isinstance(X, pd.DataFrame): - X = X.values - X = torch.tensor(X, dtype=torch.float32, device=self.device) - test_loader = DataLoader(X, batch_size=self.batch_size, shuffle=False) - predictions = [] - self.model.eval() - with torch.no_grad(): - for x in test_loader: - result = self.model(x) - predictions.append(result['assignments'].cpu().numpy()) - predictions = np.concatenate(predictions, axis=0) - return predictions - def predict(self, X: pd.DataFrame | np.ndarray): - """ - Predict the cluster assignments for the input data. - Args: - X (pd.DataFrame | np.ndarray): Input data in the interpretable space. - Returns: - pd.Series: Cluster assignments. - """ - return np.argmax(self.predict_proba(X), axis=1) - - -class NeurCAMModel(nn.Module): - def __init__(self, - input_dim: int, - repr_dim: int, - o1_channels: int, - o2_channels: int, - n_bases: int, - hidden_layers: list[int], - n_clusters): - super(NeurCAMModel, self).__init__() - self.input_dim = input_dim - self.repr_dim = repr_dim - self.o1_channels = o1_channels - self.o2_channels = o2_channels - self.n_bases = n_bases - self.hidden_layers = hidden_layers - self.n_clusters = n_clusters - self.centroids = nn.Parameter(torch.zeros(n_clusters, self.repr_dim), requires_grad=True) - nn.init.uniform_(self.centroids) - - if self.o1_channels > 0: - self.o1_selection = nn.Parameter(torch.zeros(self.o1_channels, self.input_dim), requires_grad=True) - self.o1_choice_temp = nn.Parameter(torch.tensor(1.0), requires_grad=False) - nn.init.uniform_(self.o1_selection) - if len(hidden_layers) == 0: - layers = [nn.Linear(1, n_bases)] - else: - layers = [nn.Linear(1, hidden_layers[0])] - for i in range(1, len(hidden_layers)): - layers.append(nn.ReLU()) - layers.append(nn.Linear(hidden_layers[i-1], hidden_layers[i])) - - layers.append(nn.ReLU()) - layers.append(nn.Linear(hidden_layers[-1], n_bases)) - self.o1_projection = nn.Sequential(*layers) - self.o1_weights = nn.ModuleList([nn.Linear(n_bases, n_clusters) for _ in range(self.o1_channels)]) - - - if self.o2_channels > 0: - self.o2_selection = nn.Parameter(torch.zeros(self.o2_channels, self.input_dim, 2), requires_grad=True) - self.o2_choice_temp = nn.Parameter(torch.tensor(1.0), requires_grad=False) - nn.init.uniform_(self.o2_selection) - if len(hidden_layers) == 0: - layers = [nn.Linear(2, n_bases)] - else: - layers = [nn.Linear(2, hidden_layers[0])] - for i in range(1, len(hidden_layers)): - layers.append(nn.ReLU()) - layers.append(nn.Linear(hidden_layers[i-1], hidden_layers[i])) - - layers.append(nn.ReLU()) - layers.append(nn.Linear(hidden_layers[-1], n_bases)) - self.o2_projection = nn.Sequential(*layers) - self.o2_weights = nn.ModuleList([nn.Linear(n_bases, n_clusters) for _ in range(self.o2_channels)]) - - - self.valid_cuts = False - self.choice = Entmax15(dim=1) - self.sm = nn.Softmax(dim=-1) - self.log_sm = nn.LogSoftmax(dim=-1) - - def _initialize_centroids(self, train_loader, init_size = 10): - """ - Similar to init_size argument in MB Kmeans - """ - # get first batch - # x, x_repr = next(iter(train_loader)) - - # get number of batches in train_loader - n_batches = len(train_loader) - init_size = min(init_size, n_batches) - # get n_runs batches - temp_centroids = torch.zeros_like(self.centroids) - n_points = 0 - for i, (x, x_repr) in enumerate(train_loader): - if i == init_size: - break - # get assignments - W = self.forward(x)['assignments'] - # turn W into one-hot - - # use assignments to get the centroids - centroids_num = torch.sum(W.unsqueeze(2) * x_repr.unsqueeze(1), axis=0) - centroids_den = torch.sum(W, axis=0).unsqueeze(1) - temp_centroids += centroids_num / centroids_den * x.shape[0] - n_points += x.shape[0] - temp_centroids /= n_points - temp_centroids = temp_centroids.to(self.centroids.data.device) - - self.centroids.data = temp_centroids - - def _get_o1_selection(self): - return self.choice(self.o1_selection / self.o1_choice_temp) - - def _get_o2_selection(self): - return self.choice(self.o2_selection / self.o2_choice_temp) - - def _o1_valid_cuts(self): - val_cuts_o1 = True - total_nonzero = 0 - if self.o1_channels > 0: - o1_selection = self._get_o1_selection() - for i in range(self.o1_channels): - n_non_zero = torch.count_nonzero(o1_selection[i,:]) - total_nonzero += n_non_zero - val_cuts_o1 = val_cuts_o1 and n_non_zero <= 1 - return val_cuts_o1 - - def _o2_valid_cuts(self): - val_cuts_o2 = True - if self.o2_channels > 0: - o2_selection = self._get_o2_selection() - for i in range(self.o2_channels): - rel_tensor1 = o2_selection[i,:,0] - rel_tensor2 = o2_selection[i,:,1] - n_non_zero1 = torch.count_nonzero(rel_tensor1) - n_non_zero2 = torch.count_nonzero(rel_tensor2) - val_cuts_o2 = val_cuts_o2 and n_non_zero1 <= 1 and n_non_zero2 <= 1 - return val_cuts_o2 - - def _anneal_o1(self, o1_rel_epoch, o1_anneal_steps, min_temp): - if self.o1_channels > 0: - tau = min(o1_rel_epoch / o1_anneal_steps, 1.0) - new_temperature = tau * np.log10(min_temp) - self.o1_choice_temp.data = torch.tensor(10 ** new_temperature, dtype=torch.float32) - - def _anneal_o2(self, o2_rel_epoch, o2_anneal_steps, min_temp): - if self.o2_channels > 0: - tau = min(o2_rel_epoch / o2_anneal_steps, 1.0) - new_temperature = tau * np.log10(min_temp) - self.o2_choice_temp.data = torch.tensor(10 ** new_temperature, dtype=torch.float32) - - def _lock_in_o1(self): - if self.o1_channels > 0: - self.o1_selection.requires_grad = False - self.o1_choice_temp.requires_grad = False - def _lock_in_o2(self): - if self.o2_channels > 0: - self.o2_selection.requires_grad = False - self.o2_choice_temp.requires_grad = False - - def forward(self, x): - logits = self._forward(x) - assignments = self.sm(logits) - log_assignments = self.log_sm(logits) - return { - 'assignments': assignments, - 'log_assignments': log_assignments - } - - def _seperated_forward_o1(self, X): - o1_selection_weights = self._get_o1_selection() - o1_select_save = F.linear(X, o1_selection_weights, bias=None) - o1_select = o1_select_save.unsqueeze(2) - # o1_select: (batch_size, o1_channels, 1) - # o1_bases: (batch_size, o1_channels, n_bases) - o1_bases = self.o1_projection(o1_select) - results = {} - for i in range(self.o1_channels): - rel_selection = o1_selection_weights[i,:] - # get the index of the non-zero element - non_zero_index = torch.argmax(rel_selection).item() - - rel_bases = o1_bases[:,i,:] - - if non_zero_index not in results.keys(): - results[non_zero_index] = self.o1_weights[i](rel_bases) - - else: - results[non_zero_index] += self.o1_weights[i](rel_bases) - - return results - - def _forward(self, X): - # X: (batch_size, input_dim) - # o1_selection: (input_dim, o1_channels, 1) - result = torch.zeros(X.shape[0], self.n_clusters).to(X.device) - if self.o1_channels> 0: - o1_selection_weights = self._get_o1_selection() - o1_select_save = F.linear(X, o1_selection_weights, bias=None) - o1_select = o1_select_save.unsqueeze(2) - # o1_select: (batch_size, o1_channels, 1) - # o1_bases: (batch_size, o1_channels, n_bases) - o1_bases = self.o1_projection(o1_select) - - for i in range(self.o1_channels): - rel_bases = o1_bases[:,i,:] - result += self.o1_weights[i](rel_bases) - if self.o2_channels > 0: - o2_selection_weights = self._get_o2_selection() - o2_select = torch.einsum('bi,nio->bno', X, o2_selection_weights) - o2_bases = self.o2_projection(o2_select) - # o2_select: (batch_size, o2_channels, 2) - # o2_bases: (batch_size, o2_channels, n_bases) - for i in range(self.o2_channels): - rel_bases = o2_bases[:,i,:] - result += self.o2_weights[i](rel_bases) - return result diff --git a/README.md b/README.md index d834ca2..61b686a 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,75 @@ # NeurCAM + This is the official implementation of the Neural Clustering Additive Model (NeurCAM): https://arxiv.org/abs/2408.13361 -## Environment Setup -To set up the environment, run the following commands: +## Installation + +### Using uv (Recommended) + +First, install [uv](https://github.com/astral-sh/uv) if you haven't already: +```bash +curl -LsSf https://astral.sh/uv/install.sh | sh ``` -conda env create -f environment.yml -conda activate NeurCAM + +Then, install NeurCAM: + +```bash +# Clone the repository +git clone https://github.com/alexwolson/NeurCAM.git +cd NeurCAM + +# Install the package with uv +uv pip install -e . + +# For development with optional dependencies +uv pip install -e ".[dev]" ``` -## Example +### Using pip + +```bash +pip install -e . ``` + +### PyTorch with CUDA Support + +If you need CUDA support for GPU acceleration, install PyTorch with CUDA separately: + +```bash +# For CUDA 11.8 +uv pip install torch --index-url https://download.pytorch.org/whl/cu118 + +# For CUDA 12.1 +uv pip install torch --index-url https://download.pytorch.org/whl/cu121 +``` + +## Quick Start + +```python from sklearn.datasets import load_iris -from NeurCAM import NeurCAM +from neurcam import NeurCAM +# Load data iris = load_iris() X = iris.data +# Create and fit model nc = NeurCAM(k=3, epochs=5000) nc = nc.fit(X) + +# Make predictions neurcam_pred = nc.predict(X) ``` + +## Development + +To set up the development environment: + +```bash +# Install with development dependencies +uv pip install -e ".[dev]" + +# Format code with black +black neurcam/ +``` diff --git a/environment.yml b/environment.yml deleted file mode 100644 index e015274..0000000 --- a/environment.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: NeurCAM -channels: - - conda-forge - - pytorch - - plotly - - nvidia -dependencies: - - python=3.11 - - jupyterlab - - matplotlib - - numpy - - pandas - - pytorch - - pytorch-cuda>=11.8 - - plotly - - scipy - - tqdm - - pip - - ipywidgets - - scikit-learn - - seaborn - - pip: - - nbdev - - black - - wrds - - kaleido - - entmax diff --git a/neurcam/__init__.py b/neurcam/__init__.py new file mode 100644 index 0000000..1197ce5 --- /dev/null +++ b/neurcam/__init__.py @@ -0,0 +1,15 @@ +""" +NeurCAM: Neural Clustering Additive Model + +A Python package for interpretable clustering using neural networks. + +This package implements the Neural Clustering Additive Model (NeurCAM), +which combines neural networks with interpretable clustering to provide +both accurate clustering and model interpretability. +""" + +from neurcam.loss import FuzzyCMeansLoss +from neurcam.model import NeurCAM, NeurCAMModel + +__version__ = "0.1.0" +__all__ = ["NeurCAM", "NeurCAMModel", "FuzzyCMeansLoss"] diff --git a/neurcam/loss.py b/neurcam/loss.py new file mode 100644 index 0000000..86eae12 --- /dev/null +++ b/neurcam/loss.py @@ -0,0 +1,67 @@ +"""Loss functions for NeurCAM clustering.""" + +import torch +from torch import nn +from typing import Optional, Union + + +class FuzzyCMeansLoss(nn.Module): + """ + Fuzzy C-Means loss function for soft clustering. + + This loss implements the objective function of fuzzy c-means clustering, + which minimizes the weighted sum of squared distances to cluster centroids. + """ + + def __init__(self, m: float = 1.0, return_centroids: bool = False) -> None: + """ + Initialize the Fuzzy C-Means loss. + + Args: + m: Fuzziness parameter that controls the degree of cluster overlap. + Higher values lead to fuzzier clusters. Must be >= 1.0. + return_centroids: If True, return both loss and centroids. + """ + super(FuzzyCMeansLoss, self).__init__() + if m < 1.0: + raise ValueError(f"Fuzziness parameter m must be >= 1.0, got {m}") + self.m = m + self.return_centroids = return_centroids + + def forward( + self, + X: torch.Tensor, + W: torch.Tensor, + centroids: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + Compute the fuzzy c-means loss. + + Args: + X: Input data of shape (batch_size, n_features). + W: Fuzzy membership matrix of shape (batch_size, n_clusters). + centroids: Cluster centroids of shape (n_clusters, n_features). + If None, centroids are computed from X and W. + + Returns: + Loss value, or tuple of (loss, centroids) if return_centroids is True. + """ + # Raise W to the power m for fuzzy weighting + W_raised = torch.pow(W, self.m) + + # Calculate centroids if not provided + if centroids is None: + centroids_num = torch.sum(W_raised.unsqueeze(2) * X.unsqueeze(1), axis=0) + centroids_den = torch.sum(W_raised, axis=0).unsqueeze(1) + 1e-8 + centroids = centroids_num / centroids_den + + # Calculate Euclidean distances: (batch_size, n_clusters) + distances = torch.cdist(X, centroids, p=2) + + # Calculate the loss: mean of squared distances weighted by membership + loss = torch.mean(torch.pow(distances, 2) * W_raised) + + if self.return_centroids: + return loss, centroids + else: + return loss diff --git a/neurcam/model.py b/neurcam/model.py new file mode 100644 index 0000000..01278de --- /dev/null +++ b/neurcam/model.py @@ -0,0 +1,891 @@ +import gc +import random +from typing import Optional, Union, Dict, Any + +import numpy as np +import pandas as pd +import torch +from entmax import Entmax15 +from sklearn.cluster import KMeans, MiniBatchKMeans +from torch import nn +from torch.utils.data import DataLoader, TensorDataset +from tqdm import trange +import torch.nn.functional as F + +from neurcam.loss import FuzzyCMeansLoss + +# Constants +DEFAULT_RANDOM_STATE = 42 +DEFAULT_M = 1.05 +DEFAULT_HIDDEN_LAYERS = [128, 128] +DEFAULT_N_BASES = 64 +DEFAULT_LEARNING_RATE = 2e-3 +DEFAULT_EPOCHS = 5000 +DEFAULT_BATCH_SIZE = 512 +DEFAULT_WARMUP_RATIO = 0.4 +DEFAULT_O1_ANNEAL_RATIO = 0.1 +DEFAULT_O2_ANNEAL_RATIO = 0.1 +DEFAULT_MIN_TEMP = 1e-5 +DEFAULT_KL_WEIGHT = 1.0 +DEFAULT_MODEL_DIR = "NeurCAMCheckpoints" +DEFAULT_PATIENCE = 100 +LR_SCHEDULER_FACTOR = 0.5 +LR_SCHEDULER_PATIENCE = 50 +CENTROID_INIT_SIZE = 3 + + +class NeurCAM: + def __init__( + self, + k: int, + random_state: int = DEFAULT_RANDOM_STATE, + m: float = DEFAULT_M, + hidden_layers: list[int] = None, + n_bases: int = DEFAULT_N_BASES, + learning_rate: float = DEFAULT_LEARNING_RATE, + epochs: int = DEFAULT_EPOCHS, + batch_size: int = DEFAULT_BATCH_SIZE, + single_feature_channels: Union[float, int] = 1.0, + pairwise_feature_channels: Union[float, int] = 0.0, + warmup_ratio: Union[float, int] = DEFAULT_WARMUP_RATIO, + o1_anneal_ratio: Union[float, int] = DEFAULT_O1_ANNEAL_RATIO, + o2_anneal_ratio: Union[float, int] = DEFAULT_O2_ANNEAL_RATIO, + min_temp: float = DEFAULT_MIN_TEMP, + kl_weight: float = DEFAULT_KL_WEIGHT, + smart_init: str = "none", + model_dir: str = DEFAULT_MODEL_DIR, + device: str = "auto", + verbose: bool = True, + ) -> None: + """ + NeurCAM class for interpretable clustering. + + Args: + k: Number of clusters. + random_state: Random seed for reproducibility. + m: Fuzziness parameter. + hidden_layers: List of hidden layer dimensions for the backbone network. + n_bases: Output dimension of the backbone. + learning_rate: Learning rate for the optimizer. + epochs: Total number of training epochs. + batch_size: Batch size for training. + single_feature_channels: Number of channels for single feature interactions. + If values are <=1.0, interpreted as ratio of number of features. + If values are >1.0, interpreted as number of channels. + pairwise_feature_channels: Number of channels for pairwise feature interactions. + warmup_ratio: Ratio of warmup epochs. + o1_anneal_ratio: Ratio of first annealing phase. + o2_anneal_ratio: Ratio of second annealing phase. + min_temp: Minimum temperature for annealing. + kl_weight: Weight for the KL divergence loss. + smart_init: Clustering initialization method ('none', 'kmeans', 'mbkmeans'). + model_dir: Directory to save model checkpoints. + device: Device to use for training ('auto', 'cuda', 'cpu'). + verbose: Whether to print training progress. + """ + if hidden_layers is None: + hidden_layers = DEFAULT_HIDDEN_LAYERS.copy() + + # Validate parameters + if k <= 0: + raise ValueError(f"k must be positive, got {k}") + if m < 1.0: + raise ValueError(f"m must be >= 1.0 for fuzzy clustering, got {m}") + if epochs <= 0: + raise ValueError(f"epochs must be positive, got {epochs}") + if batch_size <= 0: + raise ValueError(f"batch_size must be positive, got {batch_size}") + if learning_rate <= 0: + raise ValueError(f"learning_rate must be positive, got {learning_rate}") + if n_bases <= 0: + raise ValueError(f"n_bases must be positive, got {n_bases}") + if min_temp <= 0: + raise ValueError(f"min_temp must be positive, got {min_temp}") + if kl_weight < 0: + raise ValueError(f"kl_weight must be non-negative, got {kl_weight}") + if single_feature_channels < 0: + raise ValueError( + f"single_feature_channels must be non-negative, got {single_feature_channels}" + ) + if pairwise_feature_channels < 0: + raise ValueError( + f"pairwise_feature_channels must be non-negative, got {pairwise_feature_channels}" + ) + if not (0 <= warmup_ratio <= 1): + raise ValueError(f"warmup_ratio must be in [0, 1], got {warmup_ratio}") + if not (0 <= o1_anneal_ratio <= 1): + raise ValueError(f"o1_anneal_ratio must be in [0, 1], got {o1_anneal_ratio}") + if not (0 <= o2_anneal_ratio <= 1): + raise ValueError(f"o2_anneal_ratio must be in [0, 1], got {o2_anneal_ratio}") + if smart_init not in ["none", "kmeans", "mbkmeans"]: + raise ValueError( + f"smart_init must be one of ['none', 'kmeans', 'mbkmeans'], got '{smart_init}'" + ) + # Validate device is a valid PyTorch device specifier + # Allow "auto", or any string that torch.device() accepts + if device != "auto": + try: + torch.device(device) + except RuntimeError as e: + raise ValueError(f"Invalid device specifier: '{device}'. Error: {e}") + + self.k = k + self.random_state = random_state + self.m = m + self.hidden_layers = hidden_layers + self.n_bases = n_bases + self.learning_rate = learning_rate + self.epochs = epochs + self.batch_size = batch_size + self.sf_channels = single_feature_channels + self.pf_channels = pairwise_feature_channels + self.warmup_ratio = warmup_ratio + self.o1_anneal_ratio = o1_anneal_ratio + self.o2_anneal_ratio = o2_anneal_ratio + self.min_temp = min_temp + self.kl_weight = kl_weight + self.smart_init = smart_init + self.model_dir = model_dir + self.model: Optional[NeurCAMModel] = None + self.verbose = verbose + + self.warmup_epochs = int(self.epochs * self.warmup_ratio) + self.o1_anneal_epochs = int(self.epochs * self.o1_anneal_ratio) + self.o2_anneal_epochs = int(self.epochs * self.o2_anneal_ratio) + self.feature_names: Optional[pd.Index] = None + self.n_features = -1 + self.repr_dim = -1 + self.device = device + if self.device == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + def _set_random_seeds(self) -> None: + """Set random seeds for reproducibility.""" + torch.manual_seed(self.random_state) + np.random.seed(self.random_state) + random.seed(self.random_state) + + def _prepare_data( + self, X: Union[pd.DataFrame, np.ndarray], X_repr: Optional[Union[pd.DataFrame, np.ndarray]] + ) -> tuple[np.ndarray, np.ndarray]: + """ + Prepare input data for training. + + Args: + X: Input data in the interpretable space. + X_repr: Input data in transformed/latent space (optional). + + Returns: + Tuple of (X, X_repr) as numpy arrays. + """ + if isinstance(X, pd.DataFrame): + self.feature_names = X.columns + X = X.values + if X_repr is not None: + if isinstance(X_repr, pd.DataFrame): + X_repr = X_repr.values + if X.shape[0] != X_repr.shape[0]: + raise ValueError( + f"X and X_repr must have the same number of samples. " + f"Got X.shape[0]={X.shape[0]} and X_repr.shape[0]={X_repr.shape[0]}" + ) + else: + X_repr = X + + if X.shape[0] < self.k: + raise ValueError( + f"Number of samples ({X.shape[0]}) must be at least as large as " + f"number of clusters ({self.k})" + ) + + return X, X_repr + + def _compute_channel_counts(self) -> None: + """Compute the actual number of channels based on ratios or absolute values.""" + if self.sf_channels <= 1.0: + self.single_feature_channels = max(0, int(self.sf_channels * self.n_features)) + else: + self.single_feature_channels = int(self.sf_channels) + + if self.pf_channels <= 1.0: + self.pairwise_feature_channels = max(0, int(self.pf_channels * self.n_features)) + else: + self.pairwise_feature_channels = int(self.pf_channels) + + def _create_model(self) -> "NeurCAMModel": + """Create and initialize the NeurCAM model.""" + model = NeurCAMModel( + input_dim=self.n_features, + repr_dim=self.repr_dim, + o1_channels=self.single_feature_channels, + o2_channels=self.pairwise_feature_channels, + n_bases=self.n_bases, + hidden_layers=self.hidden_layers, + n_clusters=self.k, + ) + return model.to(self.device) + + def _initialize_centroids( + self, model: "NeurCAMModel", X_repr: np.ndarray, dataloader: DataLoader + ) -> None: + """ + Initialize cluster centroids using the specified method. + + Args: + model: The NeurCAM model to initialize. + X_repr: Input data in representation space. + dataloader: DataLoader for the training data. + """ + if self.smart_init == "kmeans": + kmeans = KMeans(n_clusters=self.k, random_state=self.random_state, n_init=10) + kmeans.fit(X_repr) + tens = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32) + model.centroids.data = tens.to(model.centroids.data.device) + + elif self.smart_init == "mbkmeans": + kmeans = MiniBatchKMeans(n_clusters=self.k, random_state=self.random_state, n_init=10) + kmeans.fit(X_repr) + tens = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32) + model.centroids.data = tens.to(model.centroids.data.device) + else: + model._initialize_centroids(dataloader, init_size=CENTROID_INIT_SIZE) + + def _train_epoch( + self, + model: "NeurCAMModel", + dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: "torch.optim.lr_scheduler.LRScheduler", + loss_func: FuzzyCMeansLoss, + kld_loss: nn.KLDivLoss, + model_copy: Optional["NeurCAMModel"] = None, + ) -> Dict[str, float]: + """ + Train for one epoch. + + Args: + model: The model to train. + dataloader: DataLoader for training data. + optimizer: Optimizer for model parameters. + scheduler: Learning rate scheduler. + loss_func: Fuzzy C-means loss function. + kld_loss: KL divergence loss function. + model_copy: Optional copy of model for KL divergence calculation. + + Returns: + Dictionary containing loss values for the epoch. + """ + model.train() + epoch_loss = {"clust_loss": 0.0, "kl_div": 0.0} + n_points = 0 + + for batch in dataloader: + optimizer.zero_grad() + x, x_repr = batch + network_result = model(x) + assignments = network_result["assignments"] + + clust_loss = loss_func(x_repr, assignments, centroids=model.centroids) + loss = clust_loss + + # Add KL divergence if model_copy is provided + if model_copy is not None: + log_assignments = network_result["log_assignments"] + old_assignments = model_copy(x)["assignments"] + kl_div = kld_loss(log_assignments, old_assignments) * self.kl_weight + loss = loss + kl_div + epoch_loss["kl_div"] += kl_div.item() + + loss.backward() + optimizer.step() + + epoch_loss["clust_loss"] += clust_loss.item() + n_points += x.shape[0] + + epoch_loss["clust_loss"] /= n_points + if model_copy is not None: + epoch_loss["kl_div"] /= n_points + + # Call scheduler once per epoch with the epoch-averaged loss + scheduler.step(epoch_loss["clust_loss"]) + + return epoch_loss + + def _train_phase( + self, + model: "NeurCAMModel", + dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: "torch.optim.lr_scheduler.LRScheduler", + loss_func: FuzzyCMeansLoss, + kld_loss: nn.KLDivLoss, + n_epochs: int, + phase_name: str, + model_copy: Optional["NeurCAMModel"] = None, + track_best: bool = True, + ) -> Optional[Dict[str, Any]]: + """ + Run a complete training phase. + + Args: + model: The model to train. + dataloader: DataLoader for training data. + optimizer: Optimizer for model parameters. + scheduler: Learning rate scheduler. + loss_func: Fuzzy C-means loss function. + kld_loss: KL divergence loss function. + n_epochs: Number of epochs for this phase. + phase_name: Name of the training phase for logging. + model_copy: Optional copy of model for KL divergence calculation. + track_best: Whether to track and return the best checkpoint. + + Returns: + Best model state dict if track_best is True, None otherwise. + """ + if self.verbose: + print(f"Starting {phase_name}...") + progress_bar = trange(n_epochs, desc=phase_name) + else: + progress_bar = range(n_epochs) + + best_loss = np.inf + best_ckpt = None + best_epoch = 0 + + for epoch in progress_bar: + epoch_loss = self._train_epoch( + model, dataloader, optimizer, scheduler, loss_func, kld_loss, model_copy + ) + + if self.verbose: + progress_bar.set_postfix({"clust_loss": epoch_loss["clust_loss"]}) + + if track_best and epoch_loss["clust_loss"] < best_loss: + best_loss = epoch_loss["clust_loss"] + best_ckpt = model.state_dict() + best_epoch = epoch + + if track_best and epoch - best_epoch > DEFAULT_PATIENCE: + break + + return best_ckpt if track_best else None + + def _train_annealing_phase( + self, + model: "NeurCAMModel", + dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: "torch.optim.lr_scheduler.LRScheduler", + loss_func: FuzzyCMeansLoss, + kld_loss: nn.KLDivLoss, + model_copy: "NeurCAMModel", + n_epochs: int, + phase_name: str, + anneal_fn: str, + valid_cuts_fn: str, + lock_fn: str, + ) -> None: + """ + Run an annealing training phase. + + Args: + model: The model to train. + dataloader: DataLoader for training data. + optimizer: Optimizer for model parameters. + scheduler: Learning rate scheduler. + loss_func: Fuzzy C-means loss function. + kld_loss: KL divergence loss function. + model_copy: Copy of model for KL divergence calculation. + n_epochs: Number of epochs for this phase. + phase_name: Name of the training phase for logging. + anneal_fn: Name of annealing method to call. + valid_cuts_fn: Name of validation method to check. + lock_fn: Name of lock-in method to call after training. + """ + if self.verbose: + print(f"Starting {phase_name}...") + progress_bar = trange(n_epochs, desc=phase_name) + else: + progress_bar = range(n_epochs) + + for epoch in progress_bar: + model.train() + getattr(model, anneal_fn)(epoch, n_epochs, self.min_temp) + + valid_cuts = getattr(model, valid_cuts_fn)() + if valid_cuts: + break + + epoch_loss = self._train_epoch( + model, dataloader, optimizer, scheduler, loss_func, kld_loss, model_copy + ) + + if self.verbose: + progress_bar.set_postfix({"clust_loss": epoch_loss["clust_loss"]}) + + getattr(model, lock_fn)() + + def fit( + self, + X: Union[pd.DataFrame, np.ndarray], + X_repr: Optional[Union[pd.DataFrame, np.ndarray]] = None, + ) -> "NeurCAM": + """ + Fit the NeurCAM model. + + Args: + X: Input data in the interpretable space. + X_repr: Input data in transformed/latent space (optional). + + Returns: + Self for method chaining. + """ + # Set random seeds for reproducibility + self._set_random_seeds() + + # Prepare data + X, X_repr = self._prepare_data(X, X_repr) + self.n_features = X.shape[1] + self.repr_dim = X_repr.shape[1] + + # Create dataset and dataloader + dataset = TensorDataset( + torch.tensor(X, dtype=torch.float32, device=self.device), + torch.tensor(X_repr, dtype=torch.float32, device=self.device), + ) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + # Compute channel counts + self._compute_channel_counts() + + # Create and initialize model + model = self._create_model() + self._initialize_centroids(model, X_repr, dataloader) + + # Setup training components + optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE + ) + loss_func = FuzzyCMeansLoss(m=self.m) + kld_loss = nn.KLDivLoss(reduction="batchmean") + + # Warmup phase + best_ckpt = self._train_phase( + model, + dataloader, + optimizer, + scheduler, + loss_func, + kld_loss, + self.warmup_epochs, + "Warmup Phase", + model_copy=None, + track_best=True, + ) + + # Load best checkpoint and create copy for KL divergence + model.load_state_dict(best_ckpt) + model_copy = self._create_model() + model_copy.load_state_dict(best_ckpt) + model_copy.eval() + del best_ckpt + gc.collect() + + # Pairwise feature annealing phase + if self.pairwise_feature_channels > 0: + self._train_annealing_phase( + model, + dataloader, + optimizer, + scheduler, + loss_func, + kld_loss, + model_copy, + self.o2_anneal_epochs, + "O2 Annealing Phase", + "_anneal_o2", + "_o2_valid_cuts", + "_lock_in_o2", + ) + + # Single feature annealing phase + if self.single_feature_channels > 0: + self._train_annealing_phase( + model, + dataloader, + optimizer, + scheduler, + loss_func, + kld_loss, + model_copy, + self.o1_anneal_epochs, + "O1 Annealing Phase", + "_anneal_o1", + "_o1_valid_cuts", + "_lock_in_o1", + ) + + # Final training phase + final_epochs = ( + self.epochs - self.warmup_epochs - self.o1_anneal_epochs - self.o2_anneal_epochs + ) + best_ckpt = self._train_phase( + model, + dataloader, + optimizer, + scheduler, + loss_func, + kld_loss, + final_epochs, + "Final Training Phase", + model_copy=model_copy, + track_best=True, + ) + + # Load best checkpoint + model.load_state_dict(best_ckpt) + self.model = model + + # Cleanup + del model_copy + gc.collect() + torch.cuda.empty_cache() + + return self + + def predict_proba(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: + """ + Predict the soft cluster assignments for the input data. + + Args: + X: Input data in the interpretable space. + + Returns: + Numpy array of soft cluster assignments with shape (n_samples, n_clusters). + """ + if isinstance(X, pd.DataFrame): + X = X.values + X = torch.tensor(X, dtype=torch.float32, device=self.device) + test_loader = DataLoader(X, batch_size=self.batch_size, shuffle=False) + predictions = [] + self.model.eval() + with torch.no_grad(): + for x in test_loader: + result = self.model(x) + predictions.append(result["assignments"].cpu().numpy()) + predictions = np.concatenate(predictions, axis=0) + return predictions + + def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: + """ + Predict the cluster assignments for the input data. + + Args: + X: Input data in the interpretable space. + + Returns: + Numpy array of hard cluster assignments with shape (n_samples,). + """ + return np.argmax(self.predict_proba(X), axis=1) + + +class NeurCAMModel(nn.Module): + """ + Neural network model for NeurCAM clustering. + + This model implements a neural clustering approach with support for + single-feature and pairwise-feature interactions. + """ + + def __init__( + self, + input_dim: int, + repr_dim: int, + o1_channels: int, + o2_channels: int, + n_bases: int, + hidden_layers: list[int], + n_clusters: int, + ) -> None: + """ + Initialize the NeurCAM model. + + Args: + input_dim: Dimension of input features. + repr_dim: Dimension of representation space. + o1_channels: Number of single-feature channels. + o2_channels: Number of pairwise-feature channels. + n_bases: Number of basis functions. + hidden_layers: List of hidden layer dimensions. + n_clusters: Number of clusters. + """ + super(NeurCAMModel, self).__init__() + + # Validate parameters + if input_dim <= 0: + raise ValueError(f"input_dim must be positive, got {input_dim}") + if repr_dim <= 0: + raise ValueError(f"repr_dim must be positive, got {repr_dim}") + if o1_channels < 0: + raise ValueError(f"o1_channels must be non-negative, got {o1_channels}") + if o2_channels < 0: + raise ValueError(f"o2_channels must be non-negative, got {o2_channels}") + if n_bases <= 0: + raise ValueError(f"n_bases must be positive, got {n_bases}") + if n_clusters <= 0: + raise ValueError(f"n_clusters must be positive, got {n_clusters}") + + self.input_dim = input_dim + self.repr_dim = repr_dim + self.o1_channels = o1_channels + self.o2_channels = o2_channels + self.n_bases = n_bases + self.hidden_layers = hidden_layers + self.n_clusters = n_clusters + self.centroids = nn.Parameter(torch.zeros(n_clusters, self.repr_dim), requires_grad=True) + nn.init.uniform_(self.centroids) + + if self.o1_channels > 0: + self._initialize_o1_layers() + + if self.o2_channels > 0: + self._initialize_o2_layers() + + self.valid_cuts = False + self.choice = Entmax15(dim=1) + self.sm = nn.Softmax(dim=-1) + self.log_sm = nn.LogSoftmax(dim=-1) + + def _build_projection_network(self, input_size: int) -> nn.Sequential: + """ + Build a projection network with the specified architecture. + + Args: + input_size: Size of the input layer. + + Returns: + Sequential neural network module. + """ + layers = [] + + if not self.hidden_layers: + layers.append(nn.Linear(input_size, self.n_bases)) + else: + layers.append(nn.Linear(input_size, self.hidden_layers[0])) + + for i in range(1, len(self.hidden_layers)): + layers.extend( + [nn.ReLU(), nn.Linear(self.hidden_layers[i - 1], self.hidden_layers[i])] + ) + + layers.extend([nn.ReLU(), nn.Linear(self.hidden_layers[-1], self.n_bases)]) + + return nn.Sequential(*layers) + + def _initialize_o1_layers(self) -> None: + """Initialize layers for single-feature processing.""" + self.o1_selection = nn.Parameter( + torch.zeros(self.o1_channels, self.input_dim), requires_grad=True + ) + self.o1_choice_temp = nn.Parameter(torch.tensor(1.0), requires_grad=False) + nn.init.uniform_(self.o1_selection) + self.o1_projection = self._build_projection_network(input_size=1) + self.o1_weights = nn.ModuleList( + [nn.Linear(self.n_bases, self.n_clusters) for _ in range(self.o1_channels)] + ) + + def _initialize_o2_layers(self) -> None: + """Initialize layers for pairwise-feature processing.""" + self.o2_selection = nn.Parameter( + torch.zeros(self.o2_channels, self.input_dim, 2), requires_grad=True + ) + self.o2_choice_temp = nn.Parameter(torch.tensor(1.0), requires_grad=False) + nn.init.uniform_(self.o2_selection) + self.o2_projection = self._build_projection_network(input_size=2) + self.o2_weights = nn.ModuleList( + [nn.Linear(self.n_bases, self.n_clusters) for _ in range(self.o2_channels)] + ) + + def _initialize_centroids(self, train_loader: DataLoader, init_size: int = 10) -> None: + """ + Initialize centroids using initial batches from the training data. + + Similar to init_size argument in MiniBatch KMeans. + + Args: + train_loader: DataLoader for training data. + init_size: Number of batches to use for initialization. + """ + n_batches = len(train_loader) + init_size = min(init_size, n_batches) + temp_centroids = torch.zeros_like(self.centroids) + n_points = 0 + + for i, (x, x_repr) in enumerate(train_loader): + if i == init_size: + break + W = self.forward(x)["assignments"] + centroids_num = torch.sum(W.unsqueeze(2) * x_repr.unsqueeze(1), axis=0) + centroids_den = torch.sum(W, axis=0).unsqueeze(1) + temp_centroids += centroids_num / centroids_den * x.shape[0] + n_points += x.shape[0] + + temp_centroids /= n_points + temp_centroids = temp_centroids.to(self.centroids.data.device) + self.centroids.data = temp_centroids + + def _get_o1_selection(self) -> torch.Tensor: + """Get single-feature selection weights with temperature scaling.""" + return self.choice(self.o1_selection / self.o1_choice_temp) + + def _get_o2_selection(self) -> torch.Tensor: + """Get pairwise-feature selection weights with temperature scaling.""" + return self.choice(self.o2_selection / self.o2_choice_temp) + + def _o1_valid_cuts(self) -> bool: + """ + Check if single-feature selections are valid (at most one feature per channel). + + Returns: + True if all channels have at most one selected feature. + """ + if self.o1_channels == 0: + return True + + o1_selection = self._get_o1_selection() + non_zero_counts = torch.count_nonzero(o1_selection, dim=1) + return torch.all(non_zero_counts <= 1).item() + + def _o2_valid_cuts(self) -> bool: + """ + Check if pairwise-feature selections are valid (at most one feature per slot). + + Returns: + True if all channels have at most one selected feature per slot. + """ + if self.o2_channels == 0: + return True + + o2_selection = self._get_o2_selection() + non_zero_counts_slot0 = torch.count_nonzero(o2_selection[:, :, 0], dim=1) + non_zero_counts_slot1 = torch.count_nonzero(o2_selection[:, :, 1], dim=1) + return ( + torch.all(non_zero_counts_slot0 <= 1) and torch.all(non_zero_counts_slot1 <= 1) + ).item() + + def _anneal_o1(self, o1_rel_epoch: int, o1_anneal_steps: int, min_temp: float) -> None: + """ + Anneal the temperature for single-feature selection. + + Args: + o1_rel_epoch: Current epoch in the annealing phase. + o1_anneal_steps: Total number of annealing steps. + min_temp: Minimum temperature to reach. + """ + if self.o1_channels > 0: + tau = min(o1_rel_epoch / o1_anneal_steps, 1.0) + new_temperature = tau * np.log10(min_temp) + self.o1_choice_temp.data = torch.tensor(10**new_temperature, dtype=torch.float32) + + def _anneal_o2(self, o2_rel_epoch: int, o2_anneal_steps: int, min_temp: float) -> None: + """ + Anneal the temperature for pairwise-feature selection. + + Args: + o2_rel_epoch: Current epoch in the annealing phase. + o2_anneal_steps: Total number of annealing steps. + min_temp: Minimum temperature to reach. + """ + if self.o2_channels > 0: + tau = min(o2_rel_epoch / o2_anneal_steps, 1.0) + new_temperature = tau * np.log10(min_temp) + self.o2_choice_temp.data = torch.tensor(10**new_temperature, dtype=torch.float32) + + def _lock_in_o1(self) -> None: + """Lock single-feature selections by disabling gradient updates.""" + if self.o1_channels > 0: + self.o1_selection.requires_grad = False + self.o1_choice_temp.requires_grad = False + + def _lock_in_o2(self) -> None: + """Lock pairwise-feature selections by disabling gradient updates.""" + if self.o2_channels > 0: + self.o2_selection.requires_grad = False + self.o2_choice_temp.requires_grad = False + + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Forward pass through the model. + + Args: + x: Input tensor of shape (batch_size, input_dim). + + Returns: + Dictionary containing 'assignments' and 'log_assignments'. + """ + logits = self._forward(x) + assignments = self.sm(logits) + log_assignments = self.log_sm(logits) + return {"assignments": assignments, "log_assignments": log_assignments} + + def _separated_forward_o1(self, X: torch.Tensor) -> Dict[int, torch.Tensor]: + """ + Compute single-feature contributions separately for each feature. + + Args: + X: Input tensor of shape (batch_size, input_dim). + + Returns: + Dictionary mapping feature indices to their contributions. + """ + o1_selection_weights = self._get_o1_selection() + o1_select_save = F.linear(X, o1_selection_weights, bias=None) + o1_select = o1_select_save.unsqueeze(2) + o1_bases = self.o1_projection(o1_select) + results = {} + + for i in range(self.o1_channels): + rel_selection = o1_selection_weights[i, :] + non_zero_index = torch.argmax(rel_selection).item() + rel_bases = o1_bases[:, i, :] + + if non_zero_index in results: + results[non_zero_index] += self.o1_weights[i](rel_bases) + else: + results[non_zero_index] = self.o1_weights[i](rel_bases) + + return results + + def _forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Internal forward pass computing logits. + + Args: + X: Input tensor of shape (batch_size, input_dim). + + Returns: + Logits tensor of shape (batch_size, n_clusters). + """ + result = torch.zeros(X.shape[0], self.n_clusters, device=X.device) + + if self.o1_channels > 0: + o1_selection_weights = self._get_o1_selection() + o1_select_save = F.linear(X, o1_selection_weights, bias=None) + o1_select = o1_select_save.unsqueeze(2) + o1_bases = self.o1_projection(o1_select) + + for i in range(self.o1_channels): + rel_bases = o1_bases[:, i, :] + result += self.o1_weights[i](rel_bases) + + if self.o2_channels > 0: + o2_selection_weights = self._get_o2_selection() + o2_select = torch.einsum("bi,nio->bno", X, o2_selection_weights) + o2_bases = self.o2_projection(o2_select) + + for i in range(self.o2_channels): + rel_bases = o2_bases[:, i, :] + result += self.o2_weights[i](rel_bases) + + return result diff --git a/neurcam/py.typed b/neurcam/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5a66b69 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,69 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "neurcam" +version = "0.1.0" +description = "Neural Clustering Additive Model (NeurCAM) - An interpretable clustering approach using neural networks" +readme = "README.md" +requires-python = ">=3.11" +license = {file = "LICENSE"} +keywords = ["machine learning", "clustering", "neural networks", "interpretability"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +dependencies = [ + "numpy", + "pandas", + "torch>=2.0.0", + "scikit-learn", + "scipy", + "tqdm", + "entmax", +] + +[project.optional-dependencies] +dev = [ + "jupyterlab", + "matplotlib", + "plotly", + "ipywidgets", + "seaborn", + "nbdev", + "black", + "kaleido", +] +wrds = [ + "wrds", +] + +[project.urls] +Homepage = "https://github.com/alexwolson/NeurCAM" +Repository = "https://github.com/alexwolson/NeurCAM" +Documentation = "https://arxiv.org/abs/2408.13361" + +[tool.hatch.build.targets.wheel] +packages = ["neurcam"] + +[tool.black] +line-length = 100 +target-version = ['py311'] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hatch + | \.venv + | _build + | build + | dist +)/ +'''