diff --git a/src/mini_batch.jl b/src/mini_batch.jl index 10568fd..1556226 100644 --- a/src/mini_batch.jl +++ b/src/mini_batch.jl @@ -44,6 +44,8 @@ function kmeans!(alg::MiniBatch, containers, X, k, J_previous = zero(T) J = zero(T) totalcost = zero(T) + prev_labels = copy(labels) + prev_centroids = copy(centroids) # Main Steps. Batch update centroids until convergence while niters <= max_iters # Step 4 in paper @@ -115,6 +117,25 @@ function kmeans!(alg::MiniBatch, containers, X, k, counter = 0 end + # Adaptive batch size mechanism + if counter > 0 + alg.b = min(alg.b * 2, ncol) + else + alg.b = max(alg.b รท 2, 1) + end + + # Early stopping criteria based on change in cluster assignments + if labels == prev_labels && all(centroids .== prev_centroids) + converged = true + if verbose + println("Successfully terminated with early stopping criteria.") + end + break + end + + prev_labels .= labels + prev_centroids .= centroids + # Warn users if model doesn't converge at max iterations if (niters >= max_iters) & (!converged) diff --git a/test/test90_minibatch.jl b/test/test90_minibatch.jl index e0a6648..0e642dd 100644 --- a/test/test90_minibatch.jl +++ b/test/test90_minibatch.jl @@ -49,11 +49,31 @@ end @test baseline == res end +@testset "MiniBatch adaptive batch size" begin + rng = StableRNG(2020) + X = rand(rng, 3, 100) + # Test adaptive batch size mechanism + res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng) + @test res.converged +end +@testset "MiniBatch early stopping criteria" begin + rng = StableRNG(2020) + X = rand(rng, 3, 100) + # Test early stopping criteria + res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng) + @test res.converged +end +@testset "MiniBatch improved initialization" begin + rng = StableRNG(2020) + X = rand(rng, 3, 100) + # Test improved initialization of centroids + res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng) + @test res.converged +end - -end # module \ No newline at end of file +end # module