From 420418edf0d266dfeda9f8ff514a94638421000d Mon Sep 17 00:00:00 2001
From: gileshd <gileshd@googlemail.com>
Date: Fri, 19 Jul 2024 22:04:48 +0100
Subject: [PATCH 1/3] Add utility function for sklearn kmeans

---
 dynamax/utils/cluster.py | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)
 create mode 100644 dynamax/utils/cluster.py

diff --git a/dynamax/utils/cluster.py b/dynamax/utils/cluster.py
new file mode 100644
index 00000000..a31c9e05
--- /dev/null
+++ b/dynamax/utils/cluster.py
@@ -0,0 +1,28 @@
+from typing import Tuple
+
+from jax import numpy as jnp
+from jax import random as jr
+
+from jaxtyping import Array, Float
+
+
+def kmeans_sklearn(
+    k: int, X: Float[Array, "num_samples state_dim"], key: Array
+) -> Tuple[Float[Array, "num_states state_dim"], Float[Array, "num_samples"]]:
+    """
+    Compute the cluster centers and assignments using the sklearn K-means algorithm.
+
+    Args:
+        k (int): The number of clusters.
+        X (Array(N, D)): The input data array. N samples of dimension D.
+        key (Array): The random seed array.
+
+    Returns:
+        Array(k, D), Array(N,): The cluster centers and labels
+    """
+    from sklearn.cluster import KMeans
+
+    key, subkey = jr.split(key)  # Create a random seed for SKLearn.
+    sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
+    km = KMeans(k, random_state=int(sklearn_key)).fit(X)
+    return jnp.array(km.cluster_centers_), jnp.array(km.labels_)

From 828f1d144e3631cbc7d9e790e11e9c83863e4242 Mon Sep 17 00:00:00 2001
From: gileshd <gileshd@googlemail.com>
Date: Fri, 19 Jul 2024 22:00:51 +0100
Subject: [PATCH 2/3] Update SSMs to use kmeans utility function

---
 dynamax/hidden_markov_model/models/arhmm.py   |  7 ++--
 .../hidden_markov_model/models/gamma_hmm.py   |  9 ++----
 .../models/gaussian_hmm.py                    | 32 ++++---------------
 dynamax/hidden_markov_model/models/gmm_hmm.py | 15 +++------
 .../hidden_markov_model/models/linreg_hmm.py  |  7 ++--
 .../hidden_markov_model/models/logreg_hmm.py  | 16 ++++++----
 6 files changed, 28 insertions(+), 58 deletions(-)

diff --git a/dynamax/hidden_markov_model/models/arhmm.py b/dynamax/hidden_markov_model/models/arhmm.py
index 2eff832b..07c1ee75 100644
--- a/dynamax/hidden_markov_model/models/arhmm.py
+++ b/dynamax/hidden_markov_model/models/arhmm.py
@@ -10,6 +10,7 @@
 from dynamax.parameters import ParameterProperties
 from dynamax.types import Scalar
 from dynamax.utils.bijectors import RealToPSDBijector
+from dynamax.utils.cluster import kmeans_sklearn
 from tensorflow_probability.substrates import jax as tfp
 from typing import NamedTuple, Optional, Tuple, Union
 
@@ -42,12 +43,8 @@ def initialize(self,
                    emissions=None):
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
             _emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.emission_dim * self.num_lags))
-            _emission_biases = jnp.array(km.cluster_centers_)
+            _emission_biases, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))
 
         elif method.lower() == "prior":
diff --git a/dynamax/hidden_markov_model/models/gamma_hmm.py b/dynamax/hidden_markov_model/models/gamma_hmm.py
index 2efcdd86..f44538ea 100644
--- a/dynamax/hidden_markov_model/models/gamma_hmm.py
+++ b/dynamax/hidden_markov_model/models/gamma_hmm.py
@@ -8,6 +8,7 @@
 from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
 from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
 from dynamax.types import Scalar
+from dynamax.utils.cluster import kmeans_sklearn
 import optax
 from typing import NamedTuple, Optional, Tuple, Union
 
@@ -38,13 +39,9 @@ def initialize(self,
 
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, 1))
-
+            cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, 1), key)
             _emission_concentrations = jnp.ones((self.num_states,))
-            _emission_rates = jnp.ravel(1.0 / km.cluster_centers_)
+            _emission_rates = jnp.ravel(1.0 / cluster_centers)
 
         elif method.lower() == "prior":
             _emission_concentrations = jnp.ones((self.num_states,))
