Skip to content

Wrapper around some basic sklearn utilities for clustering

License

Notifications You must be signed in to change notification settings

rutujagurav/clustutils4r

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Clustering Utilities

This packages provides a simple convenience wrapper around some basic sklearn utilities for clustering. The only function available is eval_clustering().

Installation

pip install clustutils4r

Available Parameters

model: Clustering model object (untrained)

X: Numpy array containing preprocessed, normalized, complete dataset features

gt_labels: Numpy array containing encoded ground-truth labels for X (often not available)

num_runs: No. of times to fit a model

best_model_metric: Metric to use to choose the best model

make_silhoutte_plots: Whether to make silhouette plots for the best model (default = False).

embed_data_in_2d: Whether to compute TSNE embeddings of the X to plotted alongside silhouette plot or plot the first 2 features (default = False).

save_dir: location to store results; directory will be created if it does not exist

save: set True if you want to save all results in save_dir; defaults to False

show: display all results; useful in notebooks; defaults to False

Example Usage

import os
import numpy as np
from sklearn.datasets import make_blobs, load_iris, load_digits
from eval_clustering import eval_clustering

## For testing purposes
rng = np.random.RandomState(0)
n_samples=1000
X, y = make_blobs(n_samples=n_samples, centers=5, n_features=2, cluster_std=0.60, random_state=rng)

save_dir = "results"
os.makedirs(save_dir, exist_ok=True)

best_model, grid_search_results = eval_clustering(
                                       X=X,                                               # dataset to cluster
                                       gt_labels=y,                                       # ground-truth labels; often these aren't available so don't pass this argument
                                       num_runs=10,                                       # number of times to fit a model
                                       best_model_metric="FMI",                           # metric to use to choose the best model
                                       make_silhoutte_plots=True, embed_data_in_2d=False, # whether to make silhouette plots
                                       show=False,                                        # whether to display the plots; this is used in a notebook
                                       save=True, save_dir="results"                      # whether to save the plots
                                    )

grid_search_results sil

About

Wrapper around some basic sklearn utilities for clustering

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published