Skip to content

Commit

Permalink
Add a RNG to HdbscanTrainer to be used for the determination of clust…
Browse files Browse the repository at this point in the history
…er exemplars in some edge cases
  • Loading branch information
geoffreydstewart committed Nov 29, 2023
1 parent 1de1444 commit 6df1fc0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import org.tribuo.Trainer;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.neighbour.NeighboursQueryFactoryType;

Expand Down Expand Up @@ -70,7 +71,7 @@ public String getOptionsDescription() {
*/
@Option(longName = "hdbscan-exemplar-sample-seed", usage = "The seed to use when sampling cluster exemplars at " +
"random from members of a cluster.")
public long exemplarSampleSeed = 12345L;
public long exemplarSampleSeed = Trainer.DEFAULT_SEED;

/**
* Gets the configured HdbscanTrainer using the options in this object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ public DistanceType getDistanceType() {

@Config(description = "The seed to use in the event cluster exemplars need to be determined with random samples " +
"from the members of a cluster.")
private long exemplarSampleSeed = 12345L;
private long exemplarSampleSeed = Trainer.DEFAULT_SEED;

private SplittableRandom rng = new SplittableRandom(exemplarSampleSeed);

private int trainInvocationCounter;

Expand Down Expand Up @@ -225,6 +227,7 @@ public HdbscanTrainer(int minClusterSize, org.tribuo.math.distance.Distance dist
this.numThreads = numThreads;
this.neighboursQueryFactory = NeighboursQueryFactoryType.getNeighboursQueryFactory(nqFactoryType, dist, numThreads);
this.exemplarSampleSeed = exemplarSampleSeed;
this.rng = new SplittableRandom(exemplarSampleSeed);
}

/**
Expand Down Expand Up @@ -271,9 +274,11 @@ public synchronized void postConfig() {
public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
// increment the invocation count.
TrainerProvenance trainerProvenance;
SplittableRandom localRNG;
synchronized (this) {
trainerProvenance = getProvenance();
trainInvocationCounter++;
localRNG = rng.split();
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();

Expand Down Expand Up @@ -309,7 +314,7 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);

// Compute the cluster exemplars.
List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, dist);
List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, dist, localRNG);

// Get the outlier score value for points that are predicted as noise points.
double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
Expand Down Expand Up @@ -337,8 +342,10 @@ public int getInvocationCount() {
public void setInvocationCount(int newInvocationCount) {
if(newInvocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
} else {
trainInvocationCounter = newInvocationCount;
}

for (trainInvocationCounter = 0; trainInvocationCounter < newInvocationCount; trainInvocationCounter++){
SplittableRandom localRNG = rng.split();
}
}

Expand Down Expand Up @@ -780,10 +787,11 @@ private static Map<Integer, List<Pair<Double, Integer>>> generateClusterAssignme
* @param data An array of {@link DenseVector} containing the data.
* @param clusterAssignments A map of the cluster labels, and the points assigned to them.
* @param dist The distance metric to employ.
* @param rng The RNG to use.
* @return A list of {@link ClusterExemplar}s which are used for predictions.
*/
private List<ClusterExemplar> computeExemplars(SGDVector[] data, Map<Integer, List<Pair<Double, Integer>>> clusterAssignments,
org.tribuo.math.distance.Distance dist) {
org.tribuo.math.distance.Distance dist, SplittableRandom rng) {
List<ClusterExemplar> clusterExemplars = new ArrayList<>();
// The formula to calculate the exemplar number. This calculates the number of exemplars to be used for this
// configuration. The appropriate number of exemplars is important for prediction. At the time, this
Expand Down Expand Up @@ -846,9 +854,8 @@ private List<ClusterExemplar> computeExemplars(SGDVector[] data, Map<Integer, Li
// To determine the remaining exemplars, the best thing to do is randomly sample them from all the
// points in this cluster. This could introduce duplicate exemplar points, but that is safer than
// reducing the number of exemplars.
SplittableRandom rand = new SplittableRandom(exemplarSampleSeed);
int numSamples = numExemplarsThisCluster - outlierScoreTreeSize;
Stream<Integer> intStreamSamples = rand.ints(numSamples, 0, outlierScoreIndexList.size()).boxed();
Stream<Integer> intStreamSamples = rng.ints(numSamples, 0, outlierScoreIndexList.size()).boxed();
intStreamSamples.forEach((i) -> partialClusterExemplarsList.add(outlierScoreIndexList.get(i)));

// For each of the partial exemplars in this cluster, iterate the nodes in the list to find the
Expand Down

0 comments on commit 6df1fc0

Please sign in to comment.