Skip to content

Commit 6bcc19b

Browse files
sam-hermanjshookmarianotepper
authored
Reconstruct heap graph from disk graph (#536)
* Support OnHeapGraphReconstruction * Get incremental graph construction working * Add serialization for NeighborsCache * Add ord-mapping logic from PQ vectors to RAVV * Add ord mapping from RAVV to graph creation Signed-off-by: Samuel Herman <[email protected]> * Rebase fix Signed-off-by: Samuel Herman <[email protected]> * Switch interface for convert graph Signed-off-by: Samuel Herman <[email protected]> * Remove explicit mentioning of OnDiskGraphIndex in builder Signed-off-by: Samuel Herman <[email protected]> * Add dimension Signed-off-by: Samuel Herman <[email protected]> * Remove NeighborsCache. Refactoring to respect interface boundaries * Implementation that serializes/deserializes the OnHeapGraphIndex * Label OnHeapGraphIndex.save and OnHeapGraphIndex.load as experimental. * Add experimental tag to GraphIndexBuilder.buildAndMergeNewNodes * Bug fixes to make tests pass * Added deprecated tags to OnHeapGraphIndex.load and OnHeapGraphIndex.save * Fix documentation of MutableGraphIndex.allMutationsCompleted --------- Signed-off-by: Samuel Herman <[email protected]> Co-authored-by: Jonathan Shook <[email protected]> Co-authored-by: Mariano Tepper <[email protected]>
1 parent 47cc10b commit 6bcc19b

File tree

9 files changed

+614
-52
lines changed

9 files changed

+614
-52
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package io.github.jbellis.jvector.graph;
1818

19+
import io.github.jbellis.jvector.annotations.Experimental;
1920
import io.github.jbellis.jvector.annotations.VisibleForTesting;
2021
import io.github.jbellis.jvector.disk.RandomAccessReader;
2122
import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel;
@@ -325,6 +326,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
325326
this.parallelExecutor = parallelExecutor;
326327

327328
this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha));
329+
328330
this.searchers = ExplicitThreadLocal.withInitial(() -> {
329331
var gs = new GraphSearcher(graph);
330332
gs.usePruning(false);
@@ -338,6 +340,58 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
338340
this.rng = new Random(0);
339341
}
340342

343+
/**
344+
* Create this builder from an existing {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex}, this is useful when we just loaded a graph from disk
345+
* copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object
346+
*
347+
* @param buildScoreProvider the provider responsible for calculating build scores.
348+
* @param mutableGraphIndex a mutable graph index.
349+
* @param beamWidth the width of the beam used during the graph building process.
350+
* @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit.
351+
* @param alpha the weight factor for balancing score computations.
352+
* @param addHierarchy whether to add hierarchical structures while building the graph.
353+
* @param refineFinalGraph whether to perform a refinement step on the final graph structure.
354+
* @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building.
355+
* @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building.
356+
*
357+
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
358+
*/
359+
private GraphIndexBuilder(BuildScoreProvider buildScoreProvider, int dimension, MutableGraphIndex mutableGraphIndex, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
360+
if (beamWidth <= 0) {
361+
throw new IllegalArgumentException("beamWidth must be positive");
362+
}
363+
if (neighborOverflow < 1.0f) {
364+
throw new IllegalArgumentException("neighborOverflow must be >= 1.0");
365+
}
366+
if (alpha <= 0) {
367+
throw new IllegalArgumentException("alpha must be positive");
368+
}
369+
370+
this.scoreProvider = buildScoreProvider;
371+
this.neighborOverflow = neighborOverflow;
372+
this.dimension = dimension;
373+
this.alpha = alpha;
374+
this.addHierarchy = addHierarchy;
375+
this.refineFinalGraph = refineFinalGraph;
376+
this.beamWidth = beamWidth;
377+
this.simdExecutor = simdExecutor;
378+
this.parallelExecutor = parallelExecutor;
379+
380+
this.graph = mutableGraphIndex;
381+
382+
this.searchers = ExplicitThreadLocal.withInitial(() -> {
383+
var gs = new GraphSearcher(graph);
384+
gs.usePruning(false);
385+
return gs;
386+
});
387+
388+
// in scratch, we store candidates in reverse order: worse candidates are first
389+
this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
390+
this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
391+
392+
this.rng = new Random(0);
393+
}
394+
341395
// used by Cassandra when it fine-tunes the PQ codebook
342396
public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
343397
var newBuilder = new GraphIndexBuilder(newProvider,
@@ -450,13 +504,13 @@ public void cleanup() {
450504
// clean up overflowed neighbor lists
451505
parallelExecutor.submit(() -> {
452506
IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(id -> {
453-
for (int layer = 0; layer <= graph.getMaxLevel(); layer++) {
507+
for (int level = 0; level <= graph.getMaxLevel(); level++) {
454508
graph.enforceDegree(id);
455509
}
456510
});
457511
}).join();
458512

459-
graph.allMutationsCompleted();
513+
graph.setAllMutationsCompleted();
460514
}
461515

462516
private void improveConnections(int node) {
@@ -825,6 +879,9 @@ public void load(RandomAccessReader in) throws IOException {
825879
loadV3(in, size);
826880
} else {
827881
version = in.readInt();
882+
if (version != 4) {
883+
throw new IOException("Unsupported version: " + version);
884+
}
828885
loadV4(in);
829886
}
830887
}
@@ -836,15 +893,18 @@ private void loadV4(RandomAccessReader in) throws IOException {
836893
}
837894

838895
int layerCount = in.readInt();
839-
int entryNode = in.readInt();
840896
var layerDegrees = new ArrayList<Integer>(layerCount);
897+
for (int level = 0; level < layerCount; level++) {
898+
layerDegrees.add(in.readInt());
899+
}
900+
901+
int entryNode = in.readInt();
841902

842903
Map<Integer, Integer> nodeLevelMap = new HashMap<>();
843904

844905
// Read layer info
845906
for (int level = 0; level < layerCount; level++) {
846907
int layerSize = in.readInt();
847-
layerDegrees.add(in.readInt());
848908
for (int i = 0; i < layerSize; i++) {
849909
int nodeId = in.readInt();
850910
int nNeighbors = in.readInt();
@@ -860,6 +920,7 @@ private void loadV4(RandomAccessReader in) throws IOException {
860920
var ca = new NodeArray(nNeighbors);
861921
for (int j = 0; j < nNeighbors; j++) {
862922
int neighbor = in.readInt();
923+
float score = in.readFloat();
863924
ca.addInOrder(neighbor, sf.similarityTo(neighbor));
864925
}
865926
graph.connectNode(level, nodeId, ca);
@@ -909,4 +970,61 @@ private void loadV3(RandomAccessReader in, int size) throws IOException {
909970
graph.updateEntryNode(new NodeAtLevel(0, entryNode));
910971
graph.setDegrees(List.of(maxDegree));
911972
}
973+
974+
/**
975+
* Convenience method to build a new graph from an existing one, with the addition of new nodes.
976+
* This is useful when we want to merge a new set of vectors into an existing graph that is already on disk.
977+
*
978+
* @param in a reader from which to read the on-heap graph.
979+
* @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph
980+
* @param buildScoreProvider the provider responsible for calculating build scores.
981+
* @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start
982+
* @param graphToRavvOrdMap a mapping from the old graph's node ids to the newVectors RAVV node ids
983+
* @param beamWidth the width of the beam used during the graph building process.
984+
* @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node.
985+
* @param alpha the weight factor for balancing score computations.
986+
* @param addHierarchy whether to add hierarchical structures while building the graph.
987+
*
988+
* @return the in-memory representation of the graph index.
989+
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
990+
*/
991+
@Experimental
992+
public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in,
993+
RandomAccessVectorValues newVectors,
994+
BuildScoreProvider buildScoreProvider,
995+
int startingNodeOffset,
996+
int[] graphToRavvOrdMap,
997+
int beamWidth,
998+
float overflowRatio,
999+
float alpha,
1000+
boolean addHierarchy) throws IOException {
1001+
1002+
var diversityProvider = new VamanaDiversityProvider(buildScoreProvider, alpha);
1003+
1004+
try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, overflowRatio, diversityProvider);) {
1005+
1006+
GraphIndexBuilder builder = new GraphIndexBuilder(
1007+
buildScoreProvider,
1008+
newVectors.dimension(),
1009+
graph,
1010+
beamWidth,
1011+
overflowRatio,
1012+
alpha,
1013+
addHierarchy,
1014+
true,
1015+
PhysicalCoreExecutor.pool(),
1016+
ForkJoinPool.commonPool()
1017+
);
1018+
1019+
var vv = newVectors.threadLocalSupplier();
1020+
1021+
// parallel graph construction from the merge documents Ids
1022+
PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
1023+
builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord]));
1024+
})).join();
1025+
1026+
builder.cleanup();
1027+
return builder.getGraph();
1028+
}
1029+
}
9121030
}

jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,10 @@ interface MutableGraphIndex extends ImmutableGraphIndex {
166166
* Signals that all mutations have been completed and the graph will not be mutated any further.
167167
* Should be called by the builder after all mutations are completed (during cleanup).
168168
*/
169-
void allMutationsCompleted();
169+
void setAllMutationsCompleted();
170+
171+
/**
172+
* Returns true if all mutations have been completed. This is signaled by calling setAllMutationsCompleted.
173+
*/
174+
boolean allMutationsCompleted();
170175
}

jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java

Lines changed: 101 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
package io.github.jbellis.jvector.graph;
2626

27+
import io.github.jbellis.jvector.annotations.Experimental;
28+
import io.github.jbellis.jvector.disk.RandomAccessReader;
2729
import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors;
2830
import io.github.jbellis.jvector.graph.diversity.DiversityProvider;
2931
import io.github.jbellis.jvector.util.Accountable;
@@ -37,9 +39,10 @@
3739

3840
import java.io.DataOutput;
3941
import java.io.IOException;
40-
import java.io.UncheckedIOException;
4142
import java.util.ArrayList;
43+
import java.util.HashMap;
4244
import java.util.List;
45+
import java.util.Map;
4346
import java.util.NoSuchElementException;
4447
import java.util.concurrent.atomic.AtomicInteger;
4548
import java.util.concurrent.atomic.AtomicIntegerArray;
@@ -367,10 +370,14 @@ public void setDegrees(List<Integer> layerDegrees) {
367370
}
368371

