|
26 | 26 | import java.nio.file.Path;
|
27 | 27 | import java.nio.file.Paths;
|
28 | 28 | import java.text.NumberFormat;
|
| 29 | +import java.util.ArrayDeque; |
29 | 30 | import java.util.ArrayList;
|
30 | 31 | import java.util.Arrays;
|
| 32 | +import java.util.Collections; |
| 33 | +import java.util.Deque; |
31 | 34 | import java.util.HashMap;
|
32 | 35 | import java.util.Iterator;
|
33 | 36 | import java.util.List;
|
|
52 | 55 | import org.apache.lucene.codecs.StoredFieldsReader;
|
53 | 56 | import org.apache.lucene.codecs.TermVectorsReader;
|
54 | 57 | import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
| 58 | +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; |
55 | 59 | import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
56 | 60 | import org.apache.lucene.document.Document;
|
57 | 61 | import org.apache.lucene.document.DocumentStoredFieldVisitor;
|
58 | 62 | import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus;
|
59 | 63 | import org.apache.lucene.index.PointValues.IntersectVisitor;
|
60 | 64 | import org.apache.lucene.index.PointValues.Relation;
|
| 65 | +import org.apache.lucene.internal.hppc.IntIntHashMap; |
61 | 66 | import org.apache.lucene.search.DocIdSetIterator;
|
62 | 67 | import org.apache.lucene.search.FieldExistsQuery;
|
63 | 68 | import org.apache.lucene.search.KnnCollector;
|
|
74 | 79 | import org.apache.lucene.store.Lock;
|
75 | 80 | import org.apache.lucene.util.ArrayUtil;
|
76 | 81 | import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
|
| 82 | +import org.apache.lucene.util.BitSet; |
77 | 83 | import org.apache.lucene.util.Bits;
|
78 | 84 | import org.apache.lucene.util.BytesRef;
|
79 | 85 | import org.apache.lucene.util.BytesRefBuilder;
|
|
91 | 97 | import org.apache.lucene.util.automaton.ByteRunAutomaton;
|
92 | 98 | import org.apache.lucene.util.automaton.CompiledAutomaton;
|
93 | 99 | import org.apache.lucene.util.automaton.Operations;
|
| 100 | +import org.apache.lucene.util.hnsw.HnswGraph; |
94 | 101 |
|
95 | 102 | /**
|
96 | 103 | * Basic tool and API to check the health of an index and write a new segments file that removes
|
@@ -249,6 +256,9 @@ public static class SegmentInfoStatus {
|
249 | 256 | /** Status of vectors */
|
250 | 257 | public VectorValuesStatus vectorValuesStatus;
|
251 | 258 |
|
| 259 | + /** Status of HNSW graph */ |
| 260 | + public HnswGraphsStatus hnswGraphsStatus; |
| 261 | + |
252 | 262 | /** Status of soft deletes */
|
253 | 263 | public SoftDeletesStatus softDeletesStatus;
|
254 | 264 |
|
@@ -406,6 +416,32 @@ public static final class VectorValuesStatus {
|
406 | 416 | public Throwable error;
|
407 | 417 | }
|
408 | 418 |
|
| 419 | + /** Status from testing a single HNSW graph */ |
| 420 | + public static final class HnswGraphStatus { |
| 421 | + |
| 422 | + HnswGraphStatus() {} |
| 423 | + |
| 424 | + /** Number of nodes at each level */ |
| 425 | + public List<Integer> numNodesAtLevel; |
| 426 | + |
| 427 | + /** Connectedness at each level represented as a fraction */ |
| 428 | + public List<String> connectednessAtLevel; |
| 429 | + } |
| 430 | + |
| 431 | + /** Status from testing all HNSW graphs */ |
| 432 | + public static final class HnswGraphsStatus { |
| 433 | + |
| 434 | + HnswGraphsStatus() { |
| 435 | + this.hnswGraphsStatusByField = new HashMap<>(); |
| 436 | + } |
| 437 | + |
| 438 | + /** Status of the HNSW graph keyed with field name */ |
| 439 | + public Map<String, HnswGraphStatus> hnswGraphsStatusByField; |
| 440 | + |
| 441 | + /** Exception thrown during term index test (null on success) */ |
| 442 | + public Throwable error; |
| 443 | + } |
| 444 | + |
409 | 445 | /** Status from testing index sort */
|
410 | 446 | public static final class IndexSortStatus {
|
411 | 447 | IndexSortStatus() {}
|
@@ -1085,6 +1121,9 @@ private Status.SegmentInfoStatus testSegment(
|
1085 | 1121 | // Test FloatVectorValues and ByteVectorValues
|
1086 | 1122 | segInfoStat.vectorValuesStatus = testVectors(reader, infoStream, failFast);
|
1087 | 1123 |
|
| 1124 | + // Test HNSW graph |
| 1125 | + segInfoStat.hnswGraphsStatus = testHnswGraphs(reader, infoStream, failFast); |
| 1126 | + |
1088 | 1127 | // Test Index Sort
|
1089 | 1128 | if (indexSort != null) {
|
1090 | 1129 | segInfoStat.indexSortStatus = testSort(reader, indexSort, infoStream, failFast);
|
@@ -2746,6 +2785,196 @@ public static Status.VectorValuesStatus testVectors(
|
2746 | 2785 | return status;
|
2747 | 2786 | }
|
2748 | 2787 |
|
| 2788 | + /** Test the HNSW graph. */ |
| 2789 | + public static Status.HnswGraphsStatus testHnswGraphs( |
| 2790 | + CodecReader reader, PrintStream infoStream, boolean failFast) throws IOException { |
| 2791 | + if (infoStream != null) { |
| 2792 | + infoStream.print(" test: hnsw graphs........."); |
| 2793 | + } |
| 2794 | + long startNS = System.nanoTime(); |
| 2795 | + Status.HnswGraphsStatus status = new Status.HnswGraphsStatus(); |
| 2796 | + KnnVectorsReader vectorsReader = reader.getVectorReader(); |
| 2797 | + FieldInfos fieldInfos = reader.getFieldInfos(); |
| 2798 | + |
| 2799 | + try { |
| 2800 | + if (fieldInfos.hasVectorValues()) { |
| 2801 | + for (FieldInfo fieldInfo : fieldInfos) { |
| 2802 | + if (fieldInfo.hasVectorValues()) { |
| 2803 | + KnnVectorsReader fieldReader = getFieldReaderForName(vectorsReader, fieldInfo.name); |
| 2804 | + if (fieldReader instanceof HnswGraphProvider graphProvider) { |
| 2805 | + HnswGraph hnswGraph = graphProvider.getGraph(fieldInfo.name); |
| 2806 | + testHnswGraph(hnswGraph, fieldInfo.name, status); |
| 2807 | + } |
| 2808 | + } |
| 2809 | + } |
| 2810 | + } |
| 2811 | + msg( |
| 2812 | + infoStream, |
| 2813 | + String.format( |
| 2814 | + Locale.ROOT, |
| 2815 | + "OK [%d fields] [took %.3f sec]", |
| 2816 | + status.hnswGraphsStatusByField.size(), |
| 2817 | + nsToSec(System.nanoTime() - startNS))); |
| 2818 | + printHnswInfo(infoStream, status.hnswGraphsStatusByField); |
| 2819 | + } catch (Exception e) { |
| 2820 | + if (failFast) { |
| 2821 | + throw IOUtils.rethrowAlways(e); |
| 2822 | + } |
| 2823 | + msg(infoStream, "ERROR: " + e); |
| 2824 | + status.error = e; |
| 2825 | + if (infoStream != null) { |
| 2826 | + e.printStackTrace(infoStream); |
| 2827 | + } |
| 2828 | + } |
| 2829 | + |
| 2830 | + return status; |
| 2831 | + } |
| 2832 | + |
| 2833 | + private static KnnVectorsReader getFieldReaderForName( |
| 2834 | + KnnVectorsReader vectorsReader, String fieldName) { |
| 2835 | + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { |
| 2836 | + return fieldsReader.getFieldReader(fieldName); |
| 2837 | + } else { |
| 2838 | + return vectorsReader; |
| 2839 | + } |
| 2840 | + } |
| 2841 | + |
| 2842 | + private static void printHnswInfo( |
| 2843 | + PrintStream infoStream, Map<String, CheckIndex.Status.HnswGraphStatus> fieldsStatus) { |
| 2844 | + for (Map.Entry<String, CheckIndex.Status.HnswGraphStatus> entry : fieldsStatus.entrySet()) { |
| 2845 | + String fieldName = entry.getKey(); |
| 2846 | + CheckIndex.Status.HnswGraphStatus status = entry.getValue(); |
| 2847 | + msg(infoStream, " hnsw field name: " + fieldName); |
| 2848 | + |
| 2849 | + int numLevels = Math.min(status.numNodesAtLevel.size(), status.connectednessAtLevel.size()); |
| 2850 | + for (int level = numLevels - 1; level >= 0; level--) { |
| 2851 | + int numNodes = status.numNodesAtLevel.get(level); |
| 2852 | + String connectedness = status.connectednessAtLevel.get(level); |
| 2853 | + msg( |
| 2854 | + infoStream, |
| 2855 | + String.format( |
| 2856 | + Locale.ROOT, |
| 2857 | + " level %d: %d nodes, %s connected", |
| 2858 | + level, |
| 2859 | + numNodes, |
| 2860 | + connectedness)); |
| 2861 | + } |
| 2862 | + } |
| 2863 | + } |
| 2864 | + |
| 2865 | + private static void testHnswGraph( |
| 2866 | + HnswGraph hnswGraph, String fieldName, Status.HnswGraphsStatus status) |
| 2867 | + throws IOException, CheckIndexException { |
| 2868 | + if (hnswGraph != null) { |
| 2869 | + status.hnswGraphsStatusByField.put(fieldName, new Status.HnswGraphStatus()); |
| 2870 | + final int numLevels = hnswGraph.numLevels(); |
| 2871 | + status.hnswGraphsStatusByField.get(fieldName).numNodesAtLevel = |
| 2872 | + new ArrayList<>(Collections.nCopies(numLevels, null)); |
| 2873 | + status.hnswGraphsStatusByField.get(fieldName).connectednessAtLevel = |
| 2874 | + new ArrayList<>(Collections.nCopies(numLevels, null)); |
| 2875 | + // Perform checks on each level of the HNSW graph |
| 2876 | + for (int level = numLevels - 1; level >= 0; level--) { |
| 2877 | + // Collect BitSet of all nodes on this level |
| 2878 | + BitSet nodesOnThisLevel = new FixedBitSet(hnswGraph.size()); |
| 2879 | + HnswGraph.NodesIterator nodesIterator = hnswGraph.getNodesOnLevel(level); |
| 2880 | + while (nodesIterator.hasNext()) { |
| 2881 | + nodesOnThisLevel.set(nodesIterator.nextInt()); |
| 2882 | + } |
| 2883 | + |
| 2884 | + nodesIterator = hnswGraph.getNodesOnLevel(level); |
| 2885 | + // Perform checks on each node on the level |
| 2886 | + while (nodesIterator.hasNext()) { |
| 2887 | + int node = nodesIterator.nextInt(); |
| 2888 | + if (node < 0 || node > hnswGraph.size() - 1) { |
| 2889 | + throw new CheckIndexException( |
| 2890 | + "Field \"" |
| 2891 | + + fieldName |
| 2892 | + + "\" has node: " |
| 2893 | + + node |
| 2894 | + + " not in the expected range [0, " |
| 2895 | + + (hnswGraph.size() - 1) |
| 2896 | + + "]"); |
| 2897 | + } |
| 2898 | + |
| 2899 | + // Perform checks on the node's neighbors |
| 2900 | + hnswGraph.seek(level, node); |
| 2901 | + int nbr, lastNeighbor = -1, firstNeighbor = -1; |
| 2902 | + while ((nbr = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) { |
| 2903 | + if (!nodesOnThisLevel.get(nbr)) { |
| 2904 | + throw new CheckIndexException( |
| 2905 | + "Field \"" |
| 2906 | + + fieldName |
| 2907 | + + "\" has node: " |
| 2908 | + + node |
| 2909 | + + " with a neighbor " |
| 2910 | + + nbr |
| 2911 | + + " which is not on its level (" |
| 2912 | + + level |
| 2913 | + + ")"); |
| 2914 | + } |
| 2915 | + if (firstNeighbor == -1) { |
| 2916 | + firstNeighbor = nbr; |
| 2917 | + } |
| 2918 | + if (nbr < lastNeighbor) { |
| 2919 | + throw new CheckIndexException( |
| 2920 | + "Field \"" |
| 2921 | + + fieldName |
| 2922 | + + "\" has neighbors out of order for node " |
| 2923 | + + node |
| 2924 | + + ": " |
| 2925 | + + nbr |
| 2926 | + + "<" |
| 2927 | + + lastNeighbor |
| 2928 | + + " 1st=" |
| 2929 | + + firstNeighbor); |
| 2930 | + } else if (nbr == lastNeighbor) { |
| 2931 | + throw new CheckIndexException( |
| 2932 | + "Field \"" |
| 2933 | + + fieldName |
| 2934 | + + "\" has repeated neighbors of node " |
| 2935 | + + node |
| 2936 | + + " with value " |
| 2937 | + + nbr); |
| 2938 | + } |
| 2939 | + lastNeighbor = nbr; |
| 2940 | + } |
| 2941 | + } |
| 2942 | + int numNodesOnLayer = nodesIterator.size(); |
| 2943 | + status.hnswGraphsStatusByField.get(fieldName).numNodesAtLevel.set(level, numNodesOnLayer); |
| 2944 | + |
| 2945 | + // Evaluate connectedness at this level by measuring the number of nodes reachable from the |
| 2946 | + // entry point |
| 2947 | + IntIntHashMap connectedNodes = getConnectedNodesOnLevel(hnswGraph, numNodesOnLayer, level); |
| 2948 | + status |
| 2949 | + .hnswGraphsStatusByField |
| 2950 | + .get(fieldName) |
| 2951 | + .connectednessAtLevel |
| 2952 | + .set(level, connectedNodes.size() + "/" + numNodesOnLayer); |
| 2953 | + } |
| 2954 | + } |
| 2955 | + } |
| 2956 | + |
| 2957 | + private static IntIntHashMap getConnectedNodesOnLevel( |
| 2958 | + HnswGraph hnswGraph, int numNodesOnLayer, int level) throws IOException { |
| 2959 | + IntIntHashMap connectedNodes = new IntIntHashMap(numNodesOnLayer); |
| 2960 | + int entryPoint = hnswGraph.entryNode(); |
| 2961 | + Deque<Integer> stack = new ArrayDeque<>(); |
| 2962 | + stack.push(entryPoint); |
| 2963 | + while (!stack.isEmpty()) { |
| 2964 | + int node = stack.pop(); |
| 2965 | + if (connectedNodes.containsKey(node)) { |
| 2966 | + continue; |
| 2967 | + } |
| 2968 | + connectedNodes.put(node, 1); |
| 2969 | + hnswGraph.seek(level, node); |
| 2970 | + int friendOrd; |
| 2971 | + while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) { |
| 2972 | + stack.push(friendOrd); |
| 2973 | + } |
| 2974 | + } |
| 2975 | + return connectedNodes; |
| 2976 | + } |
| 2977 | + |
2749 | 2978 | private static boolean vectorsReaderSupportsSearch(CodecReader codecReader, String fieldName) {
|
2750 | 2979 | KnnVectorsReader vectorsReader = codecReader.getVectorReader();
|
2751 | 2980 | if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader) {
|
|
0 commit comments