diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java index 808d7b3cc882..78df640dddba 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java @@ -37,4 +37,12 @@ private FlatVectorScorerUtil() {} public static FlatVectorsScorer getLucene99FlatVectorsScorer() { return IMPL.getLucene99FlatVectorsScorer(); } + + /** + * Returns a FlatVectorsScorer that supports the quantized Lucene99 format. Scorers retrieved + * through this method may be optimized on certain platforms. + */ + public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return IMPL.getLucene99ScalarQuantizedVectorsScorer(); + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 0b3c6d19af83..f0ebc9ad7ac5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -18,10 +18,10 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -70,7 +70,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { final byte bits; final boolean compress; - final Lucene99ScalarQuantizedVectorScorer flatVectorScorer; + final FlatVectorsScorer flatVectorScorer; /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { @@ -117,8 +117,7 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; - this.flatVectorScorer = - new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer(); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index c5e9301e9bc4..2f77dd9c7945 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -19,6 +19,7 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.store.IndexInput; /** Default provider returning scalar implementations. */ @@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { return DefaultFlatVectorScorer.INSTANCE; } + @Override + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) { return new PostingDecodingUtil(input); diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index f725de389a6a..88a68c6f942f 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -122,6 +122,9 @@ public static VectorizationProvider getInstance() { /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ public abstract FlatVectorsScorer getLucene99FlatVectorsScorer(); + /** Returns a FlatVectorsScorer that supports the quantized Lucene99 format. */ + public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer(); + /** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */ public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException; diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedScorer.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedScorer.java new file mode 100644 index 000000000000..591120a4867b --- /dev/null +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedScorer.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.internal.vectorization; + +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; +import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; +import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.factory; +import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.getSegment; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.FloatToFloatFunction; +import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.MemorySegmentScorer; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +class Lucene99MemorySegmentScalarQuantizedScorer + extends RandomVectorScorer.AbstractRandomVectorScorer { + + private final VectorSimilarityFunction function; + private final QuantizedByteVectorValues values; + private final MemorySegmentAccessInput input; + private final MemorySegmentScorer scorer; + private final FloatToFloatFunction scaler; + private final float constMultiplier; + private final int vectorByteSize; + private final int entrySize; + private final MemorySegment query; + private final float queryOffset; + private final byte[][] docScratch; + + public Lucene99MemorySegmentScalarQuantizedScorer( + VectorSimilarityFunction function, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input, + float[] target) { + + super(values); + this.function = function; + this.values = values; + this.input = input; + this.scorer = factory(function, values, false); + this.scaler = factory(function); + + ScalarQuantizer quantizer = values.getScalarQuantizer(); + this.constMultiplier = quantizer.getConstantMultiplier(); + this.vectorByteSize = values.getVectorByteLength(); + this.entrySize = vectorByteSize + Float.BYTES; + + byte[] targetBytes = new byte[target.length]; + this.queryOffset = quantizeQuery(target, targetBytes, function, quantizer); + this.query = Arena.ofAuto().allocateFrom(JAVA_BYTE, targetBytes); + + this.docScratch = new byte[1][]; + } + + @Override + public float score(int node) throws IOException { + MemorySegment segment = getSegment(input, entrySize, node, docScratch); + MemorySegment doc = segment.reinterpret(vectorByteSize); + float docOffset = segment.get(JAVA_FLOAT, vectorByteSize); + return scaler.scale(scorer.score(query, doc) * constMultiplier + queryOffset + docOffset); + } +} diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedScorerSupplier.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedScorerSupplier.java new file mode 100644 index 000000000000..e5dbc8bfab8f --- /dev/null +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedScorerSupplier.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.internal.vectorization; + +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; +import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.factory; +import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.getSegment; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.FloatToFloatFunction; +import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.MemorySegmentScorer; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; + +class Lucene99MemorySegmentScalarQuantizedScorerSupplier implements RandomVectorScorerSupplier { + + private final VectorSimilarityFunction function; + private final QuantizedByteVectorValues values; + private final MemorySegmentAccessInput input; + private final MemorySegmentScorer scorer; + private final FloatToFloatFunction scaler; + private final float constMultiplier; + private final int vectorByteSize; + private final int entrySize; + + public Lucene99MemorySegmentScalarQuantizedScorerSupplier( + VectorSimilarityFunction function, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input) { + + this.function = function; + this.values = values; + this.input = input; + this.scorer = factory(function, values, true); + this.scaler = factory(function); + this.constMultiplier = values.getScalarQuantizer().getConstantMultiplier(); + this.vectorByteSize = values.getVectorByteLength(); + this.entrySize = vectorByteSize + Float.BYTES; + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + + private final MemorySegment[] doc = new MemorySegment[1]; + private final float[] docOffset = new float[1]; + private final byte[][] docScratch = new byte[1][]; + private final byte[][] queryScratch = new byte[1][]; + + @Override + public void setScoringOrdinal(int node) throws IOException { + MemorySegment segment = getSegment(input, entrySize, node, docScratch); + doc[0] = segment.reinterpret(vectorByteSize); + docOffset[0] = segment.get(JAVA_FLOAT, vectorByteSize); + } + + @Override + public float score(int node) throws IOException { + MemorySegment segment = getSegment(input, entrySize, node, queryScratch); + MemorySegment query = segment.reinterpret(vectorByteSize); + float queryOffset = segment.get(JAVA_FLOAT, vectorByteSize); + return scaler.scale( + scorer.score(query, doc[0]) * constMultiplier + queryOffset + docOffset[0]); + } + }; + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new Lucene99MemorySegmentScalarQuantizedScorerSupplier(function, values, input); + } +} diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..b9c91ed3a8ce --- /dev/null +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; + +public class Lucene99MemorySegmentScalarQuantizedVectorScorer implements FlatVectorsScorer { + + public static final Lucene99MemorySegmentScalarQuantizedVectorScorer INSTANCE = + new Lucene99MemorySegmentScalarQuantizedVectorScorer(); + + private static final FlatVectorsScorer NON_QUANTIZED_DELEGATE = + Lucene99MemorySegmentFlatVectorsScorer.INSTANCE; + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues values + && values.getSlice() instanceof MemorySegmentAccessInput input) { + return new Lucene99MemorySegmentScalarQuantizedScorerSupplier( + similarityFunction, values, input); + } + // It is possible to get to this branch during initial indexing and flush + return NON_QUANTIZED_DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues values + && values.getSlice() instanceof MemorySegmentAccessInput input) { + checkDimensions(target.length, vectorValues.dimension()); + return new Lucene99MemorySegmentScalarQuantizedScorer( + similarityFunction, values, input, target); + } + // It is possible to get to this branch during initial indexing and flush + return NON_QUANTIZED_DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + return NON_QUANTIZED_DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public String toString() { + return getClass().getSimpleName() + "()"; + } + + private static void checkDimensions(int queryLen, int fieldLen) { + if (queryLen != fieldLen) { + throw new IllegalArgumentException( + "vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen); + } + } + + static MemorySegment getSegment( + MemorySegmentAccessInput input, int entrySize, int node, byte[][] scratch) + throws IOException { + long pos = (long) entrySize * node; + MemorySegment segment = input.segmentSliceOrNull(pos, entrySize); + if (segment == null) { + if (scratch[0] == null) { + scratch[0] = new byte[entrySize]; + } + input.readBytes(pos, scratch[0], 0, entrySize); + segment = MemorySegment.ofArray(scratch[0]); + } + return segment; + } + + @FunctionalInterface + interface MemorySegmentScorer { + float score(MemorySegment query, MemorySegment doc); + } + + @FunctionalInterface + interface FloatToFloatFunction { + float scale(float score); + } + + static MemorySegmentScorer factory( + VectorSimilarityFunction function, + QuantizedByteVectorValues values, + boolean isScorerSupplier) { + return switch (function) { + case EUCLIDEAN -> { + if (values.getScalarQuantizer().getBits() < 7) { + // TODO + throw new UnsupportedOperationException(); + } + yield PanamaVectorUtilSupport::squareDistance; + } + case DOT_PRODUCT, COSINE, MAXIMUM_INNER_PRODUCT -> { + if (values.getScalarQuantizer().getBits() <= 4) { + if (values.getVectorByteLength() != values.dimension()) { + if (isScorerSupplier) { + yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, true, doc, true); + } + yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, false, doc, true); + } + yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, false, doc, false); + } + yield PanamaVectorUtilSupport::dotProduct; + } + }; + } + + static FloatToFloatFunction factory(VectorSimilarityFunction function) { + return switch (function) { + case EUCLIDEAN -> score -> (1 / (1f + score)); + case DOT_PRODUCT, COSINE -> score -> Math.max((1f + score) / 2, 0); + case MAXIMUM_INNER_PRODUCT -> VectorUtil::scaleMaxInnerProductScore; + }; + } +} diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 9aa4bc09e348..9fa9cffdbbdd 100644 --- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -396,49 +396,74 @@ private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit @Override public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { - assert (apacked && bpacked) == false; + return int4DotProduct(MemorySegment.ofArray(a), apacked, MemorySegment.ofArray(b), bpacked); + } + + public static int int4DotProduct( + MemorySegment a, boolean apacked, MemorySegment b, boolean bpacked) { int i = 0; int res = 0; - if (apacked || bpacked) { - byte[] packed = apacked ? a : b; - byte[] unpacked = apacked ? b : a; - if (packed.length >= 32) { + if (apacked && bpacked) { + if (a.byteSize() >= 32) { if (VECTOR_BITSIZE >= 512) { - i += ByteVector.SPECIES_256.loopBound(packed.length); - res += dotProductBody512Int4Packed(unpacked, packed, i); + i += ByteVector.SPECIES_256.loopBound(a.byteSize()); + res += dotProductBody512Int4BothPacked(a, b, i); } else if (VECTOR_BITSIZE == 256) { - i += ByteVector.SPECIES_128.loopBound(packed.length); - res += dotProductBody256Int4Packed(unpacked, packed, i); + i += ByteVector.SPECIES_128.loopBound(a.byteSize()); + res += dotProductBody256Int4BothPacked(a, b, i); } else if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_64.loopBound(packed.length); - res += dotProductBody128Int4Packed(unpacked, packed, i); + i += ByteVector.SPECIES_64.loopBound(a.byteSize()); + res += dotProductBody128Int4BothPacked(a, b, i); } } // scalar tail - for (; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; + for (; i < a.byteSize(); i++) { + byte packedByte1 = a.get(JAVA_BYTE, i); + byte packedByte2 = b.get(JAVA_BYTE, i); + res += (packedByte1 & 0x0F) * (packedByte2 & 0x0F); + res += ((packedByte1 & 0xFF) >> 4) * ((packedByte2 & 0xFF) >> 4); + } + } else if (apacked || bpacked) { + MemorySegment packed = apacked ? a : b; + MemorySegment unpacked = apacked ? b : a; + if (packed.byteSize() >= 32) { + if (VECTOR_BITSIZE >= 512) { + i += ByteVector.SPECIES_256.loopBound(packed.byteSize()); + res += dotProductBody512Int4SinglePacked(unpacked, packed, i); + } else if (VECTOR_BITSIZE == 256) { + i += ByteVector.SPECIES_128.loopBound(packed.byteSize()); + res += dotProductBody256Int4SinglePacked(unpacked, packed, i); + } else if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_64.loopBound(packed.byteSize()); + res += dotProductBody128Int4SinglePacked(unpacked, packed, i); + } + } + // scalar tail + for (; i < packed.byteSize(); i++) { + byte packedByte = packed.get(JAVA_BYTE, i); + byte unpacked1 = unpacked.get(JAVA_BYTE, i); + byte unpacked2 = unpacked.get(JAVA_BYTE, i + packed.byteSize()); res += (packedByte & 0x0F) * unpacked2; res += ((packedByte & 0xFF) >> 4) * unpacked1; } } else { if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { return dotProduct(a, b); - } else if (a.length >= 32 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_128.loopBound(a.length); + } else if (a.byteSize() >= 32 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_128.loopBound(a.byteSize()); res += int4DotProductBody128(a, b, i); } // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; + for (; i < a.byteSize(); i++) { + res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i); } } return res; } - private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody512Int4SinglePacked( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 4096) { @@ -447,9 +472,12 @@ private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limi int innerLimit = Math.min(limit - i, 4096); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + var vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j, LITTLE_ENDIAN); // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + var va8 = + ByteVector.fromMemorySegment( + ByteVector.SPECIES_256, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -457,7 +485,8 @@ private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limi acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector vc8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j, LITTLE_ENDIAN); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); acc1 = acc1.add(prod16a); @@ -471,7 +500,38 @@ private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limi return sum; } - private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody512Int4BothPacked(MemorySegment a, MemorySegment b, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 4096) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512); + int innerLimit = Math.min(limit - i, 4096); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { + var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, a, i + j, LITTLE_ENDIAN); + var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, b, i + j, LITTLE_ENDIAN); + + // upper + ByteVector prod8 = va8.and((byte) 0x0F).mul(vb8.and((byte) 0x0F)); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector prod8a = va8.lanewise(LSHR, 4).mul(vb8.lanewise(LSHR, 4)); + Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + private static int dotProductBody256Int4SinglePacked( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 2048) { @@ -480,9 +540,12 @@ private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limi int innerLimit = Math.min(limit - i, 2048); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); + var vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j, LITTLE_ENDIAN); // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + var va8 = + ByteVector.fromMemorySegment( + ByteVector.SPECIES_128, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -490,7 +553,8 @@ private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limi acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector vc8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j, LITTLE_ENDIAN); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); acc1 = acc1.add(prod16a); @@ -504,8 +568,39 @@ private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limi return sum; } + private static int dotProductBody256Int4BothPacked(MemorySegment a, MemorySegment b, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 4096) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256); + int innerLimit = Math.min(limit - i, 4096); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { + var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, a, i + j, LITTLE_ENDIAN); + var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, b, i + j, LITTLE_ENDIAN); + + // upper + ByteVector prod8 = va8.and((byte) 0x0F).mul(vb8.and((byte) 0x0F)); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector prod8a = va8.lanewise(LSHR, 4).mul(vb8.lanewise(LSHR, 4)); + Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + /** vectorized dot product body (128 bit vectors) */ - private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody128Int4SinglePacked( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 1024) { @@ -514,10 +609,12 @@ private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limi int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { // packed - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); + ByteVector vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j, LITTLE_ENDIAN); // unpacked ByteVector va8 = - ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); + ByteVector.fromMemorySegment( + ByteVector.SPECIES_64, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -526,7 +623,7 @@ private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limi acc0 = acc0.add(prod16.and((short) 0xFF)); // lower - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j, LITTLE_ENDIAN); prod8 = vb8.lanewise(LSHR, 4).mul(va8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); @@ -540,7 +637,37 @@ private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limi return sum; } - private int int4DotProductBody128(byte[] a, byte[] b, int limit) { + private static int dotProductBody128Int4BothPacked(MemorySegment a, MemorySegment b, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 4096) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); + int innerLimit = Math.min(limit - i, 4096); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { + var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j, LITTLE_ENDIAN); + var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j, LITTLE_ENDIAN); + + // upper + ByteVector prod8 = va8.and((byte) 0x0F).mul(vb8.and((byte) 0x0F)); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_128, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector prod8a = va8.lanewise(LSHR, 4).mul(vb8.lanewise(LSHR, 4)); + Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_128, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + private static int int4DotProductBody128(MemorySegment a, MemorySegment b, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 1024) { @@ -548,15 +675,17 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) { ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j); + ByteVector va8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j, LITTLE_ENDIAN); + ByteVector vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j, LITTLE_ENDIAN); ByteVector prod8 = va8.mul(vb8); ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc0 = acc0.add(prod16.and((short) 0xFF)); - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8); - vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8, LITTLE_ENDIAN); + vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8, LITTLE_ENDIAN); prod8 = va8.mul(vb8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); diff --git a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index 2e061ee6f5fc..4b77ee07a40a 100644 --- a/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -67,6 +67,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE; } + @Override + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return Lucene99MemorySegmentScalarQuantizedVectorScorer.INSTANCE; + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException { if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS