diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 0bf2bac5db3a..4a9e41c0c093 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -44,6 +44,8 @@ API Changes * GITHUB#15131: Restrict visibility of TieredMergePolicy.score() API (Trevor McCulloch) +* GITHUB#15187: Restrict visibility of PerFieldKnnVectorsFormat.FieldsReader (Simon Cooper) + New Features --------------------- * GITHUB#14097: Binary partitioning merge policy over float-valued vector field. (Mike Sokolov) diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index d88e2d121d9c..b36f153eaeb4 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -39,7 +39,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -592,9 +591,7 @@ static boolean shouldRecomputeQuantiles( private static QuantizedVectorsReader getQuantizedKnnVectorsReader( KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { - vectorsReader = candidateReader.getFieldReader(fieldName); - } + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); if (vectorsReader instanceof QuantizedVectorsReader reader) { return reader; } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 1b23e1108d93..fb2ca112a0ab 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -30,7 +30,6 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; @@ -275,9 +274,7 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { LeafReader r = getOnlyLeafReader(reader); if (r instanceof CodecReader codecReader) { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } + knnVectorsReader = knnVectorsReader.unwrapReaderForField("f"); if (knnVectorsReader instanceof Lucene99HnswVectorsReader hnswReader) { assertNotNull(hnswReader.getQuantizationState("f")); QuantizedByteVectorValues quantizedByteVectorValues = diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswScalarQuantizedVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswScalarQuantizedVectorsFormat.java index f16c95c92e15..e2019719792f 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswScalarQuantizedVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99HnswScalarQuantizedVectorsFormat.java @@ -22,7 +22,6 @@ import java.io.IOException; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; @@ -52,9 +51,7 @@ public void testSimpleOffHeapSize() throws IOException { LeafReader r = getOnlyLeafReader(reader); if (r instanceof CodecReader codecReader) { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } + knnVectorsReader = knnVectorsReader.unwrapReaderForField("f"); var fieldInfo = r.getFieldInfos().fieldInfo("f"); var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java index 71fbff929ab3..9829fe4d5095 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -29,7 +29,6 @@ import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; @@ -217,9 +216,7 @@ private RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction function, LeafReader leafReader, float[] vector) throws IOException { if (leafReader instanceof CodecReader codecReader) { KnnVectorsReader format = codecReader.getVectorReader(); - if (format instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldFormat) { - format = perFieldFormat.getFieldReader("field"); - } + format = format.unwrapReaderForField("field"); if (format instanceof Lucene99HnswVectorsReader hnswFormat) { OffHeapQuantizedByteVectorValues quantizedByteVectorReader = (OffHeapQuantizedByteVectorValues) hnswFormat.getQuantizedVectorValues("field"); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index 609d2496d54e..ee9765f2ac0e 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -31,7 +31,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; @@ -195,9 +194,7 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { LeafReader r = getOnlyLeafReader(reader); if (r instanceof CodecReader codecReader) { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } + knnVectorsReader = knnVectorsReader.unwrapReaderForField("f"); if (knnVectorsReader instanceof Lucene99ScalarQuantizedVectorsReader quantizedReader) { assertNotNull(quantizedReader.getQuantizationState("f")); QuantizedByteVectorValues quantizedByteVectorValues = @@ -259,9 +256,7 @@ public void testReadQuantizedVectorWithEmptyRawVectors() throws Exception { LeafReader r = getOnlyLeafReader(reader); if (r instanceof CodecReader codecReader) { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader(vectorFieldName); - } + knnVectorsReader = knnVectorsReader.unwrapReaderForField(vectorFieldName); if (knnVectorsReader instanceof Lucene99ScalarQuantizedVectorsReader quantizedReader) { FloatVectorValues floatVectorValues = quantizedReader.getFloatVectorValues(vectorFieldName); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestInt7HnswBackwardsCompatibility.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestInt7HnswBackwardsCompatibility.java index 5858eab0e885..7a874ba5500d 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestInt7HnswBackwardsCompatibility.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestInt7HnswBackwardsCompatibility.java @@ -25,7 +25,6 @@ import org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; @@ -151,13 +150,10 @@ public void testIndexIsReallyQuantized() throws Exception { try (DirectoryReader reader = DirectoryReader.open(directory)) { for (LeafReaderContext leafContext : reader.leaves()) { KnnVectorsReader knnVectorsReader = ((CodecReader) leafContext.reader()).getVectorReader(); - assertTrue( - "expected PerFieldKnnVectorsFormat.FieldsReader but got: " + knnVectorsReader, - knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader); - KnnVectorsReader forField = - ((PerFieldKnnVectorsFormat.FieldsReader) knnVectorsReader) - .getFieldReader(KNN_VECTOR_FIELD); + KnnVectorsReader forField = knnVectorsReader.unwrapReaderForField(KNN_VECTOR_FIELD); + assertNotSame( + "Expected unwrapped field but got: " + knnVectorsReader, knnVectorsReader, forField); assertTrue(forField instanceof Lucene99HnswVectorsReader); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index 99b5c81fdff6..3eab9264895d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -49,6 +49,14 @@ protected KnnVectorsReader() {} */ public abstract void checkIntegrity() throws IOException; + /** + * If this reader wraps another for {@code field}, return the underlying reader, else return + * {@code this} + */ + public KnnVectorsReader unwrapReaderForField(String field) { + return this; + } + /** * Returns the {@link FloatVectorValues} for the given {@code field}. The behavior is undefined if * the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java index d881fdd5b28d..cb0a2e41da17 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java @@ -39,7 +39,6 @@ import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -552,9 +551,7 @@ public void close() throws IOException { } static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { - vectorsReader = candidateReader.getFieldReader(fieldName); - } + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); if (vectorsReader instanceof Lucene102BinaryQuantizedVectorsReader reader) { return reader.getCentroid(fieldName); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index c7a2f7a54de5..4238aed03600 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -35,7 +35,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -494,9 +493,7 @@ public void close() throws IOException { } static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { - vectorsReader = candidateReader.getFieldReader(fieldName); - } + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); if (vectorsReader instanceof Lucene104ScalarQuantizedVectorsReader reader) { return reader.getCentroid(fieldName); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index 6bf0a8888589..e58f57e75915 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -191,7 +191,7 @@ public long ramBytesUsed() { } /** VectorReader that can wrap multiple delegate readers, selected by field. */ - public static class FieldsReader extends KnnVectorsReader implements HnswGraphProvider { + private static class FieldsReader extends KnnVectorsReader implements HnswGraphProvider { private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final FieldInfos fieldInfos; @@ -259,17 +259,10 @@ public void finishMerge() throws IOException { } } - /** - * Return the underlying VectorReader for the given field - * - * @param field the name of a numeric vector field - */ - public KnnVectorsReader getFieldReader(String field) { - final FieldInfo info = fieldInfos.fieldInfo(field); - if (info == null) { - return null; - } - return fields.get(info.number); + @Override + public KnnVectorsReader unwrapReaderForField(String field) { + FieldInfo fi = fieldInfos.fieldInfo(field); + return fi != null ? fields.get(fi.number) : this; } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index 9ec132a6a4e1..dc3242722636 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -55,7 +55,6 @@ import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.DocumentStoredFieldVisitor; import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus; @@ -2914,7 +2913,7 @@ public static Status.HnswGraphsStatus testHnswGraphs( if (fieldInfos.hasVectorValues()) { for (FieldInfo fieldInfo : fieldInfos) { if (fieldInfo.hasVectorValues()) { - KnnVectorsReader fieldReader = getFieldReaderForName(vectorsReader, fieldInfo.name); + KnnVectorsReader fieldReader = vectorsReader.unwrapReaderForField(fieldInfo.name); if (fieldReader instanceof HnswGraphProvider graphProvider) { HnswGraph hnswGraph = graphProvider.getGraph(fieldInfo.name); testHnswGraph(hnswGraph, fieldInfo.name, status); @@ -2944,15 +2943,6 @@ public static Status.HnswGraphsStatus testHnswGraphs( return status; } - private static KnnVectorsReader getFieldReaderForName( - KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - return fieldsReader.getFieldReader(fieldName); - } else { - return vectorsReader; - } - } - private static void printHnswInfo( PrintStream infoStream, Map fieldsStatus) { for (Map.Entry entry : fieldsStatus.entrySet()) { @@ -3091,9 +3081,7 @@ private static IntIntHashMap getConnectedNodesOnLevel( private static boolean vectorsReaderSupportsSearch(CodecReader codecReader, String fieldName) { KnnVectorsReader vectorsReader = codecReader.getVectorReader(); - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader) { - vectorsReader = perFieldReader.getFieldReader(fieldName); - } + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); return (vectorsReader instanceof FlatVectorsReader) == false; } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java index 23c4276a0df1..599605364a64 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java @@ -29,7 +29,6 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -198,9 +197,7 @@ public void checkIntegrity() { public Map getOffHeapByteSize(FieldInfo fieldInfo) { SegmentReader segmentReader = segmentReader(reader); var vectorsReader = segmentReader.getVectorReader(); - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - vectorsReader = fieldsReader.getFieldReader(fieldInfo.name); - } + vectorsReader = vectorsReader.unwrapReaderForField(fieldInfo.name); return vectorsReader.getOffHeapByteSize(fieldInfo); } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java index f316048a04c0..ab1ea556acc7 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java @@ -25,7 +25,6 @@ import java.util.List; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.IndexReader; @@ -236,8 +235,7 @@ public static boolean graphIsRooted(IndexReader reader, String vectorField) thro for (LeafReaderContext ctx : reader.leaves()) { CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader()); KnnVectorsReader vectorsReader = - ((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader()) - .getFieldReader(vectorField); + codecReader.getVectorReader().unwrapReaderForField(vectorField); if (vectorsReader instanceof HnswGraphProvider) { HnswGraph graph = ((HnswGraphProvider) vectorsReader).getGraph(vectorField); if (isRooted(graph) == false) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java index 31799f096690..e58ab9b5ef78 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java @@ -31,7 +31,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; @@ -163,9 +162,7 @@ public void testSimpleOffHeapSize() throws IOException { LeafReader r = getOnlyLeafReader(reader); if (r instanceof CodecReader codecReader) { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } + knnVectorsReader = knnVectorsReader.unwrapReaderForField("f"); var fieldInfo = r.getFieldInfos().fieldInfo("f"); var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java index 7f01f1bbe852..6164b0062f02 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java @@ -31,7 +31,6 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; @@ -165,9 +164,7 @@ public void testSimpleOffHeapSize() throws IOException { LeafReader r = getOnlyLeafReader(reader); if (r instanceof CodecReader codecReader) { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } + knnVectorsReader = knnVectorsReader.unwrapReaderForField("f"); var fieldInfo = r.getFieldInfos().fieldInfo("f"); var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java index dccc51c35a76..f41298e0046c 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java @@ -27,7 +27,6 @@ import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; @@ -86,9 +85,7 @@ public void testSimpleOffHeapSize() throws IOException { LeafReader r = getOnlyLeafReader(reader); if (r instanceof CodecReader codecReader) { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } + knnVectorsReader = knnVectorsReader.unwrapReaderForField("f"); var fieldInfo = r.getFieldInfos().fieldInfo("f"); var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 7b05a018afac..d8b196e04df2 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -30,10 +30,10 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; @@ -374,17 +374,15 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect try (DirectoryReader dr = DirectoryReader.open(iw)) { for (LeafReaderContext ctx : dr.leaves()) { LeafReader reader = ctx.reader(); - PerFieldKnnVectorsFormat.FieldsReader perFieldReader = - (PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) reader).getVectorReader(); - if (perFieldReader == null) { + KnnVectorsReader knnFieldReader = ((CodecReader) reader).getVectorReader(); + if (knnFieldReader == null) { continue; } - Lucene99HnswVectorsReader vectorReader = - (Lucene99HnswVectorsReader) perFieldReader.getFieldReader(vectorField); - if (vectorReader == null) { + KnnVectorsReader vectorReader = knnFieldReader.unwrapReaderForField(vectorField); + if (!(vectorReader instanceof HnswGraphProvider graphProvider)) { continue; } - HnswGraph graphValues = vectorReader.getGraph(vectorField); + HnswGraph graphValues = graphProvider.getGraph(vectorField); FloatVectorValues vectorValues = reader.getFloatVectorValues(vectorField); if (vectorValues == null) { assert graphValues == null; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index be099398ba6a..fb0f507b7893 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -41,7 +41,6 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.NumericDocValuesField; @@ -270,9 +269,7 @@ public void testReadWrite() throws IOException { assertVectorsEqual(v3, values); HnswGraph graphValues = ((Lucene99HnswVectorsReader) - ((PerFieldKnnVectorsFormat.FieldsReader) - ((CodecReader) ctx.reader()).getVectorReader()) - .getFieldReader("field")) + ((CodecReader) ctx.reader()).getVectorReader().unwrapReaderForField("field")) .getGraph("field"); assertGraphEqual(hnsw, graphValues); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index eaff7a0ca9d5..37116ea82dad 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -2142,9 +2142,7 @@ protected void assertOffHeapByteSize(LeafReader r, String fieldName) throws IOEx if (r instanceof CodecReader codecReader) { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader(fieldName); - } + knnVectorsReader = knnVectorsReader.unwrapReaderForField(fieldName); var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); long totalByteSize = offHeap.values().stream().mapToLong(Long::longValue).sum(); if (knnVectorsReader instanceof SimpleTextKnnVectorsReader) {