369372
@Override
370-
public void allMutationsCompleted() {
373+
public void setAllMutationsCompleted() {
371374
allMutationsCompleted = true;
372375
}
373376

377+
@Override
378+
public boolean allMutationsCompleted() {
379+
return allMutationsCompleted;
380+
}
374381

375382
/**
376383
* A concurrent View of the graph that is safe to search concurrently with updates and with other
@@ -490,44 +497,101 @@ public String toString() {
490497
/**
491498
* Saves the graph to the given DataOutput for reloading into memory later
492499
*/
500+
@Experimental
493501
@Deprecated
494-
public void save(DataOutput out) {
495-
if (deletedNodes.cardinality() > 0) {
496-
throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first");
497-
}
498-
499-
try (var view = getView()) {
500-
out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number
501-
out.writeInt(4); // The version
502-
503-
// Write graph-level properties.
504-
out.writeInt(layers.size());
505-
assert view.entryNode().level == getMaxLevel();
506-
out.writeInt(view.entryNode().node);
507-
508-
for (int level = 0; level < layers.size(); level++) {
509-
out.writeInt(size(level));
510-
out.writeInt(getDegree(level));
511-
512-
// Save neighbors from the layer.
513-
var baseLayer = layers.get(level);
514-
baseLayer.forEach((nodeId, neighbors) -> {
515-
try {
516-
NodesIterator iterator = neighbors.iterator();
517-
out.writeInt(nodeId);
518-
out.writeInt(iterator.size());
519-
for (int n = 0; n < iterator.size(); n++) {
520-
out.writeInt(iterator.nextInt());
521-
}
522-
assert !iterator.hasNext();
523-
} catch (IOException e) {
524-
throw new UncheckedIOException(e);
525-
}
526-
});
502+
public void save(DataOutput out) throws IOException {
503+
if (!allMutationsCompleted()) {
504+
throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first");
505+
}
506+
507+
out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number
508+
out.writeInt(4); // The version
509+
510+
// Write graph-level properties.
511+
out.writeInt(layers.size());
512+
for (int level = 0; level < layers.size(); level++) {
513+
out.writeInt(getDegree(level));
514+
}
515+
516+
var entryNode = entryPoint.get();
517+
assert entryNode.level == getMaxLevel();
518+
out.writeInt(entryNode.node);
519+
520+
for (int level = 0; level < layers.size(); level++) {
521+
out.writeInt(size(level));
522+
523+
// Save neighbors from the layer.
524+
var it = nodeStream(level).iterator();
525+
while (it.hasNext()) {
526+
int nodeId = it.nextInt();
527+
var neighbors = layers.get(level).get(nodeId);
528+
out.writeInt(nodeId);
529+
out.writeInt(neighbors.size());
530+
531+
for (int n = 0; n < neighbors.size(); n++) {
532+
out.writeInt(neighbors.getNode(n));
533+
out.writeFloat(neighbors.getScore(n));
534+
}
535+
}
536+
}
537+
}
538+
539+
/**
540+
* Saves the graph to the given DataOutput for reloading into memory later
541+
*/
542+
@Experimental
543+
@Deprecated
544+
public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, DiversityProvider diversityProvider) throws IOException {
545+
int magic = in.readInt(); // the magic number
546+
if (magic != OnHeapGraphIndex.MAGIC) {
547+
throw new IOException("Unsupported magic number: " + magic);
548+
}
549+
550+
int version = in.readInt(); // The version
551+
if (version != 4) {
552+
throw new IOException("Unsupported version: " + version);
553+
}
554+
555+
// Write graph-level properties.
556+
int layerCount = in.readInt();
557+
var layerDegrees = new ArrayList<Integer>(layerCount);
558+
for (int level = 0; level < layerCount; level++) {
559+
layerDegrees.add(in.readInt());
560+
}
561+
562+
int entryNode = in.readInt();
563+
564+
var graph = new OnHeapGraphIndex(layerDegrees, overflowRatio, diversityProvider);
565+
566+
Map<Integer, Integer> nodeLevelMap = new HashMap<>();
567+
568+
for (int level = 0; level < layerCount; level++) {
569+
int layerSize = in.readInt();
570+
571+
for (int i = 0; i < layerSize; i++) {
572+
int nodeId = in.readInt();
573+
int nNeighbors = in.readInt();
574+
575+
var ca = new NodeArray(nNeighbors);
576+
for (int j = 0; j < nNeighbors; j++) {
577+
int neighbor = in.readInt();
578+
float score = in.readFloat();
579+
ca.addInOrder(neighbor, score);
580+
}
581+
graph.connectNode(level, nodeId, ca);
582+
nodeLevelMap.put(nodeId, level);
527583
}
528-
} catch (IOException e) {
529-
throw new UncheckedIOException(e);
530584
}
585+
586+
for (var k : nodeLevelMap.keySet()) {
587+
NodeAtLevel nal = new NodeAtLevel(nodeLevelMap.get(k), k);
588+
graph.markComplete(nal);
589+
}
590+
591+
graph.setDegrees(layerDegrees);
592+
graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode));
593+
594+
return graph;
531595
}
532596

533597
/**

jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,6 @@ public Set<FeatureId> getFeatureSet() {
227227
return features.keySet();
228228
}
229229

230-
public int getDimension() {
231-
return dimension;
232-
}
233-
234230
@Override
235231
public int size(int level) {
236232
return layerInfo.get(level).size;

0 commit comments

Comments
 (0)