Skip to content

Commit 11eb2c8

Browse files
authored
Add some basic HNSW graph checks to CheckIndex (#13984)
1 parent d4f0a32 commit 11eb2c8

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed

lucene/core/src/java/org/apache/lucene/index/CheckIndex.java

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626
import java.nio.file.Path;
2727
import java.nio.file.Paths;
2828
import java.text.NumberFormat;
29+
import java.util.ArrayDeque;
2930
import java.util.ArrayList;
3031
import java.util.Arrays;
32+
import java.util.Collections;
33+
import java.util.Deque;
3134
import java.util.HashMap;
3235
import java.util.Iterator;
3336
import java.util.List;
@@ -52,12 +55,14 @@
5255
import org.apache.lucene.codecs.StoredFieldsReader;
5356
import org.apache.lucene.codecs.TermVectorsReader;
5457
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
58+
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
5559
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
5660
import org.apache.lucene.document.Document;
5761
import org.apache.lucene.document.DocumentStoredFieldVisitor;
5862
import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus;
5963
import org.apache.lucene.index.PointValues.IntersectVisitor;
6064
import org.apache.lucene.index.PointValues.Relation;
65+
import org.apache.lucene.internal.hppc.IntIntHashMap;
6166
import org.apache.lucene.search.DocIdSetIterator;
6267
import org.apache.lucene.search.FieldExistsQuery;
6368
import org.apache.lucene.search.KnnCollector;
@@ -74,6 +79,7 @@
7479
import org.apache.lucene.store.Lock;
7580
import org.apache.lucene.util.ArrayUtil;
7681
import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
82+
import org.apache.lucene.util.BitSet;
7783
import org.apache.lucene.util.Bits;
7884
import org.apache.lucene.util.BytesRef;
7985
import org.apache.lucene.util.BytesRefBuilder;
@@ -91,6 +97,7 @@
9197
import org.apache.lucene.util.automaton.ByteRunAutomaton;
9298
import org.apache.lucene.util.automaton.CompiledAutomaton;
9399
import org.apache.lucene.util.automaton.Operations;
100+
import org.apache.lucene.util.hnsw.HnswGraph;
94101

95102
/**
96103
* Basic tool and API to check the health of an index and write a new segments file that removes
@@ -249,6 +256,9 @@ public static class SegmentInfoStatus {
249256
/** Status of vectors */
250257
public VectorValuesStatus vectorValuesStatus;
251258

259+
/** Status of HNSW graph */
260+
public HnswGraphsStatus hnswGraphsStatus;
261+
252262
/** Status of soft deletes */
253263
public SoftDeletesStatus softDeletesStatus;
254264

@@ -406,6 +416,32 @@ public static final class VectorValuesStatus {
406416
public Throwable error;
407417
}
408418

419+
/** Status from testing a single HNSW graph */
420+
public static final class HnswGraphStatus {
421+
422+
HnswGraphStatus() {}
423+
424+
/** Number of nodes at each level */
425+
public List<Integer> numNodesAtLevel;
426+
427+
/** Connectedness at each level represented as a fraction */
428+
public List<String> connectednessAtLevel;
429+
}
430+
431+
/** Status from testing all HNSW graphs */
432+
public static final class HnswGraphsStatus {
433+
434+
HnswGraphsStatus() {
435+
this.hnswGraphsStatusByField = new HashMap<>();
436+
}
437+
438+
/** Status of the HNSW graph keyed with field name */
439+
public Map<String, HnswGraphStatus> hnswGraphsStatusByField;
440+
441+
/** Exception thrown during term index test (null on success) */
442+
public Throwable error;
443+
}
444+
409445
/** Status from testing index sort */
410446
public static final class IndexSortStatus {
411447
IndexSortStatus() {}
@@ -1085,6 +1121,9 @@ private Status.SegmentInfoStatus testSegment(
10851121
// Test FloatVectorValues and ByteVectorValues
10861122
segInfoStat.vectorValuesStatus = testVectors(reader, infoStream, failFast);
10871123

1124+
// Test HNSW graph
1125+
segInfoStat.hnswGraphsStatus = testHnswGraphs(reader, infoStream, failFast);
1126+
10881127
// Test Index Sort
10891128
if (indexSort != null) {
10901129
segInfoStat.indexSortStatus = testSort(reader, indexSort, infoStream, failFast);
@@ -2746,6 +2785,196 @@ public static Status.VectorValuesStatus testVectors(
27462785
return status;
27472786
}
27482787

2788+
/** Test the HNSW graph. */
2789+
public static Status.HnswGraphsStatus testHnswGraphs(
2790+
CodecReader reader, PrintStream infoStream, boolean failFast) throws IOException {
2791+
if (infoStream != null) {
2792+
infoStream.print(" test: hnsw graphs.........");
2793+
}
2794+
long startNS = System.nanoTime();
2795+
Status.HnswGraphsStatus status = new Status.HnswGraphsStatus();
2796+
KnnVectorsReader vectorsReader = reader.getVectorReader();
2797+
FieldInfos fieldInfos = reader.getFieldInfos();
2798+
2799+
try {
2800+
if (fieldInfos.hasVectorValues()) {
2801+
for (FieldInfo fieldInfo : fieldInfos) {
2802+
if (fieldInfo.hasVectorValues()) {
2803+
KnnVectorsReader fieldReader = getFieldReaderForName(vectorsReader, fieldInfo.name);
2804+
if (fieldReader instanceof HnswGraphProvider graphProvider) {
2805+
HnswGraph hnswGraph = graphProvider.getGraph(fieldInfo.name);
2806+
testHnswGraph(hnswGraph, fieldInfo.name, status);
2807+
}
2808+
}
2809+
}
2810+
}
2811+
msg(
2812+
infoStream,
2813+
String.format(
2814+
Locale.ROOT,
2815+
"OK [%d fields] [took %.3f sec]",
2816+
status.hnswGraphsStatusByField.size(),
2817+
nsToSec(System.nanoTime() - startNS)));
2818+
printHnswInfo(infoStream, status.hnswGraphsStatusByField);
2819+
} catch (Exception e) {
2820+
if (failFast) {
2821+
throw IOUtils.rethrowAlways(e);
2822+
}
2823+
msg(infoStream, "ERROR: " + e);
2824+
status.error = e;
2825+
if (infoStream != null) {
2826+
e.printStackTrace(infoStream);
2827+
}
2828+
}
2829+
2830+
return status;
2831+
}
2832+
2833+
private static KnnVectorsReader getFieldReaderForName(
2834+
KnnVectorsReader vectorsReader, String fieldName) {
2835+
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
2836+
return fieldsReader.getFieldReader(fieldName);
2837+
} else {
2838+
return vectorsReader;
2839+
}
2840+
}
2841+
2842+
private static void printHnswInfo(
2843+
PrintStream infoStream, Map<String, CheckIndex.Status.HnswGraphStatus> fieldsStatus) {
2844+
for (Map.Entry<String, CheckIndex.Status.HnswGraphStatus> entry : fieldsStatus.entrySet()) {
2845+
String fieldName = entry.getKey();
2846+
CheckIndex.Status.HnswGraphStatus status = entry.getValue();
2847+
msg(infoStream, " hnsw field name: " + fieldName);
2848+
2849+
int numLevels = Math.min(status.numNodesAtLevel.size(), status.connectednessAtLevel.size());
2850+
for (int level = numLevels - 1; level >= 0; level--) {
2851+
int numNodes = status.numNodesAtLevel.get(level);
2852+
String connectedness = status.connectednessAtLevel.get(level);
2853+
msg(
2854+
infoStream,
2855+
String.format(
2856+
Locale.ROOT,
2857+
" level %d: %d nodes, %s connected",
2858+
level,
2859+
numNodes,
2860+
connectedness));
2861+
}
2862+
}
2863+
}
2864+
2865+
private static void testHnswGraph(
2866+
HnswGraph hnswGraph, String fieldName, Status.HnswGraphsStatus status)
2867+
throws IOException, CheckIndexException {
2868+
if (hnswGraph != null) {
2869+
status.hnswGraphsStatusByField.put(fieldName, new Status.HnswGraphStatus());
2870+
final int numLevels = hnswGraph.numLevels();
2871+
status.hnswGraphsStatusByField.get(fieldName).numNodesAtLevel =
2872+
new ArrayList<>(Collections.nCopies(numLevels, null));
2873+
status.hnswGraphsStatusByField.get(fieldName).connectednessAtLevel =
2874+
new ArrayList<>(Collections.nCopies(numLevels, null));
2875+
// Perform checks on each level of the HNSW graph
2876+
for (int level = numLevels - 1; level >= 0; level--) {
2877+
// Collect BitSet of all nodes on this level
2878+
BitSet nodesOnThisLevel = new FixedBitSet(hnswGraph.size());
2879+
HnswGraph.NodesIterator nodesIterator = hnswGraph.getNodesOnLevel(level);
2880+
while (nodesIterator.hasNext()) {
2881+
nodesOnThisLevel.set(nodesIterator.nextInt());
2882+
}
2883+
2884+
nodesIterator = hnswGraph.getNodesOnLevel(level);
2885+
// Perform checks on each node on the level
2886+
while (nodesIterator.hasNext()) {
2887+
int node = nodesIterator.nextInt();
2888+
if (node < 0 || node > hnswGraph.size() - 1) {
2889+
throw new CheckIndexException(
2890+
"Field \""
2891+
+ fieldName
2892+
+ "\" has node: "
2893+
+ node
2894+
+ " not in the expected range [0, "
2895+
+ (hnswGraph.size() - 1)
2896+
+ "]");
2897+
}
2898+
2899+
// Perform checks on the node's neighbors
2900+
hnswGraph.seek(level, node);
2901+
int nbr, lastNeighbor = -1, firstNeighbor = -1;
2902+
while ((nbr = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) {
2903+
if (!nodesOnThisLevel.get(nbr)) {
2904+
throw new CheckIndexException(
2905+
"Field \""
2906+
+ fieldName
2907+
+ "\" has node: "
2908+
+ node
2909+
+ " with a neighbor "
2910+
+ nbr
2911+
+ " which is not on its level ("
2912+
+ level
2913+
+ ")");
2914+
}
2915+
if (firstNeighbor == -1) {
2916+
firstNeighbor = nbr;
2917+
}
2918+
if (nbr < lastNeighbor) {
2919+
throw new CheckIndexException(
2920+
"Field \""
2921+
+ fieldName
2922+
+ "\" has neighbors out of order for node "
2923+
+ node
2924+
+ ": "
2925+
+ nbr
2926+
+ "<"
2927+
+ lastNeighbor
2928+
+ " 1st="
2929+
+ firstNeighbor);
2930+
} else if (nbr == lastNeighbor) {
2931+
throw new CheckIndexException(
2932+
"Field \""
2933+
+ fieldName
2934+
+ "\" has repeated neighbors of node "
2935+
+ node
2936+
+ " with value "
2937+
+ nbr);
2938+
}
2939+
lastNeighbor = nbr;
2940+
}
2941+
}
2942+
int numNodesOnLayer = nodesIterator.size();
2943+
status.hnswGraphsStatusByField.get(fieldName).numNodesAtLevel.set(level, numNodesOnLayer);
2944+
2945+
// Evaluate connectedness at this level by measuring the number of nodes reachable from the
2946+
// entry point
2947+
IntIntHashMap connectedNodes = getConnectedNodesOnLevel(hnswGraph, numNodesOnLayer, level);
2948+
status
2949+
.hnswGraphsStatusByField
2950+
.get(fieldName)
2951+
.connectednessAtLevel
2952+
.set(level, connectedNodes.size() + "/" + numNodesOnLayer);
2953+
}
2954+
}
2955+
}
2956+
2957+
private static IntIntHashMap getConnectedNodesOnLevel(
2958+
HnswGraph hnswGraph, int numNodesOnLayer, int level) throws IOException {
2959+
IntIntHashMap connectedNodes = new IntIntHashMap(numNodesOnLayer);
2960+
int entryPoint = hnswGraph.entryNode();
2961+
Deque<Integer> stack = new ArrayDeque<>();
2962+
stack.push(entryPoint);
2963+
while (!stack.isEmpty()) {
2964+
int node = stack.pop();
2965+
if (connectedNodes.containsKey(node)) {
2966+
continue;
2967+
}
2968+
connectedNodes.put(node, 1);
2969+
hnswGraph.seek(level, node);
2970+
int friendOrd;
2971+
while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) {
2972+
stack.push(friendOrd);
2973+
}
2974+
}
2975+
return connectedNodes;
2976+
}
2977+
27492978
private static boolean vectorsReaderSupportsSearch(CodecReader codecReader, String fieldName) {
27502979
KnnVectorsReader vectorsReader = codecReader.getVectorReader();
27512980
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader) {

0 commit comments

Comments
 (0)