diff --git a/dynamax/hidden_markov_model/models/gaussian_hmm.py b/dynamax/hidden_markov_model/models/gaussian_hmm.py
index c1904878..fe1e8dc3 100644
--- a/dynamax/hidden_markov_model/models/gaussian_hmm.py
+++ b/dynamax/hidden_markov_model/models/gaussian_hmm.py
@@ -17,6 +17,7 @@
 from dynamax.utils.distributions import niw_posterior_update
 from dynamax.utils.bijectors import RealToPSDBijector
 from dynamax.utils.utils import pytree_sum
+from dynamax.utils.cluster import kmeans_sklearn
 from typing import NamedTuple, Optional, Tuple, Union
 
 
@@ -70,12 +71,7 @@ def initialize(self, key=jr.PRNGKey(0),
                    emissions=None):
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
-
-            _emission_means = jnp.array(km.cluster_centers_)
+            _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))
 
         elif method.lower() == "prior":
@@ -168,11 +164,7 @@ def initialize(self, key=jr.PRNGKey(0),
 
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
-            _emission_means = jnp.array(km.cluster_centers_)
+            _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_scale_diags = jnp.ones((self.num_states, self.emission_dim))
 
         elif method.lower() == "prior":
@@ -289,11 +281,7 @@ def initialize(self, key=jr.PRNGKey(0),
         """
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
-            _emission_means = jnp.array(km.cluster_centers_)
+            _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_scales = jnp.ones((self.num_states,))
 
         elif method.lower() == "prior":
@@ -391,11 +379,7 @@ def initialize(self, key=jr.PRNGKey(0),
         """
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
-            _emission_means = jnp.array(km.cluster_centers_)
+            _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_cov = jnp.eye(self.emission_dim)
 
         elif method.lower() == "prior":
@@ -513,11 +497,7 @@ def initialize(self, key=jr.PRNGKey(0),
         """
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
-            _emission_means = jnp.array(km.cluster_centers_)
+            _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_cov_diag_factors = jnp.ones((self.num_states, self.emission_dim))
             _emission_cov_low_rank_factors = jnp.zeros((self.num_states, self.emission_dim, self.emission_rank))
 
diff --git a/dynamax/hidden_markov_model/models/gmm_hmm.py b/dynamax/hidden_markov_model/models/gmm_hmm.py
index 8b7e778c..e6f55d84 100644
--- a/dynamax/hidden_markov_model/models/gmm_hmm.py
+++ b/dynamax/hidden_markov_model/models/gmm_hmm.py
@@ -15,6 +15,7 @@
 from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
 from dynamax.utils.bijectors import RealToPSDBijector
 from dynamax.utils.utils import pytree_sum
+from dynamax.utils.cluster import kmeans_sklearn
 from dynamax.types import Scalar
 from typing import NamedTuple, Optional, Tuple, Union
 
@@ -77,12 +78,9 @@ def initialize(self, key=jr.PRNGKey(0),
                    emissions=None):
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
+            cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
-            _emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
+            _emission_means = jnp.tile(jnp.array(cluster_centers)[:, None, :], (1, self.num_components, 1))
             _emission_covs = jnp.tile(jnp.eye(self.emission_dim), (self.num_states, self.num_components, 1, 1))
 
         elif method.lower() == "prior":
@@ -299,12 +297,9 @@ def initialize(self, key=jr.PRNGKey(0),
                    emissions=None):
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
+            cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
-            _emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
+            _emission_means = jnp.tile(jnp.array(cluster_centers)[:, None, :], (1, self.num_components, 1))
             _emission_scale_diags = jnp.ones((self.num_states, self.num_components, self.emission_dim))
 
         elif method.lower() == "prior":
diff --git a/dynamax/hidden_markov_model/models/linreg_hmm.py b/dynamax/hidden_markov_model/models/linreg_hmm.py
index df63c1bd..26947f61 100644
--- a/dynamax/hidden_markov_model/models/linreg_hmm.py
+++ b/dynamax/hidden_markov_model/models/linreg_hmm.py
@@ -9,6 +9,7 @@
 from dynamax.types import Scalar
 from dynamax.utils.utils import pytree_sum
 from dynamax.utils.bijectors import RealToPSDBijector
+from dynamax.utils.cluster import kmeans_sklearn
 from tensorflow_probability.substrates import jax as tfp
 from typing import NamedTuple, Optional, Tuple, Union
 
@@ -58,12 +59,8 @@ def initialize(self,
                    emissions=None):
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
             _emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.input_dim))
-            _emission_biases = jnp.array(km.cluster_centers_)
+            _emission_biases, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
             _emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))
 
         elif method.lower() == "prior":
diff --git a/dynamax/hidden_markov_model/models/logreg_hmm.py b/dynamax/hidden_markov_model/models/logreg_hmm.py
index 2da4dd84..b9957a09 100644
--- a/dynamax/hidden_markov_model/models/logreg_hmm.py
+++ b/dynamax/hidden_markov_model/models/logreg_hmm.py
@@ -8,6 +8,7 @@
 from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
 from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
 from dynamax.types import Scalar
+from dynamax.utils.cluster import kmeans_sklearn
 import optax
 from typing import NamedTuple, Optional, Tuple, Union
 
@@ -48,16 +49,19 @@ def initialize(self,
         if method.lower() == "kmeans":
             assert emissions is not None, "Need emissions to initialize the model with K-Means!"
             assert inputs is not None, "Need inputs to initialize the model with K-Means!"
-            from sklearn.cluster import KMeans
 
             flat_emissions = emissions.reshape(-1,)
             flat_inputs = inputs.reshape(-1, self.input_dim)
-            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
-            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
-            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(flat_inputs)
+
+            _, km_labels = kmeans_sklearn(self.num_states, flat_inputs, key)
             _emission_weights = jnp.zeros((self.num_states, self.input_dim))
-            _emission_biases = jnp.array([tfb.Sigmoid().inverse(flat_emissions[km.labels_ == k].mean())
-                                          for k in range(self.num_states)])
+            cluster_emissions_means = jnp.array(
+                [jnp.mean(flat_emissions, where=km_labels == k) for k in range(self.num_states)]
+            )
+            cluster_emissions_means = jnp.where(
+                jnp.isnan(cluster_emissions_means), flat_emissions.mean(), cluster_emissions_means
+            )
+            _emission_biases = tfb.Sigmoid().inverse(cluster_emissions_means)
 
         elif method.lower() == "prior":
             # TODO: Use an MNIW prior

From 0f8646b45e67e89212f974ba5d2c9901fb887610 Mon Sep 17 00:00:00 2001
From: gileshd <gileshd@googlemail.com>
Date: Fri, 19 Jul 2024 22:14:02 +0100
Subject: [PATCH 3/3] Add jax implementation of kmeans

---
 dynamax/utils/cluster.py      | 82 +++++++++++++++++++++++++++++++++--
 dynamax/utils/cluster_test.py | 50 +++++++++++++++++++++
 2 files changed, 128 insertions(+), 4 deletions(-)
 create mode 100644 dynamax/utils/cluster_test.py

diff --git a/dynamax/utils/cluster.py b/dynamax/utils/cluster.py
index a31c9e05..cda9fb82 100644
--- a/dynamax/utils/cluster.py
+++ b/dynamax/utils/cluster.py
@@ -1,9 +1,9 @@
-from typing import Tuple
-
+from functools import partial
+from jax import lax, jit
 from jax import numpy as jnp
 from jax import random as jr
-
-from jaxtyping import Array, Float
+from jaxtyping import Array, Int, Float
+from typing import NamedTuple, Tuple
 
 
 def kmeans_sklearn(
@@ -26,3 +26,77 @@ def kmeans_sklearn(
     sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
     km = KMeans(k, random_state=int(sklearn_key)).fit(X)
     return jnp.array(km.cluster_centers_), jnp.array(km.labels_)
+
+
+class KMeansState(NamedTuple):
+    centroids: Float[Array, "num_states state_dim"]
+    assignments: Int[Array, "num_samples"]
+    prev_centroids: Float[Array, "num_states state_dim"]
+    itr: int
+
+
+@partial(jit, static_argnums=(1, 3))
+def kmeans_jax(
+    X: Float[Array, "num_samples state_dim"],
+    k: int,
+    key: Array = jr.PRNGKey(0),
+    max_iters: int = 1000,
+) -> KMeansState:
+    """
+    Perform k-means clustering using JAX.
+
+    K-means++ initialization is used to initialize the centroids.
+
+    Args:
+        X (Array): The input data array of shape (n_samples, n_features).
+        k (int): The number of clusters.
+        max_iters (int, optional): The maximum number of iterations. Defaults to 1000.
+        key (PRNGKey, optional): The random key for initialization. Defaults to jr.PRNGKey(0).
+
+    Returns:
+        KMeansState: A named tuple containing the final centroids array of shape (k, n_features),
+        the assignments array of shape (n_samples,) indicating the cluster index for each sample,
+        the previous centroids array of shape (k, n_features), and the number of iterations.
+    """
+
+    def _update_centroids(X: Array, assignments: Array):
+        new_centroids = jnp.array([jnp.mean(X, axis=0, where=(assignments == i)[:, None]) for i in range(k)])
+        return new_centroids
+
+    def _update_assignments(X, centroids):
+        return jnp.argmin(jnp.linalg.norm(X[:, None] - centroids, axis=2), axis=1)
+
+    def body(carry: KMeansState):
+        centroids, assignments, *_ = carry
+        new_centroids = _update_centroids(X, assignments)
+        new_assignments = _update_assignments(X, new_centroids)
+        return KMeansState(new_centroids, new_assignments, centroids, carry.itr + 1)
+
+    def cond(carry: KMeansState):
+        return jnp.any(carry.centroids != carry.prev_centroids) & (carry.itr < max_iters)
+
+    def init(key):
+        """kmeans++ initialization of centroids
+
+        Iteratively sample new centroids with probability proportional to the squared distance
+        from the closest centroid. This initialization method is more stable than random
+        initialization and leads to faster convergence.
+        Ref: Arthur, D., & Vassilvitskii, S. (2006).
+        """
+        centroids = jnp.zeros((k, X.shape[1]))
+        centroids = centroids.at[0, :].set(jr.choice(key, X))
+        for i in range(1, k):
+            squared_diffs = jnp.sum((X[:, None, :] - centroids[None, :i, :]) ** 2, axis=2)
+            min_squared_dists = jnp.min(squared_diffs, axis=1)
+            probs = min_squared_dists / jnp.sum(min_squared_dists)
+            centroids = centroids.at[i, :].set(jr.choice(key, X, p=probs))
+        assignments = _update_assignments(X, centroids)
+        # Perform one iteration to update centroids
+        updated_centroids = _update_centroids(X, assignments)
+        updated_assignments = _update_assignments(X, updated_centroids)
+        return KMeansState(updated_centroids, updated_assignments, centroids, 1)
+
+    init_state = init(key)
+    state = lax.while_loop(cond, body, init_state)
+
+    return state
diff --git a/dynamax/utils/cluster_test.py b/dynamax/utils/cluster_test.py
new file mode 100644
index 00000000..414120b3
--- /dev/null
+++ b/dynamax/utils/cluster_test.py
@@ -0,0 +1,50 @@
+from jax import numpy as jnp
+from jax import random as jr
+from jax import vmap
+
+from dynamax.utils.cluster import kmeans_jax
+
+
+def test_kmeans_jax_toy():
+    """Checks that kmeans works against toy example.
+
+    Ref: scikit-learn tests
+    """
+
+    key = jr.PRNGKey(101)
+    x = jnp.array([[0, 0], [0.5, 0], [0.5, 1], [1, 1]])
+
+    centroids, assignments, *_ = kmeans_jax(x, 2, key)
+
+    # There are two possible solutions for the centroids and assignments
+    try:
+        expected_labels = jnp.array([0, 0, 1, 1])
+        expected_centers = jnp.array([[0.25, 0], [0.75, 1]])
+        assert jnp.all(assignments == expected_labels)
+        assert jnp.allclose(centroids, expected_centers)
+    except AssertionError:
+        expected_labels = jnp.array([1, 1, 0, 0])
+        expected_centers = jnp.array([[0.75, 1.0], [0.25, 0.0]])
+        assert jnp.all(assignments == expected_labels)
+        assert jnp.allclose(centroids, expected_centers)
+
+
+def test_kmeans_jax_vmap():
+    """Test that kmeans_jax works with vmap."""
+
+    def _gen_data(key):
+        """Generate 3 clusters of 10 samples each."""
+        subkeys = jr.split(key, 3)
+        means = jnp.array([-2., 0., 2.])
+        _2D_normal = lambda key, mean: jr.normal(key, (10, 2))*0.2 + mean
+        return vmap(_2D_normal)(subkeys, means).reshape(-1, 2)
+
+    key = jr.PRNGKey(5)
+    key, *data_subkeys = jr.split(key,3)
+    # Generate 2 samples of the 3-cluster data
+    x = vmap(_gen_data)(jnp.array(data_subkeys))
+
+    alg_subkeys = jr.split(key, 2)
+    _, assignments, *_ = vmap(kmeans_jax, (0, None, 0))(x, 3, alg_subkeys)
+    # Check that the assignments are the same for both samples (clusters are very distinct)
+    assert jnp.all(assignments[0] == assignments[1])