diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index 0ffdf72eb..e61f7fc94 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.graph.similarity; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.OrdinalMapper; import io.github.jbellis.jvector.quantization.BQVectors; import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; @@ -89,14 +90,16 @@ public interface BuildScoreProvider { * Helper method for the special case that mapping between graph node IDs and ravv ordinals is the identity function. */ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) { - return randomAccessScoreProvider(ravv, IntStream.range(0, ravv.size()).toArray(), similarityFunction); + // We don't know the max ordinal, and that's fine, so we just use Integer.MAX_VALUE + return randomAccessScoreProvider(ravv, new OrdinalMapper.IdentityMapper(Integer.MAX_VALUE), similarityFunction); } /** * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. - * graphToRavvOrdMap maps graph node IDs to ravv ordinals. + * graphToRavvOrdMap maps graph node IDs to ravv ordinals. The OrdinalMapper is used to map between the graph's ordinals + * and the ordinals used by the RandomAccessVectorValues via the oldToNew method. */ - static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) { + static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, OrdinalMapper ordinalMapper, VectorSimilarityFunction similarityFunction) { // We need two sources of vectors in order to perform diversity check comparisons without // colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared. var vectors = ravv.threadLocalSupplier(); @@ -125,22 +128,22 @@ public VectorFloat approximateCentroid() { @Override public SearchScoreProvider searchProviderFor(VectorFloat vector) { var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(vector, ordinalMapper, similarityFunction, vc); } @Override public SearchScoreProvider searchProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); + var v = randomAccessVectorValues.getVector(ordinalMapper.oldToNew(node1)); return searchProviderFor(v); } @Override public SearchScoreProvider diversityProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); + var v = randomAccessVectorValues.getVector(ordinalMapper.oldToNew(node1)); var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(v, ordinalMapper, similarityFunction, vc); } }; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java index de46762b2..a19f4e5db 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.graph.similarity; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.OrdinalMapper; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; @@ -82,14 +83,15 @@ public float similarityTo(int node2) { /** * A SearchScoreProvider for a single-pass search based on exact similarity. * Generally only suitable when your RandomAccessVectorValues is entirely in-memory, - * e.g. during construction. + * e.g. during construction. The ordinal mapper is used to map between the graph's ordinals + * and the ordinals used by the RandomAccessVectorValues. */ - public static DefaultSearchScoreProvider exact(VectorFloat v, int[] graphToRavvOrdMap ,VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) { + public static DefaultSearchScoreProvider exact(VectorFloat v, OrdinalMapper ordinalMapper, VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) { // don't use ESF.reranker, we need thread safety here var sf = new ScoreFunction.ExactScoreFunction() { @Override public float similarityTo(int node2) { - return vsf.compare(v, ravv.getVector(graphToRavvOrdMap[node2])); + return vsf.compare(v, ravv.getVector(ordinalMapper.oldToNew(node2))); } }; return new DefaultSearchScoreProvider(sf); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index 716621d21..59b248584 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -156,4 +156,19 @@ public void testSaveAndLoad() throws IOException { } assertGraphEquals(graph, builder.graph); } + + // Because RandomAccessVectorValues is exposed in such a way that it allows for subsequent additions to the + // vector source, we need to ensure that GraphIndexBuilder can handle this. + @Test + public void testAddNodesToVectorValuesIteratively() throws IOException { + int dimension = randomIntBetween(2, 32); + var mutableVectors = new ArrayList>(); + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(mutableVectors, dimension); + try (var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, true)) { + for (int i = 0; i < 10; i++) { + mutableVectors.add(TestUtil.randomVector(random(), dimension)); + builder.addGraphNode(i, ravv.getVector(i)); + } + } + } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java index 4942b8efb..51458dac0 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.graph.similarity; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.OrdinalMapper; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; @@ -24,6 +25,7 @@ import org.junit.Test; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import static org.junit.Assert.assertEquals; @@ -46,9 +48,13 @@ public void testOrdinalMapping() { var ravv = new ListRandomAccessVectorValues(vectors, 2); // Create non-identity mapping: graph node 0 -> ravv ordinal 2, graph node 1 -> ravv ordinal 0, graph node 2 -> ravv ordinal 1 - int[] graphToRavvOrdMap = {2, 0, 1}; + var oldToNew = new HashMap(); + oldToNew.put(0, 2); + oldToNew.put(1, 0); + oldToNew.put(2, 1); + var ordinalMapper = new OrdinalMapper.MapMapper(oldToNew); - var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, graphToRavvOrdMap, vsf); + var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, ordinalMapper, vsf); // Test that searchProviderFor(graphNode) uses the correct RAVV ordinal var ssp0 = bsp.searchProviderFor(0); // should use ravv ordinal 2 (vector [-1, 0])