diff --git a/ACKNOWLEDGEMENTS b/ACKNOWLEDGEMENTS index 9284e47647..321141eabd 100644 --- a/ACKNOWLEDGEMENTS +++ b/ACKNOWLEDGEMENTS @@ -216,3 +216,27 @@ Unicode, Inc (ICU4J) Creative Commons Attribution 4.0 License (GeoNames) https://creativecommons.org/licenses/by/4.0/ + +Christian Heina (HALF4J) + + Copyright 2023 Christian Heina + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +Jianyang Gao, Yutong Gou, Yuexuan Xu, Yongyi Yang, Cheng Long, Raymond Chi-Wing Wong, + "Practical and Asymptotically Optimal Quantization of High-Dimensional Vectors in Euclidean Space for + Approximate Nearest Neighbor Search", + SIGMOD 2025, available at https://arxiv.org/abs/2409.09913 + +Yutong Gou, Jianyang Gao, Yuexuan Xu, Jifan Shi and Zhonghao Yang + https://github.com/VectorDB-NTU/RaBitQ-Library/blob/main/LICENSE diff --git a/fdb-extensions/fdb-extensions.gradle b/fdb-extensions/fdb-extensions.gradle index 137e13eb96..9d35bf31f2 100644 --- a/fdb-extensions/fdb-extensions.gradle +++ b/fdb-extensions/fdb-extensions.gradle @@ -42,6 +42,38 @@ dependencies { testFixturesAnnotationProcessor(libs.autoService) } +def siftSmallFile = layout.buildDirectory.file('downloads/siftsmall.tar.gz') +def extractDir = layout.buildDirectory.dir("extracted") + +// Task that downloads the CSV exactly once unless it changed +tasks.register('downloadSiftSmall', de.undercouch.gradle.tasks.download.Download) { + src 'https://huggingface.co/datasets/vecdata/siftsmall/resolve/3106e1b83049c44713b1ce06942d0ab474bbdfb6/siftsmall.tar.gz' + dest siftSmallFile.get().asFile + onlyIfModified true + tempAndMove true + retries 3 +} + +tasks.register('extractSiftSmall', Copy) { + dependsOn 'downloadSiftSmall' + from(tarTree(resources.gzip(siftSmallFile))) + into extractDir + + doLast { + println "Extracted files into: ${extractDir.get().asFile}" + fileTree(extractDir).visit { details -> + if (!details.isDirectory()) { + println " - ${details.file}" + } + } + } +} + +test { + dependsOn tasks.named('extractSiftSmall') + inputs.dir extractDir +} + publishing { publications { library(MavenPublication) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java index 563dec11a6..e696512fdd 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java @@ -23,12 +23,14 @@ import com.apple.foundationdb.annotation.API; import com.apple.foundationdb.util.LoggableException; import com.google.common.base.Suppliers; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.ThreadFactoryBuilder; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -42,9 +44,13 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.IntPredicate; +import java.util.function.IntUnaryOperator; import java.util.function.Predicate; import java.util.function.Supplier; @@ -1051,6 +1057,93 @@ public static CompletableFuture swallowException(@Nonnull CompletableFutur return result; } + /** + * Method that provides the functionality of a for loop, however, in an asynchronous way. The result of this method + * is a {@link CompletableFuture} that represents the result of the last iteration of the loop body. + * @param startI an integer analogous to the starting value of a loop variable in a for loop + * @param startU an object of some type {@code U} that represents some initial state that is passed to the loop's + * initial state + * @param conditionPredicate a predicate on the loop variable that must be true before the next iteration is + * entered; analogous to the condition in a for loop + * @param stepFunction a unary operator used for modifying the loop variable after each iteration + * @param body a bi-function to be called for each iteration; this function is initially invoked using + * {@code startI} and {@code startU}; the result of the body is then passed into the next iterator's body + * together with a new value for the loop variable. In this way callers can access state inside an iteration + * that was computed in a previous iteration. + * @param executor the executor + * @param the type of the result of the body {@link BiFunction} + * @return a {@link CompletableFuture} containing the result of the last iteration's body invocation. + */ + @Nonnull + public static CompletableFuture forLoop(final int startI, @Nullable final U startU, + @Nonnull final IntPredicate conditionPredicate, + @Nonnull final IntUnaryOperator stepFunction, + @Nonnull final BiFunction> body, + @Nonnull final Executor executor) { + final AtomicInteger loopVariableAtomic = new AtomicInteger(startI); + final AtomicReference lastResultAtomic = new AtomicReference<>(startU); + return whileTrue(() -> { + final int loopVariable = loopVariableAtomic.get(); + if (!conditionPredicate.test(loopVariable)) { + return AsyncUtil.READY_FALSE; + } + return body.apply(loopVariable, lastResultAtomic.get()) + .thenApply(result -> { + loopVariableAtomic.set(stepFunction.applyAsInt(loopVariable)); + lastResultAtomic.set(result); + return true; + }); + }, executor).thenApply(ignored -> lastResultAtomic.get()); + } + + /** + * Method to iterate over some items, for each of which a body is executed asynchronously. The result of each such + * executed is then collected in a list and returned as a {@link CompletableFuture} over that list. + * @param items the items to iterate over + * @param body a function to be called for each item + * @param parallelism the maximum degree of parallelism this method should use + * @param executor the executor + * @param the type of item + * @param the type of the result + * @return a {@link CompletableFuture} containing a list of results collected from the individual body invocations + */ + @Nonnull + @SuppressWarnings("unchecked") + public static CompletableFuture> forEach(@Nonnull final Iterable items, + @Nonnull final Function> body, + final int parallelism, + @Nonnull final Executor executor) { + // this deque is only modified by once upon creation + final ArrayDeque toBeProcessed = new ArrayDeque<>(); + for (final T item : items) { + toBeProcessed.addLast(item); + } + + final List> working = Lists.newArrayList(); + final AtomicInteger indexAtomic = new AtomicInteger(0); + final Object[] resultArray = new Object[toBeProcessed.size()]; + + return whileTrue(() -> { + working.removeIf(CompletableFuture::isDone); + + while (working.size() <= parallelism) { + final T currentItem = toBeProcessed.pollFirst(); + if (currentItem == null) { + break; + } + + final int index = indexAtomic.getAndIncrement(); + working.add(body.apply(currentItem) + .thenAccept(result -> resultArray[index] = result)); + } + + if (working.isEmpty()) { + return AsyncUtil.READY_FALSE; + } + return whenAny(working).thenApply(ignored -> true); + }, executor).thenApply(ignored -> Arrays.asList((U[])resultArray)); + } + /** * A {@code Boolean} function that is always true. * @param the type of the (ignored) argument to the function diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java index f60c17da63..2623cff1dc 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java @@ -36,7 +36,6 @@ * Storage adapter used for serialization and deserialization of nodes. */ interface StorageAdapter { - /** * Get the {@link RTree.Config} associated with this storage adapter. * @return the configuration used by this storage adapter diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/half/Half.java b/fdb-extensions/src/main/java/com/apple/foundationdb/half/Half.java new file mode 100644 index 0000000000..80fe8eaef5 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/half/Half.java @@ -0,0 +1,853 @@ +/* + * Copyright 2023 Christian Heina + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications Copyright 2015-2025 Apple Inc. and the FoundationDB project authors. + * This source file is part of the FoundationDB open source project + */ + +package com.apple.foundationdb.half; + +/** + * The {@code Half} class implements half precision (FP16) float-point number according to IEEE 754 standard. + *

+ * In addition, this class provides several methods for converting a {@code Half} to a {@code String} and a + * {@code String} to a {@code Half}, as well as other constants and methods useful when dealing with a {@code Half}. + *

+ * {@code Half} is implemented to provide, as much as possible, the same interface as {@link Float} and {@link Double}. + * + * @author Christian Heina (developer@christianheina.com) + */ +public final class Half extends Number implements Comparable { + + /** + * A constant holding the positive infinity of type {@code Half}. + * + *

+ * It is equal to the value returned by {@link #shortBitsToHalf(short) shortBitsToHalf((short)0x7c00)}. + */ + public static final Half POSITIVE_INFINITY = shortBitsToHalf((short) 0x7c00); + + /** + * A constant holding the negative infinity of type {@code Half}. + * + *

+ * It is equal to the value returned by {@link #shortBitsToHalf(short) shortBitsToHalf((short)0xfc00)}. + */ + public static final Half NEGATIVE_INFINITY = shortBitsToHalf((short) 0xfc00); + + /** + * A constant holding a Not-a-Number (NaN) value of type {@code Half}. + * + *

+ * It is equivalent to the value returned by {@link #shortBitsToHalf(short) shortBitsToHalf((short)0x7e00)}. + */ + @SuppressWarnings("PMD.FieldNamingConventions") + public static final Half NaN = shortBitsToHalf((short) 0x7e00); + + /** + * A constant holding the largest positive finite value of type {@code Half}, + * (2-2-10)·215. + * + *

+ * It is equal to {@link #shortBitsToHalf(short) shortBitsToHalf((short)0x7bff)}. + */ + public static final Half MAX_VALUE = shortBitsToHalf((short) 0x7bff); + + /** + * A constant holding the largest positive finite value of type {@code Half}, + * (2-2-10)·215. + * + *

+ * It is equal to {@link #shortBitsToHalf(short) shortBitsToHalf((short)0xfbff)}. + */ + public static final Half NEGATIVE_MAX_VALUE = shortBitsToHalf((short) 0xfbff); + + /** + * A constant holding the smallest positive normal value of type {@code Half}, 2-14. + * + *

+ * It is equal to {@link #shortBitsToHalf(short) shortBitsToHalf((short)0x0400)}. + */ + public static final Half MIN_NORMAL = shortBitsToHalf((short) 0x0400); // 6.103515625E-5 + + /** + * A constant holding the smallest positive nonzero value of type {@code Half}, 2-24. + * + *

+ * It is equal to {@link #shortBitsToHalf(short) shortBitsToHalf((short)0x1)}. + */ + public static final Half MIN_VALUE = shortBitsToHalf((short) 0x1); // 5.9604645E-8 + + /** + * Maximum exponent a finite {@code Half} variable may have. + * + *

+ * It is equal to the value returned by {@code HalfMath.getExponent(Half.MAX_VALUE)}. + */ + public static final int MAX_EXPONENT = 15; + + /** + * Minimum exponent a normalized {@code Half} variable may have. + * + *

+ * It is equal to the value returned by {@code HalfMath.getExponent(Half.MIN_NORMAL)}. + */ + public static final int MIN_EXPONENT = -14; + + /** + * The number of bits used to represent a {@code Half} value. + */ + public static final int SIZE = 16; + + /** + * The number of bytes used to represent a {@code Half} value. + */ + public static final int BYTES = SIZE / Byte.SIZE; + + /** + * A constant holding the positive zero of type {@code Half}. + * + *

+ * It is equal to the value returned by {@link #shortBitsToHalf(short) shortBitsToHalf((short)0x0)}. + */ + public static final Half POSITIVE_ZERO = shortBitsToHalf((short) 0x0); + + /** + * A constant holding the negative zero of type {@code Half}. + * + *

+ * It is equal to the value returned by {@link #shortBitsToHalf(short) shortBitsToHalf((short)0x8000)}. + */ + public static final Half NEGATIVE_ZERO = shortBitsToHalf((short) 0x8000); + + private static final long serialVersionUID = 1682650405628820816L; + + /** + * The value of the half precision floating-point as a float. + * + * @serial + */ + private final float floatRepresentation; + + private Half(float floatRepresentation) { + /* Hidden Constructor */ + super(); + this.floatRepresentation = floatRepresentation; + } + + /** + * Returns the {@code Half} object corresponding to a given bit representation. The argument is considered to be a + * representation of a floating-point value according to the IEEE 754 floating-point "half format" bit layout. + * + *

+ * If the argument is {@code 0x7c00}, the result is positive infinity. + * + *

+ * If the argument is {@code 0xfc00}, the result is negative infinity. + * + *

+ * If the argument is any value in the range {0x7c01} through {@code 0x7fff} or in the range {@code 0xfc01} through + * {@code 0xffff}, the result is a NaN. No IEEE 754 floating-point operation provided by Java can distinguish + * between two NaN values of the same type with different bit patterns. Distinct values of NaN are only + * distinguishable by use of the {@code Half.halfToRawShortBits} method. + * + *

+ * In all other cases, let s, e, and m be three values that can be computed from the argument: + * + *

+ * + *
+     * {
+     *     @code
+     *     int s = ((bits >> 16) == 0) ? 1 : -1;
+     *     int e = ((bits >> 10) & 0x1f);
+     *     int m = (e == 0) ? (bits & 0x3ff) << 1 : (bits & 0x3ff) | 0x200;
+     * }
+     * 
+ * + *
+ * + * Then the float-point result equals the value of the mathematical expression + * s·m·2e-25. + * + *

+ * Note that this method may not be able to return a {@code Half} NaN with exactly same bit pattern as the + * {@code short} argument. IEEE 754 distinguishes between two kinds of NaNs, quiet NaNs and signaling NaNs. + * The differences between the two kinds of NaN are generally not visible in Java. Arithmetic operations on + * signaling NaNs turn them into quiet NaNs with a different, but often similar, bit pattern. However, on some + * processors merely copying a signaling NaN also performs that conversion. In particular, copying a signaling NaN + * to return it to the calling method may perform this conversion. So {@code shortBitsToHalf} may not be able to + * return a {@code Half} with a signaling NaN bit pattern. Consequently, for some {@code short} values, + * {@code halfToRawShortBits(shortBitsToHalf(start))} may not equal {@code start}. Moreover, which particular + * bit patterns represent signaling NaNs is platform dependent; although all NaN bit patterns, quiet or signaling, + * must be in the NaN range identified above. + * + * @param shortBits + * a short. + * + * @return the {@code Half} float-point object with the same bit pattern. + */ + public static Half shortBitsToHalf(short shortBits) { + return new Half(halfShortToFloat(shortBits)); + } + + public static float halfShortToFloat(short shortBits) { + int intBits = (int) shortBits; + int exponent = (intBits & HalfConstants.EXP_BIT_MASK) >> 10; + int significand = (intBits & HalfConstants.SIGNIF_BIT_MASK) << 13; + + // Check infinities and NaN + if (exponent == 31) { + // sign | positive infinity integer value | significand + return Float.intBitsToFloat((intBits & HalfConstants.SIGN_BIT_MASK) << 16 | 0x7f800000 | significand); + } + + int v = Float.floatToIntBits((float) significand) >> 23; + // sign | normal | subnormal + return Float.intBitsToFloat( + (intBits & 0x8000) << 16 | (exponent != 0 ? 1 : 0) * ((exponent + 112) << 23 | significand) + | ((exponent == 0 ? 1 : 0) & (significand != 0 ? 1 : 0)) + * ((v - 37) << 23 | ((significand << (150 - v)) & 0x007FE000))); + } + + /** + * Returns a representation of the specified floating-point value according to the IEEE 754 floating-point "single + * format" bit layout. + * + *

+ * Bit 15 (the bit that is selected by the mask {@code 0x8000}) represents the sign of the floating-point number. + * Bits 14-10 (the bits that are selected by the mask {@code 0x7c00}) represent the exponent. Bits 9-0 (the bits + * that are selected by the mask {@code 0x03ff}) represent the significand (sometimes called the mantissa) of the + * floating-point number. + * + *

+ * If the argument is positive infinity, the result is {@code 0x7c00}. + * + *

+ * If the argument is negative infinity, the result is {@code 0xfc00}. + * + *

+ * If the argument is NaN, the result is {@code 0x7e00}. + * + *

+ * In all cases, the result is a short that, when given to the {@link #shortBitsToHalf(short)} method, will produce + * a floating-point value the same as the argument to {@code halfToShortBits} (except all NaN values are collapsed + * to a single "canonical" NaN value). + * + * @param half + * a Half object. + * + * @return the bits that represent the floating-point number. + */ + public static short halfToShortBits(Half half) { + if (!isNaN(half)) { + return halfToRawShortBits(half); + } + return 0x7e00; + } + + /** + * Returns a representation of the specified floating-point value according to the IEEE 754 floating-point "single + * format" bit layout. + * + *

+ * Bit 15 (the bit that is selected by the mask {@code 0x8000}) represents the sign of the floating-point number. + * Bits 14-10 (the bits that are selected by the mask {@code 0x7c00}) represent the exponent. Bits 9-0 (the bits + * that are selected by the mask {@code 0x03ff}) represent the significand (sometimes called the mantissa) of the + * floating-point number. + * + *

+ * If the argument is positive infinity, the result is {@code 0x7c00}. + * + *

+ * If the argument is negative infinity, the result is {@code 0xfc00}. + * + *

+ * If the argument is NaN, the result is {@code 0x7e00}. + * + *

+ * In all cases, the result is a short that, when given to the {@link #shortBitsToHalf(short)} method, will produce + * a floating-point value the same as the argument to {@code halfToShortBits} (except all NaN values are collapsed + * to a single "canonical" NaN value). + * + * @param floatRepresentation + * a float representation as used within a {@code Half} object. + * + * @return the bits that represent the floating-point number. + */ + public static short floatToShortBitsCollapseNaN(final float floatRepresentation) { + if (!Float.isNaN(floatRepresentation)) { + return floatToHalfShortBits(floatRepresentation); + } + return 0x7e00; + } + + /** + * Returns a representation of the specified floating-point value according to the IEEE 754 floating-point "single + * format" bit layout, preserving Not-a-Number (NaN) values. + * + *

+ * Bit 15 (the bit that is selected by the mask {@code 0x8000}) represents the sign of the floating-point number. + * Bits 14-10 (the bits that are selected by the mask {@code 0x7c00}) represent the exponent. Bits 9-0 (the bits + * that are selected by the mask {@code 0x03ff}) represent the significand (sometimes called the mantissa) of the + * floating-point number. + * + *

+ * If the argument is positive infinity, the result is {@code 0x7c00}. + * + *

+ * If the argument is negative infinity, the result is {@code 0xfc00}. + * + *

+ * If the argument is NaN, the result is the integer representing the actual NaN value. Unlike the + * {@code halfToShortBits} method, {@code halfToRawShortBits} does not collapse all the bit patterns encoding a NaN + * to a single "canonical" NaN value. + * + *

+ * In all cases, the result is a short that, when given to the {@link #shortBitsToHalf(short)} method, will produce + * a floating-point value the same as the argument to {@code halfToRawShortBits}. + * + * @param half + * a Half object. + * + * @return the bits that represent the half-point number. + */ + public static short halfToRawShortBits(Half half) { + return floatToHalfShortBits(half.floatRepresentation); + } + + public static short floatToHalfShortBits(float floatValue) { + int intBits = Float.floatToRawIntBits(floatValue); + int exponent = (intBits & 0x7F800000) >> 23; + int significand = intBits & 0x007FFFFF; + + // Check infinities and NaNs + if (exponent > 142) { + // sign | positive infinity short value + return (short) ((intBits & 0x80000000) >> 16 | 0x7c00 | significand >> 13); + } + + // sign | normal | subnormal + return (short) ((intBits & 0x80000000) >> 16 + | (exponent > 112 ? 1 : 0) * ((((exponent - 112) << 10) & 0x7C00) | significand >> 13) + | ((exponent < 113 ? 1 : 0) & (exponent > 101 ? 1 : 0)) + * ((((0x007FF000 + significand) >> (125 - exponent)) + 1) >> 1)); + } + + /** + * Returns the value of the specified number as a {@code short}. + * + * @return the numeric value represented by this object after conversion to type {@code short}. + */ + @Override + public short shortValue() { + if (isInfinite() || floatValue() > Short.MAX_VALUE || floatValue() < Short.MIN_VALUE) { + return ((Float.floatToIntBits(floatValue()) & 0x80000000) >> 31) == 0 ? Short.MAX_VALUE : Short.MIN_VALUE; + } + return (short) floatValue(); + } + + @Override + public int intValue() { + return (int) floatValue(); + } + + @Override + public long longValue() { + return (long) floatValue(); + } + + @Override + public float floatValue() { + return floatRepresentation; + } + + @Override + public double doubleValue() { + return floatValue(); + } + + /** + * Returns the value of the specified number as a {@code byte}. + * + * @return the numeric value represented by this object after conversion to type {@code byte}. + */ + @Override + public byte byteValue() { + return (byte) shortValue(); + } + + /** + * Returns a {@code Half} object represented by the argument string {@code s}. + * + *

+ * If {@code s} is {@code null}, then a {@code NullPointerException} is thrown. + * + *

+ * Leading and trailing whitespace characters in {@code s} are ignored. Whitespace is removed as if by the + * {@link String#trim} method; that is, both ASCII space and control characters are removed. The rest of {@code s} + * should constitute a FloatValue as described by the lexical syntax rules: + * + *

+ *
+ *
FloatValue: + *
Signopt {@code NaN} + *
Signopt {@code Infinity} + *
Signopt FloatingPointLiteral + *
Signopt HexFloatingPointLiteral + *
SignedInteger + *
+ * + *
+ *
HexFloatingPointLiteral: + *
HexSignificand BinaryExponent FloatTypeSuffixopt + *
+ * + *
+ *
HexSignificand: + *
HexNumeral + *
HexNumeral {@code .} + *
{@code 0x} HexDigitsopt {@code .} HexDigits + *
{@code 0X} HexDigitsopt {@code .} HexDigits + *
+ * + *
+ *
BinaryExponent: + *
BinaryExponentIndicator SignedInteger + *
+ * + *
+ *
BinaryExponentIndicator: + *
{@code p} + *
{@code P} + *
+ * + *
+ * + * where Sign, FloatingPointLiteral, HexNumeral, HexDigits, SignedInteger and + * FloatTypeSuffix are as defined in the lexical structure sections of The Java Language + * Specification, except that underscores are not accepted between digits. If {@code s} does not have the + * form of a FloatValue, then a {@code NumberFormatException} is thrown. Otherwise, {@code s} is regarded as + * representing an exact decimal value in the usual "computerized scientific notation" or as an exact hexadecimal + * value; this exact numerical value is then conceptually converted to an "infinitely precise" binary value that is + * then rounded to type {@code half} by the usual round-to-nearest rule of IEEE 754 floating-point arithmetic, which + * includes preserving the sign of a zero value. + * + * Note that the round-to-nearest rule also implies overflow and underflow behaviour; if the exact value of + * {@code s} is large enough in magnitude (greater than or equal to ({@link #MAX_VALUE} + + * {@link HalfMath#ulp(Half half) HalfMath.ulp(MAX_VALUE)}/2), rounding to {@code float} will result in an infinity + * and if the exact value of {@code s} is small enough in magnitude (less than or equal to {@link #MIN_VALUE}/2), + * rounding to float will result in a zero. + * + * Finally, after rounding a {@code Half} object is returned. + * + *

+ * To interpret localized string representations of a floating-point value, use subclasses of + * {@link java.text.NumberFormat}. + * + *

+ * To avoid calling this method on an invalid string and having a {@code NumberFormatException} be thrown, the + * documentation for {@link Double#valueOf Double.valueOf} lists a regular expression which can be used to screen + * the input. + * + * @param s + * the string to be parsed. + * + * @return a {@code Half} object holding the value represented by the {@code String} argument. + * + * @throws NumberFormatException + * if the string does not contain a parsable number. + */ + public static Half valueOf(String s) throws NumberFormatException { + return valueOf(Float.valueOf(s)); + } + + /** + * Returns a {@code Half} instance representing the specified {@code double} value. + * + * @param doubleValue + * a double value. + * + * @return a {@code Half} instance representing {@code doubleValue}. + */ + public static Half valueOf(double doubleValue) { + return valueOf((float) doubleValue); + } + + /** + * Returns a {@code Half} instance representing the specified {@code Double} value. + * + * @param doubleValue + * a double value. + * + * @return a {@code Half} instance representing {@code doubleValue}. + */ + public static Half valueOf(Double doubleValue) { + return valueOf(doubleValue.doubleValue()); + } + + /** + * Returns a {@code Half} instance representing the specified {@code float} value. + * + * @param floatValue + * a float value. + * + * @return a {@code Half} instance representing {@code floatValue}. + */ + public static Half valueOf(float floatValue) { + return new Half(quantizeFloat(floatValue)); + } + + public static float quantizeFloat(final float floatValue) { + // check for infinities + if (floatValue > 65_504.0f || floatValue < -65_504.0f) { + return Half.halfShortToFloat((short) ((Float.floatToIntBits(floatValue) & 0x80000000) >> 16 | 0x7c00)); + } + return Half.halfShortToFloat(floatToShortBitsCollapseNaN(floatValue)); + } + + /** + * Returns a {@code Half} instance representing the specified {@code Float} value. + * + * @param floatValue + * a float value. + * + * @return a {@code Half} instance representing {@code floatValue}. + */ + public static Half valueOf(Float floatValue) { + return valueOf(floatValue.floatValue()); + } + + /** + * Returns a new {@code Half} instance identical to the specified {@code half}. + * + * @param half + * a half instance. + * + * @return a {@code Half} instance representing {@code doubleValue}. + */ + public static Half valueOf(Half half) { + return shortBitsToHalf(halfToRawShortBits(half)); + } + + /** + * Returns {@code true} if the specified number is a Not-a-Number (NaN) value, {@code false} otherwise. + * + * @param half + * the {@code Half} to be tested. + * + * @return {@code true} if the argument is NaN; {@code false} otherwise. + */ + public static boolean isNaN(Half half) { + return Float.isNaN(half.floatRepresentation); + } + + /** + * Returns {@code true} if this {@code Half} value is a Not-a-Number (NaN), {@code false} otherwise. + * + * @return {@code true} if the value represented by this object is NaN; {@code false} otherwise. + */ + public boolean isNaN() { + return isNaN(this); + } + + /** + * Returns {@code true} if the specified {@code Half} is infinitely large in magnitude, {@code false} otherwise. + * + * @param half + * the {@code Half} to be tested. + * + * @return {@code true} if the argument is positive infinity or negative infinity; {@code false} otherwise. + */ + public static boolean isInfinite(Half half) { + return Float.isInfinite(half.floatRepresentation); + } + + /** + * Returns {@code true} if this {@code Half} value is infinitely large in magnitude, {@code false} otherwise. + * + * @return {@code true} if the value represented by this object is positive infinity or negative infinity; + * {@code false} otherwise. + */ + public boolean isInfinite() { + return isInfinite(this); + } + + /** + * Returns {@code true} if the argument is a finite floating-point value; returns {@code false} otherwise (for NaN + * and infinity arguments). + * + * @param half + * the {@code Half} to be tested + * + * @return {@code true} if the argument is a finite floating-point value, {@code false} otherwise. + */ + public static boolean isFinite(Half half) { + return Float.isFinite(half.floatRepresentation); + } + + /** + * Returns {@code true} if the argument is a finite floating-point value; returns {@code false} otherwise (for NaN + * and infinity arguments). + * + * @return {@code true} if the argument is a finite floating-point value, {@code false} otherwise. + */ + public boolean isFinite() { + return isFinite(this); + } + + /** + * Returns a string representation of the {@code half} argument. All characters mentioned below are ASCII + * characters. + *

    + *
  • If the argument is NaN, the result is the string "{@code NaN}". + *
  • Otherwise, the result is a string that represents the sign and magnitude (absolute value) of the argument. If + * the sign is negative, the first character of the result is '{@code -}' ({@code '\u005Cu002D'}); if the sign is + * positive, no sign character appears in the result. As for the magnitude m: + *
      + *
    • If m is infinity, it is represented by the characters {@code "Infinity"}; thus, positive infinity + * produces the result {@code "Infinity"} and negative infinity produces the result {@code "-Infinity"}. + *
    • If m is zero, it is represented by the characters {@code "0.0"}; thus, negative zero produces the + * result {@code "-0.0"} and positive zero produces the result {@code "0.0"}. + *
    • If m is greater than or equal to 10-3 but less than 107, then it is represented + * as the integer part of m, in decimal form with no leading zeroes, followed by '{@code .}' + * ({@code '\u005Cu002E'}), followed by one or more decimal digits representing the fractional part of m. + *
    • If m is less than 10-3 or greater than or equal to 107, then it is represented + * in so-called "computerized scientific notation." Let n be the unique integer such that 10n + * m {@literal <} 10n+1; then let a be the mathematically exact quotient + * of m and 10n so that 1 ≤ a {@literal <} 10. The magnitude is then represented + * as the integer part of a, as a single decimal digit, followed by '{@code .}' ({@code '\u005Cu002E'}), + * followed by decimal digits representing the fractional part of a, followed by the letter '{@code E}' + * ({@code '\u005Cu0045'}), followed by a representation of n as a decimal integer, as produced by the method + * {@link java.lang.Integer#toString(int)}. + * + *
    + *
+ * Handled as a float and number of significant digits is determined by {@link Float#toString(float f) + * Float.toString(floatValue)} using results of {@link #floatValue()} call using {@code half} instance. + * + *

+ * To create localized string representations of a floating-point value, use subclasses of + * {@link java.text.NumberFormat}. + * + * @param half + * the Half to be converted. + * + * @return a string representation of the argument. + */ + public static String toString(Half half) { + // Use float toString for now + // Should have own toString implementation for better result. + return Float.toString(half.floatRepresentation); + } + + @Override + public String toString() { + return toString(this); + } + + /** + * Returns a hexadecimal string representation of the {@code half} argument. All characters mentioned below are + * ASCII characters. + * + *

    + *
  • If the argument is NaN, the result is the string "{@code NaN}". + *
  • Otherwise, the result is a string that represents the sign and magnitude (absolute value) of the argument. If + * the sign is negative, the first character of the result is '{@code -}' ({@code '\u005Cu002D'}); if the sign is + * positive, no sign character appears in the result. As for the magnitude m: + * + *
      + *
    • If m is infinity, it is represented by the string {@code "Infinity"}; thus, positive infinity produces + * the result {@code "Infinity"} and negative infinity produces the result {@code "-Infinity"}. + * + *
    • If m is zero, it is represented by the string {@code "0x0.0p0"}; thus, negative zero produces the + * result {@code "-0x0.0p0"} and positive zero produces the result {@code "0x0.0p0"}. + * + *
    • If m is a {@code half} with a normalized representation, substrings are used to represent the + * significand and exponent fields. The significand is represented by the characters {@code "0x1."} followed by a + * lowercase hexadecimal representation of the rest of the significand as a fraction. Trailing zeros in the + * hexadecimal representation are removed unless all the digits are zero, in which case a single zero is used. Next, + * the exponent is represented by {@code "p"} followed by a decimal string of the unbiased exponent as if produced + * by a call to {@link Integer#toString(int) Integer.toString} on the exponent value. + * + *
    • If m is a {@code half} with a subnormal representation, the significand is represented by the + * characters {@code "0x0."} followed by a hexadecimal representation of the rest of the significand as a fraction. + * Trailing zeros in the hexadecimal representation are removed. Next, the exponent is represented by + * {@code "p-14"}. Note that there must be at least one nonzero digit in a subnormal significand. + * + *
    + * + *
+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Examples
Floating-point ValueHexadecimal String
{@code 1.0}{@code 0x1.0p0}
{@code -1.0}{@code -0x1.0p0}
{@code 2.0}{@code 0x1.0p1}
{@code 3.0}{@code 0x1.8p1}
{@code 0.5}{@code 0x1.0p-1}
{@code 0.25}{@code 0x1.0p-2}
{@code Float.MAX_VALUE}{@code 0x1.ffcp15}
{@code Minimum Normal Value}{@code 0x1.0p-14}
{@code Maximum Subnormal Value}{@code 0x0.ffcp-14}
{@code Float.MIN_VALUE}{@code 0x0.004p-14}
+ * + * @param half + * the {@code Half} to be converted. + * + * @return a hex string representation of the argument. + */ + public static String toHexString(Half half) { + // Check subnormal + if (HalfMath.abs(half).compareTo(Half.MIN_NORMAL) < 0 && !HalfMath.abs(half).equals(Half.POSITIVE_ZERO)) { + String s = Double + .toHexString(Math.scalb((double) half.floatValue(), Double.MIN_EXPONENT - Half.MIN_EXPONENT)); + return s.replaceFirst("p-1022$", "p-14"); + } else { + // double string will be the same as half string + return Double.toHexString(half.floatValue()); + } + } + + @Override + public boolean equals(Object obj) { + return (obj instanceof Half) && Float.valueOf(((Half) obj).floatRepresentation).equals(floatRepresentation); + } + + /** + * Returns a hash code for a {@code Half}; compatible with {@code Half.hashCode()}. + * + * @param half + * the {@code Half} to hash + * + * @return a hash code value for a {@code Half} value. + */ + public static int hashCode(Half half) { + return halfToShortBits(half); + } + + /** + * Returns a hash code for this {@code Half} object. The result is the short bit representation, exactly as produced + * by the method {@link #halfToShortBits(Half)} represented by this {@code Half} object. + * + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return hashCode(this); + } + + /** + * Compares the two specified {@code Half} objects. The sign of the integer value returned is the same as that of + * the integer that would be returned by the call: + * + *
+     * half1.compareTo(half2)
+     * 
+ * + * @param half1 + * the first {@code Half} to compare. + * @param half2 + * the second {@code Half} to compare. + * + * @return the value {@code 0} if {@code half1} is numerically equal to {@code half2}; a value less than {@code 0} + * if {@code half1} is numerically less than {@code half2}; and a value greater than {@code 0} if + * {@code half1} is numerically greater than {@code half2}. + */ + public static int compare(Half half1, Half half2) { + return Float.compare(half1.floatRepresentation, half2.floatRepresentation); + } + + @Override + public int compareTo(Half anotherHalf) { + return compare(this, anotherHalf); + } + + /** + * Adds two {@code Half} values together as per the + operator. + * + * @param a + * the first operand + * @param b + * the second operand + * + * @return the sum of {@code a} and {@code b} + * + * @see java.util.function.BinaryOperator + */ + public static Half sum(Half a, Half b) { + return Half.valueOf(Float.sum(a.floatRepresentation, b.floatRepresentation)); + } + + /** + * Returns the greater of two {@code Half} objects.
+ * Determined using {@link #floatValue() aFloatValue = a.floatValue()} and {@link #floatValue() bFloatValue = + * b.floatValue()} then calling {@link Float#max(float, float) Float.max(aFloatValue, bFloatValue)}. + * + * @param a + * the first operand + * @param b + * the second operand + * + * @return the greater of {@code a} and {@code b} + * + * @see java.util.function.BinaryOperator + */ + public static Half max(Half a, Half b) { + return Half.valueOf(Math.max(a.floatRepresentation, b.floatRepresentation)); + } + + /** + * Returns the smaller of two {@code Half} objects.
+ * Determined using {@link #floatValue() aFloatValue = a.floatValue()} and {@link #floatValue() bFloatValue = + * b.floatValue()} then calling {@link Float#min(float, float) Float.min(aFloatValue, bFloatValue)}. + * + * @param a + * the first operand + * @param b + * the second operand + * + * @return the smaller of {@code a} and {@code b} + * + * @see java.util.function.BinaryOperator + */ + public static Half min(Half a, Half b) { + return Half.valueOf(Math.min(a.floatRepresentation, b.floatRepresentation)); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/half/HalfConstants.java b/fdb-extensions/src/main/java/com/apple/foundationdb/half/HalfConstants.java new file mode 100644 index 0000000000..5e6d74ee22 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/half/HalfConstants.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Christian Heina + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications Copyright 2015-2025 Apple Inc. and the FoundationDB project authors. + * This source file is part of the FoundationDB open source project + */ + +package com.apple.foundationdb.half; + +/** + * This class contains additional constants documenting limits of {@code Half}. + *

+ * {@code HalfConstants} is implemented to provide, as much as possible, the same interface as + * {@code jdk.internal.math.FloatConsts}. + * + * @author Christian Heina (developer@christianheina.com) + */ +public class HalfConstants { + /** + * The number of logical bits in the significand of a {@code half} number, including the implicit bit. + */ + public static final int SIGNIFICAND_WIDTH = 11; + + /** + * The exponent the smallest positive {@code half} subnormal value would have if it could be normalized. + */ + public static final int MIN_SUB_EXPONENT = Half.MIN_EXPONENT - (SIGNIFICAND_WIDTH - 1); + + /** + * Bias used in representing a {@code half} exponent. + */ + public static final int EXP_BIAS = 15; + + /** + * Bit mask to isolate the sign bit of a {@code half}. + */ + public static final int SIGN_BIT_MASK = 0x8000; + + /** + * Bit mask to isolate the exponent field of a {@code half}. + */ + public static final int EXP_BIT_MASK = 0x7C00; + + /** + * Bit mask to isolate the significand field of a {@code half}. + */ + public static final int SIGNIF_BIT_MASK = 0x03FF; + + private HalfConstants() { + /* Hidden Constructor */ + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/half/HalfMath.java b/fdb-extensions/src/main/java/com/apple/foundationdb/half/HalfMath.java new file mode 100644 index 0000000000..43f4b1bf6c --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/half/HalfMath.java @@ -0,0 +1,123 @@ +/* + * Copyright 2023 Christian Heina + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications Copyright 2015-2025 Apple Inc. and the FoundationDB project authors. + * This source file is part of the FoundationDB open source project + */ + +package com.apple.foundationdb.half; + +/** + * The class {@code HalfMath} contains methods for performing basic numeric operations on or using {@link Half} objects. + * + * @author Christian Heina (developer@christianheina.com) + */ +public class HalfMath { + private HalfMath() { + /* Hidden Constructor */ + } + + /** + * Returns the size of an ulp of the argument. An ulp, unit in the last place, of a {@code half} value is the + * positive distance between this floating-point value and the {@code half} value next larger in magnitude.
+ * Note that for non-NaN x, ulp(-x) == ulp(x). + * + *

+ * Special Cases: + *

    + *
  • If the argument is NaN, then the result is NaN. + *
  • If the argument is positive or negative infinity, then the result is positive infinity. + *
  • If the argument is positive or negative zero, then the result is {@code Float.MIN_VALUE}. + *
  • If the argument is ±{@code Half.MAX_VALUE}, then the result is equal to 25. + *
+ * + * @param half + * the floating-point value whose ulp is to be returned + * + * @return the size of an ulp of the argument + */ + public static Half ulp(Half half) { + int exp = getExponent(half); + + switch (exp) { + case Half.MAX_EXPONENT + 1: // NaN or infinity values + return abs(half); + case Half.MIN_EXPONENT - 1: // zero or subnormal values + return Half.MIN_VALUE; + default: // Normal values + exp = exp - (HalfConstants.SIGNIFICAND_WIDTH - 1); + if (exp >= Half.MIN_EXPONENT) { + // Normal result + return powerOfTwoH(exp); + } else { + // Subnormal result + return Half.shortBitsToHalf( + (short) (1 << (exp - (Half.MIN_EXPONENT - (HalfConstants.SIGNIFICAND_WIDTH - 1))))); + } + } + } + + /** + * Returns the unbiased exponent used in the representation of a {@code half}. + * + *

+ * Special cases: + * + *

    + *
  • If the argument is NaN or infinite, then the result is {@link Half#MAX_EXPONENT} + 1. + *
  • If the argument is zero or subnormal, then the result is {@link Half#MIN_EXPONENT} - 1. + *
+ * + * @param half + * a {@code half} value + * + * @return the unbiased exponent of the argument + */ + public static int getExponent(Half half) { + return ((Half.halfToRawShortBits(half) & HalfConstants.EXP_BIT_MASK) >> (HalfConstants.SIGNIFICAND_WIDTH - 1)) + - HalfConstants.EXP_BIAS; + } + + /** + * Returns the absolute {@link Half} object of a {@code half} instance. + * + *

+ * Special cases: + *

    + *
  • If the argument is positive zero or negative zero, the result is positive zero. + *
  • If the argument is infinite, the result is positive infinity. + *
  • If the argument is NaN, the result is a "canonical" NaN (preserving Not-a-Number (NaN) signaling). + *
+ * + * @param half + * the argument whose absolute value is to be determined + * + * @return the absolute value of the argument. + */ + public static Half abs(Half half) { + if (half.isNaN()) { + return Half.valueOf(half); + } + return Half.shortBitsToHalf((short) (Half.halfToRawShortBits(half) & 0x7fff)); + } + + /** + * Returns a floating-point power of two in the normal range. + */ + private static Half powerOfTwoH(int n) { + return Half.shortBitsToHalf( + (short) (((n + HalfConstants.EXP_BIAS) << (HalfConstants.SIGNIFICAND_WIDTH - 1)) & HalfConstants.EXP_BIT_MASK)); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/half/package-info.java b/fdb-extensions/src/main/java/com/apple/foundationdb/half/package-info.java new file mode 100644 index 0000000000..1e7ea172fb --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/half/package-info.java @@ -0,0 +1,26 @@ +/* + * package-info.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Package that implements a half precision datatype. This implementation is + * copied from HALF4J + * and was subsequently modified. + */ +package com.apple.foundationdb.half; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/AbstractRealVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/AbstractRealVector.java new file mode 100644 index 0000000000..f7a82c4d4c --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/AbstractRealVector.java @@ -0,0 +1,216 @@ +/* + * AbstractRealVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.base.Suppliers; +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import java.util.Arrays; +import java.util.Objects; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +/** + * An abstract base class representing a mathematical vector. + *

+ * This class provides a generic framework for vectors of different numerical types, + * where {@code R} is a subtype of {@link Number}. It includes common operations and functionalities like size, + * component access, equality checks, and conversions. Concrete implementations must provide specific logic for + * data type conversions and raw data representation. + */ +public abstract class AbstractRealVector implements RealVector { + @Nonnull + final double[] data; + + @Nonnull + protected Supplier hashCodeSupplier; + + @Nonnull + private final Supplier toRawDataSupplier; + + /** + * Constructs a new RealVector with the given data. + *

+ * This constructor uses the provided array directly as the backing store for the vector. It does not create a + * defensive copy. Therefore, any subsequent modifications to the input array will be reflected in this vector's + * state. The contract of this constructor is that callers do not modify {@code data} after calling the constructor. + * We do not want to copy the array here for performance reasons. + * @param data the components of this vector + * @throws NullPointerException if the provided {@code data} array is null. + */ + protected AbstractRealVector(@Nonnull final double[] data) { + this.data = data; + this.hashCodeSupplier = Suppliers.memoize(this::computeHashCode); + this.toRawDataSupplier = Suppliers.memoize(this::computeRawData); + } + + /** + * Returns the number of elements in the vector. + * @return the number of elements + */ + @Override + public int getNumDimensions() { + return data.length; + } + + /** + * Gets the component of this object at the specified dimension. + *

+ * The dimension is a zero-based index. For a 3D vector, for example, dimension 0 might correspond to the + * x-component, 1 to the y-component, and 2 to the z-component. This method provides direct access to the + * underlying data element. + * @param dimension the zero-based index of the component to retrieve. + * @return the component at the specified dimension, which is guaranteed to be non-null. + * @throws IndexOutOfBoundsException if the {@code dimension} is negative or + * greater than or equal to the number of dimensions of this object. + */ + @Override + public double getComponent(int dimension) { + return data[dimension]; + } + + /** + * Returns the underlying data array. + *

+ * The returned array is guaranteed to be non-null. Note that this method + * returns a direct reference to the internal array, not a copy. + * @return the data array of type {@code R[]}, never {@code null}. + */ + @Nonnull + @Override + public double[] getData() { + return data; + } + + /** + * Gets the raw byte data representation of this object. + *

+ * This method provides a direct, unprocessed view of the object's underlying data. The format of the byte array is + * implementation-specific and should be documented by the concrete class that implements this method. + * @return a non-null byte array containing the raw data. + */ + @Nonnull + @Override + public byte[] getRawData() { + return toRawDataSupplier.get(); + } + + /** + * Computes the raw byte data representation of this object. + *

+ * This method provides a direct, unprocessed view of the object's underlying data. The format of the byte array is + * implementation-specific and should be documented by the concrete class that implements this method. + * @return a non-null byte array containing the raw data. + */ + @Nonnull + protected abstract byte[] computeRawData(); + + /** + * Compares this vector to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null} and is a {@code RealVector} object that + * has the same data elements as this object. This method performs a deep equality check on the underlying data + * elements using {@link Objects#deepEquals(Object, Object)}. + * @param o the object to compare with this {@code RealVector} for equality. + * @return {@code true} if the given object is a {@code RealVector} equivalent to this vector, {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!(o instanceof AbstractRealVector)) { + return false; + } + final AbstractRealVector vector = (AbstractRealVector)o; + return Arrays.equals(data, vector.data); + } + + /** + * Returns a hash code value for this object. The hash code is computed once and memoized. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return hashCodeSupplier.get(); + } + + /** + * Computes a hash code based on the internal {@code data} array. + * @return the computed hash code for this object. + */ + private int computeHashCode() { + return Arrays.hashCode(data); + } + + /** + * Returns a string representation of the object. + *

+ * This method provides a default string representation by calling + * {@link #toString(int)} with a predefined indentation level of 3. + * + * @return a string representation of this object with a default indentation. + */ + @Override + public String toString() { + return toString(10); + } + + /** + * Generates a string representation of the data array, with an option to limit the number of dimensions shown. + *

+ * If the specified {@code limitDimensions} is less than the actual number of dimensions in the data array, + * the resulting string will be a truncated view, ending with {@code ", ..."} to indicate that more elements exist. + * Otherwise, the method returns a complete string representation of the entire array. + * @param limitDimensions The maximum number of array elements to include in the string. A non-positive + * value will cause an {@link com.google.common.base.VerifyException}. + * @return A string representation of the data array, potentially truncated. + * @throws com.google.common.base.VerifyException if {@code limitDimensions} is not positive + */ + public String toString(final int limitDimensions) { + Verify.verify(limitDimensions > 0); + if (limitDimensions < data.length) { + return "[" + Arrays.stream(Arrays.copyOfRange(data, 0, limitDimensions)) + .mapToObj(String::valueOf) + .collect(Collectors.joining(",")) + ", ...]"; + } else { + return "[" + Arrays.stream(data) + .mapToObj(String::valueOf) + .collect(Collectors.joining(",")) + "]"; + } + } + + @Nonnull + protected static double[] fromInts(@Nonnull final int[] ints) { + final double[] result = new double[ints.length]; + for (int i = 0; i < ints.length; i++) { + result[i] = ints[i]; + } + return result; + } + + @Nonnull + protected static double[] fromLongs(@Nonnull final long[] longs) { + final double[] result = new double[longs.length]; + for (int i = 0; i < longs.length; i++) { + result[i] = longs[i]; + } + return result; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/ColumnMajorRealMatrix.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/ColumnMajorRealMatrix.java new file mode 100644 index 0000000000..a6e58ea05d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/ColumnMajorRealMatrix.java @@ -0,0 +1,144 @@ +/* + * ColumnMajorRealMatrix.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; + +import javax.annotation.Nonnull; +import java.util.Arrays; + +public class ColumnMajorRealMatrix implements RealMatrix { + @Nonnull + private final double[][] data; + @Nonnull + private final Supplier hashCodeSupplier; + + public ColumnMajorRealMatrix(@Nonnull final double[][] data) { + Preconditions.checkArgument(data.length > 0); + Preconditions.checkArgument(data[0].length > 0); + this.data = data; + this.hashCodeSupplier = Suppliers.memoize(this::valueBasedHashCode); + } + + @Nonnull + @Override + public double[][] getData() { + return data; + } + + @Override + public int getRowDimension() { + return data[0].length; + } + + @Override + public int getColumnDimension() { + return data.length; + } + + @Override + public double getEntry(final int row, final int column) { + return data[column][row]; + } + + @Nonnull + public double[] getColumn(final int column) { + return data[column]; + } + + @Nonnull + @Override + public RealMatrix transpose() { + int n = getRowDimension(); + int m = getColumnDimension(); + double[][] result = new double[n][m]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < m; j++) { + result[i][j] = getEntry(i, j); + } + } + return new ColumnMajorRealMatrix(result); + } + + @Nonnull + @Override + public RealMatrix multiply(@Nonnull final RealMatrix otherMatrix) { + Preconditions.checkArgument(getColumnDimension() == otherMatrix.getRowDimension()); + int n = getRowDimension(); + int m = otherMatrix.getColumnDimension(); + int common = getColumnDimension(); + double[][] result = new double[m][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < m; j++) { + for (int k = 0; k < common; k++) { + result[j][i] += getEntry(i, k) * otherMatrix.getEntry(k, j); + } + } + } + return new ColumnMajorRealMatrix(result); + } + + @Nonnull + @Override + public RealMatrix subMatrix(final int startRow, final int lengthRow, final int startColumn, final int lengthColumn) { + final double[][] subData = new double[lengthColumn][lengthRow]; + + for (int j = startColumn; j < startColumn + lengthColumn; j ++) { + System.arraycopy(data[j], startRow, subData[j - startColumn], 0, lengthRow); + } + + return new ColumnMajorRealMatrix(subData); + } + + @Nonnull + @Override + public RowMajorRealMatrix toRowMajor() { + return new RowMajorRealMatrix(transpose().getData()); + } + + @Nonnull + @Override + public ColumnMajorRealMatrix toColumnMajor() { + return this; + } + + @Nonnull + @Override + public RealMatrix quickTranspose() { + return new RowMajorRealMatrix(data); + } + + @Override + public final boolean equals(final Object o) { + if (o instanceof ColumnMajorRealMatrix) { + final ColumnMajorRealMatrix that = (ColumnMajorRealMatrix)o; + return Arrays.deepEquals(data, that.data); + } + return valueEquals(o); + } + + @Override + public int hashCode() { + return hashCodeSupplier.get(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/DoubleRealVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/DoubleRealVector.java new file mode 100644 index 0000000000..c84e92707c --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/DoubleRealVector.java @@ -0,0 +1,142 @@ +/* + * DoubleRealVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.base.Suppliers; +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.function.Supplier; + +/** + * A vector class encoding a vector over double components. Conversion to {@link HalfRealVector} is supported and + * memoized. + */ +public class DoubleRealVector extends AbstractRealVector { + @Nonnull + private final Supplier toHalfVectorSupplier; + @Nonnull + private final Supplier toFloatVectorSupplier; + + public DoubleRealVector(@Nonnull final Double[] doubleData) { + this(computeDoubleData(doubleData)); + } + + public DoubleRealVector(@Nonnull final double[] data) { + super(data); + this.toHalfVectorSupplier = Suppliers.memoize(this::computeHalfRealVector); + this.toFloatVectorSupplier = Suppliers.memoize(this::computeFloatRealVector); + } + + public DoubleRealVector(@Nonnull final int[] intData) { + this(fromInts(intData)); + } + + public DoubleRealVector(@Nonnull final long[] longData) { + this(fromLongs(longData)); + } + + @Nonnull + @Override + public HalfRealVector toHalfRealVector() { + return toHalfVectorSupplier.get(); + } + + @Nonnull + public HalfRealVector computeHalfRealVector() { + return new HalfRealVector(data); + } + + @Nonnull + @Override + public FloatRealVector toFloatRealVector() { + return toFloatVectorSupplier.get(); + } + + @Nonnull + private FloatRealVector computeFloatRealVector() { + return new FloatRealVector(data); + } + + @Nonnull + @Override + public DoubleRealVector toDoubleRealVector() { + return this; + } + + @Nonnull + @Override + public RealVector withData(@Nonnull final double[] data) { + return new DoubleRealVector(data); + } + + /** + * Converts this {@link RealVector} of {@code double} precision floating-point numbers into a byte array. + *

+ * This method iterates through the input vector, converting each {@code double} element into its 16-bit short + * representation. It then serializes this short into eight bytes, placing them sequentially into the resulting byte + * array. The final array's length will be {@code 8 * vector.size()}. + * @return a new byte array representing the serialized vector data. This array is never null. + */ + @Nonnull + @Override + protected byte[] computeRawData() { + final byte[] vectorBytes = new byte[1 + 8 * getNumDimensions()]; + final ByteBuffer buffer = ByteBuffer.wrap(vectorBytes).order(ByteOrder.BIG_ENDIAN); + buffer.put((byte)VectorType.DOUBLE.ordinal()); + for (int i = 0; i < getNumDimensions(); i ++) { + buffer.putDouble(getComponent(i)); + } + return vectorBytes; + } + + @Nonnull + private static double[] computeDoubleData(@Nonnull Double[] doubleData) { + double[] result = new double[doubleData.length]; + for (int i = 0; i < doubleData.length; i++) { + result[i] = doubleData[i]; + } + return result; + } + + /** + * Creates a {@link DoubleRealVector} from a byte array. + *

+ * This method interprets the input byte array as a sequence of 64-bit double-precision floating-point numbers. Each + * run of eight bytes is converted into a {@code double} value, which then becomes a component of the resulting + * vector. + * @param vectorBytes the non-null byte array to convert + * @return a new {@link DoubleRealVector} instance created from the byte array + */ + @Nonnull + public static DoubleRealVector fromBytes(@Nonnull final byte[] vectorBytes) { + final ByteBuffer buffer = ByteBuffer.wrap(vectorBytes).order(ByteOrder.BIG_ENDIAN); + Verify.verify(buffer.get() == VectorType.DOUBLE.ordinal()); + final int numDimensions = vectorBytes.length >> 3; + final double[] vectorComponents = new double[numDimensions]; + for (int i = 0; i < numDimensions; i ++) { + vectorComponents[i] = buffer.getDouble(); + } + return new DoubleRealVector(vectorComponents); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Estimator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Estimator.java new file mode 100644 index 0000000000..b11377a688 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Estimator.java @@ -0,0 +1,46 @@ +/* + * Estimator.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import javax.annotation.Nonnull; + +/** + * Interface of an estimator used for calculating the distance between vectors. + *

+ * Implementations of this interface are expected to provide a specific distance + * metric calculation, often used in search or similarity contexts where one + * vector (the query) is compared against many stored vectors. + */ +public interface Estimator { + /** + * Calculates the distance between a pre-rotated and translated query vector and a stored vector. + *

+ * This method is designed to compute the distance metric between two vectors in a high-dimensional space. It is + * crucial that the {@code query} vector has already been appropriately transformed (e.g., rotated and translated) + * to align with the coordinate system of the {@code storedVector} before calling this method. + * + * @param query the pre-rotated and translated query vector, cannot be null. + * @param storedVector the stored vector to which the distance is calculated, cannot be null. + * @return a non-negative {@code double} representing the distance between the two vectors. + */ + double distance(@Nonnull RealVector query, // pre-rotated query q + @Nonnull RealVector storedVector); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FhtKacRotator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FhtKacRotator.java new file mode 100644 index 0000000000..9ff2f776f5 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FhtKacRotator.java @@ -0,0 +1,299 @@ +/* + * FhtKacRotator.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.annotations.VisibleForTesting; + +import javax.annotation.Nonnull; +import java.util.Arrays; +import java.util.BitSet; +import java.util.Random; + +/** + * FhtKac-like random orthogonal rotator which implements {@link LinearOperator}. + *

+ * An orthogonal rotator conceptually is an orthogonal matrix that is applied to some vector {@code x} yielding a + * new vector {@code y} that is rotated in some way along its n dimensions. Important to notice here is that such a + * rotation preserves distances/lengths as well as angles between vectors. + *

+ * Practically, we do not want to materialize such a rotator in memory as for {@code n} dimensions that would amount + * to {@code n^2} cells. In addition to that multiplying a matrix with a vector computationally is {@code O(n^2)} which + * is prohibitively expensive for large {@code n}. + *

+ * We also want to achieve some sort of randomness with these rotations. For small {@code n}, we can start by creating a + * randomly generated matrix and decompose it using QR-decomposition into an orthogonal matrix {@code Q} and an upper + * triangular matrix {@code R}. That matrix {@code Q} is indeed random (see matrix Haar randomness) and what we are + * actually trying to find. For larger {@code R} this approach becomes impractical as {@code Q} needs to be represented + * as a dense matrix which increases memory footprint and makes rotations slow (see above). + *

+ * The main idea is to use several operators in conjunction: + *

    + *
  • {@code K} which is orthogonal and applies several + * Givens rotation at once.
  • + *
  • {@code D_n} which are Random Rademacher diagonal sign matrices, i.e. matrices that are conceptually all + * {@code 0} except for the diagonal which contains any combination {@code -1}, {@code 1}. These matrices + * are also orthogonal. In particular, it flips only the signs of the elements of some other matrix when applied + * to it.
  • + *
  • {@code H} a Hadamard matrix of a suitable size. + *
  • + *
+ *

+ * All these linear operators are combined in a way that we eventually compute the result of + *

+ * {@code
+ *     x' = D1 H D_2 H ... D_(R-1) H D_(R) H K x
+ * }
+ * 
+ * (for {@code R} rotations). None of these operators require a significant amount of memory (O(R * n) bits for signs). + * They perform the complete rotation in {@code O(R * (n log n))}. + */ +@SuppressWarnings({"checkstyle:MethodName", "checkstyle:MemberName"}) +public final class FhtKacRotator implements LinearOperator { + private final int numDimensions; + private final int rounds; + private final BitSet[] signs; // signs[r] of i bits in {not set: -1, set: +1} + private static final double INV_SQRT2 = 1.0 / Math.sqrt(2.0); + + public FhtKacRotator(final long seed, final int numDimensions, final int rounds) { + if (numDimensions < 2) { + throw new IllegalArgumentException("n must be >= 2"); + } + if (rounds < 1) { + throw new IllegalArgumentException("rounds must be >= 1"); + } + this.numDimensions = numDimensions; + this.rounds = rounds; + + // Pre-generate Rademacher signs for determinism/reuse. + final Random rng = new Random(seed); + this.signs = new BitSet[rounds]; + for (int r = 0; r < rounds; r++) { + final BitSet s = new BitSet(numDimensions); + for (int i = 0; i < numDimensions; i++) { + s.set(i, rng.nextBoolean()); + } + signs[r] = s; + } + } + + @Override + public int getRowDimension() { + return numDimensions; + } + + @Override + public int getColumnDimension() { + return numDimensions; + } + + @Override + public boolean isTransposable() { + return true; + } + + @Nonnull + @Override + public RealVector operate(@Nonnull final RealVector x) { + return new DoubleRealVector(operate(x.getData())); + } + + @Nonnull + private double[] operate(@Nonnull final double[] x) { + if (x.length != numDimensions) { + throw new IllegalArgumentException("dimensionality of x != n"); + } + final double[] y = Arrays.copyOf(x, numDimensions); + + for (int r = 0; r < rounds; r++) { + // 1) Rademacher signs + final BitSet s = signs[r]; + for (int i = 0; i < numDimensions; i++) { + y[i] *= s.get(i) ? 1 : -1; + } + + // 2) FHT on largest 2^k block; alternate head/tail + int m = largestPow2LE(numDimensions); + int start = ((r & 1) == 0) ? 0 : (numDimensions - m); // head on even rounds, tail on odd + fhtNormalized(y, start, m); + + // 3) π/4 Givens between halves (pair i with i+h) + givensPiOver4(y); + } + return y; + } + + @Nonnull + @Override + public RealVector operateTranspose(@Nonnull final RealVector x) { + return new DoubleRealVector(operateTranspose(x.getData())); + } + + @Nonnull + public double[] operateTranspose(@Nonnull final double[] x) { + if (x.length != numDimensions) { + throw new IllegalArgumentException("dimensionality of x != n"); + } + final double[] y = Arrays.copyOf(x, numDimensions); + + for (int r = rounds - 1; r >= 0; r--) { + // Inverse of step 3: Givens transpose (angle -> -π/4) + givensMinusPiOver4(y); + + // Inverse of step 2: FWHT is its own inverse (orthonormal) + int m = largestPow2LE(numDimensions); + int start = ((r & 1) == 0) ? 0 : (numDimensions - m); + fhtNormalized(y, start, m); + + // Inverse of step 1: Rademacher signs (self-inverse) + final BitSet s = signs[r]; + for (int i = 0; i < numDimensions; i++) { + y[i] *= s.get(i) ? 1 : -1; + } + } + return y; + } + + /** + * Build dense P as double[n][n] (row-major). This method exists for testing purposes only. + */ + @Nonnull + @VisibleForTesting + public RowMajorRealMatrix computeP() { + final double[][] p = new double[numDimensions][numDimensions]; + final double[] e = new double[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + Arrays.fill(e, 0.0); + e[j] = 1.0; + double[] y = operate(e); // column j of P + for (int i = 0; i < numDimensions; i++) { + p[i][j] = y[i]; + } + } + return new RowMajorRealMatrix(p); + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof FhtKacRotator)) { + return false; + } + + final FhtKacRotator rotator = (FhtKacRotator)o; + return numDimensions == rotator.numDimensions && rounds == rotator.rounds && + Arrays.deepEquals(signs, rotator.signs); + } + + @Override + public int hashCode() { + int result = numDimensions; + result = 31 * result + rounds; + result = 31 * result + Arrays.deepHashCode(signs); + return result; + } + + // ----- internals ----- + + private static int largestPow2LE(int n) { + // highest power of two <= n + return 1 << (31 - Integer.numberOfLeadingZeros(n)); + } + + /** + * In-place normalized FHT on y[start ... start+m-1], where m is a power of two. + */ + private static void fhtNormalized(double[] y, int start, int m) { + // Cooley-Tukey style + for (int len = 1; len < m; len <<= 1) { + int step = len << 1; + for (int i = start; i < start + m; i += step) { + for (int j = 0; j < len; j++) { + int a = i + j; + int b = a + len; + double u = y[a]; + double v = y[b]; + y[a] = u + v; + y[b] = u - v; + } + } + } + double scale = 1.0 / Math.sqrt(m); + for (int i = start; i < start + m; i++) { + y[i] *= scale; + } + } + + /** + * Apply π/4 Givens rotation. + *
+     *  {@code
+     *  [u'; v'] = [  cos(π/4)  sin(π/4) ] [u]
+     *             [ -sin(π/4)  cos(π/4) ] [v]
+     *
+     *  Since cos(π/4) = sin(π/4) = 1/sqrt(2) this can be rewritten as
+     *
+     *  [u'; v'] = 1/ sqrt(2) * [  1  1 ] [u]
+     *                          [ -1  1 ] [v]
+     *
+     *  which allows for fast computation. Note that we rotate the incoming vector along many axes at once, the
+     *  two-dimensional example is for illustrative purposes only.
+     *  }
+     *  
+ */ + private static void givensPiOver4(double[] y) { + int h = nHalfFloor(y.length); + for (int i = 0; i < h; i++) { + int j = i + h; + if (j >= y.length) { + break; + } + double u = y[i]; + double v = y[j]; + double up = (u + v) * INV_SQRT2; + double vp = (-u + v) * INV_SQRT2; // -s*u + c*v with c=s + y[i] = up; + y[j] = vp; + } + } + + /** + * Apply transpose (inverse) of the π/4 Givens. + * @see #givensPiOver4(double[]) + */ + private static void givensMinusPiOver4(double[] y) { + int h = nHalfFloor(y.length); + for (int i = 0; i < h; i++) { + int j = i + h; + if (j >= y.length) { + break; + } + double u = y[i]; + double v = y[j]; + double up = (u - v) * INV_SQRT2; // c*u - s*v + double vp = (u + v) * INV_SQRT2; // s*u + c*v + y[i] = up; + y[j] = vp; + } + } + + private static int nHalfFloor(int n) { + return n >>> 1; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FloatRealVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FloatRealVector.java new file mode 100644 index 0000000000..e6fdad145c --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FloatRealVector.java @@ -0,0 +1,156 @@ +/* + * HalfRealVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.foundationdb.half.Half; +import com.google.common.base.Suppliers; +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.function.Supplier; + +/** + * A vector class encoding a vector over float components. + */ +public class FloatRealVector extends AbstractRealVector { + @Nonnull + private final Supplier toHalfRealVectorSupplier; + + public FloatRealVector(@Nonnull final Float[] floatData) { + this(computeDoubleData(floatData)); + } + + public FloatRealVector(@Nonnull final float[] floatData) { + this(computeDoubleData(floatData)); + } + + public FloatRealVector(@Nonnull final double[] data) { + super(truncateDoubleData(data)); + this.toHalfRealVectorSupplier = Suppliers.memoize(this::computeHalfRealVector); + } + + public FloatRealVector(@Nonnull final int[] intData) { + this(fromInts(intData)); + } + + public FloatRealVector(@Nonnull final long[] longData) { + this(fromLongs(longData)); + } + + @Nonnull + @Override + public HalfRealVector toHalfRealVector() { + return toHalfRealVectorSupplier.get(); + } + + @Nonnull + public HalfRealVector computeHalfRealVector() { + return new HalfRealVector(data); + } + + @Nonnull + @Override + public FloatRealVector toFloatRealVector() { + return this; + } + + @Nonnull + @Override + public DoubleRealVector toDoubleRealVector() { + return new DoubleRealVector(data); + } + + @Nonnull + @Override + public RealVector withData(@Nonnull final double[] data) { + return new FloatRealVector(data); + } + + /** + * Converts this {@link RealVector} of single precision floating-point numbers into a byte array. + *

+ * This method iterates through the input vector, converting each {@link Half} element into its 16-bit short + * representation. It then serializes this short into two bytes, placing them sequentially into the resulting byte + * array. The final array's length will be {@code 2 * vector.size()}. + * @return a new byte array representing the serialized vector data. This array is never null. + */ + @Nonnull + @Override + protected byte[] computeRawData() { + final byte[] vectorBytes = new byte[1 + 4 * getNumDimensions()]; + final ByteBuffer buffer = ByteBuffer.wrap(vectorBytes).order(ByteOrder.BIG_ENDIAN); + buffer.put((byte)VectorType.SINGLE.ordinal()); + for (int i = 0; i < getNumDimensions(); i ++) { + buffer.putFloat((float)data[i]); + } + return vectorBytes; + } + + @Nonnull + private static double[] computeDoubleData(@Nonnull Float[] floatData) { + double[] result = new double[floatData.length]; + for (int i = 0; i < floatData.length; i++) { + result[i] = floatData[i]; + } + return result; + } + + @Nonnull + private static double[] computeDoubleData(@Nonnull float[] floatData) { + double[] result = new double[floatData.length]; + for (int i = 0; i < floatData.length; i++) { + result[i] = floatData[i]; + } + return result; + } + + @Nonnull + private static double[] truncateDoubleData(@Nonnull double[] doubleData) { + double[] result = new double[doubleData.length]; + for (int i = 0; i < doubleData.length; i++) { + result[i] = (float)doubleData[i]; + } + return result; + } + + /** + * Creates a {@link FloatRealVector} from a byte array. + *

+ * This method interprets the input byte array as a sequence of 16-bit half-precision floating-point numbers. Each + * consecutive pair of bytes is converted into a {@code Half} value, which then becomes a component of the resulting + * vector. + * @param vectorBytes the non-null byte array to convert + * @return a new {@link FloatRealVector} instance created from the byte array + */ + @Nonnull + public static FloatRealVector fromBytes(@Nonnull final byte[] vectorBytes) { + final ByteBuffer buffer = ByteBuffer.wrap(vectorBytes).order(ByteOrder.BIG_ENDIAN); + Verify.verify(buffer.get() == VectorType.SINGLE.ordinal()); + final int numDimensions = vectorBytes.length >> 2; + final double[] vectorComponents = new double[numDimensions]; + for (int i = 0; i < numDimensions; i ++) { + vectorComponents[i] = buffer.getFloat(); + } + return new FloatRealVector(vectorComponents); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/HalfRealVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/HalfRealVector.java new file mode 100644 index 0000000000..ddac76d55a --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/HalfRealVector.java @@ -0,0 +1,133 @@ +/* + * HalfRealVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.foundationdb.half.Half; +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * A vector class encoding a vector over half components. Conversion to {@link DoubleRealVector} is supported and + * memoized. + */ +public class HalfRealVector extends AbstractRealVector { + public HalfRealVector(@Nonnull final Half[] halfData) { + this(computeDoubleData(halfData)); + } + + public HalfRealVector(@Nonnull final double[] data) { + super(truncateDoubleData(data)); + } + + public HalfRealVector(@Nonnull final int[] intData) { + this(fromInts(intData)); + } + + public HalfRealVector(@Nonnull final long[] longData) { + this(fromLongs(longData)); + } + + @Nonnull + @Override + public HalfRealVector toHalfRealVector() { + return this; + } + + @Nonnull + @Override + public FloatRealVector toFloatRealVector() { + return new FloatRealVector(data); + } + + @Nonnull + @Override + public DoubleRealVector toDoubleRealVector() { + return new DoubleRealVector(data); + } + + @Nonnull + @Override + public RealVector withData(@Nonnull final double[] data) { + return new HalfRealVector(data); + } + + /** + * Converts this {@link RealVector} of {@link Half} precision floating-point numbers into a byte array. + *

+ * This method iterates through the input vector, converting each {@link Half} element into its 16-bit short + * representation. It then serializes this short into two bytes, placing them sequentially into the resulting byte + * array. The final array's length will be {@code 2 * vector.size()}. + * @return a new byte array representing the serialized vector data. This array is never null. + */ + @Nonnull + @Override + protected byte[] computeRawData() { + final byte[] vectorBytes = new byte[1 + 2 * getNumDimensions()]; + final ByteBuffer buffer = ByteBuffer.wrap(vectorBytes).order(ByteOrder.BIG_ENDIAN); + buffer.put((byte)VectorType.HALF.ordinal()); + for (int i = 0; i < getNumDimensions(); i ++) { + buffer.putShort(Half.floatToShortBitsCollapseNaN(Half.quantizeFloat((float)getComponent(i)))); + } + return vectorBytes; + } + + @Nonnull + private static double[] computeDoubleData(@Nonnull Half[] halfData) { + double[] result = new double[halfData.length]; + for (int i = 0; i < halfData.length; i++) { + result[i] = halfData[i].doubleValue(); + } + return result; + } + + @Nonnull + private static double[] truncateDoubleData(@Nonnull double[] doubleData) { + double[] result = new double[doubleData.length]; + for (int i = 0; i < doubleData.length; i++) { + result[i] = Half.valueOf(doubleData[i]).doubleValue(); + } + return result; + } + + /** + * Creates a {@link HalfRealVector} from a byte array. + *

+ * This method interprets the input byte array as a sequence of 16-bit half-precision floating-point numbers. Each + * consecutive pair of bytes is converted into a {@code Half} value, which then becomes a component of the resulting + * vector. + * @param vectorBytes the non-null byte array to convert + * @return a new {@link HalfRealVector} instance created from the byte array + */ + @Nonnull + public static HalfRealVector fromBytes(@Nonnull final byte[] vectorBytes) { + final ByteBuffer buffer = ByteBuffer.wrap(vectorBytes).order(ByteOrder.BIG_ENDIAN); + Verify.verify(buffer.get() == VectorType.HALF.ordinal()); + final int numDimensions = vectorBytes.length >> 1; + final double[] vectorComponents = new double[numDimensions]; + for (int i = 0; i < numDimensions; i ++) { + vectorComponents[i] = Half.halfShortToFloat(buffer.getShort()); + } + return new HalfRealVector(vectorComponents); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/LinearOperator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/LinearOperator.java new file mode 100644 index 0000000000..f19f02f50b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/LinearOperator.java @@ -0,0 +1,41 @@ +/* + * LinearOperator.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import javax.annotation.Nonnull; + +public interface LinearOperator { + int getRowDimension(); + + int getColumnDimension(); + + default boolean isSquare() { + return getRowDimension() == getColumnDimension(); + } + + boolean isTransposable(); + + @Nonnull + RealVector operate(@Nonnull RealVector vector); + + @Nonnull + RealVector operateTranspose(@Nonnull RealVector vector); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/MatrixHelpers.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/MatrixHelpers.java new file mode 100644 index 0000000000..4379e91a7a --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/MatrixHelpers.java @@ -0,0 +1,49 @@ +/* + * MatrixHelpers.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import javax.annotation.Nonnull; +import java.util.Random; + +public class MatrixHelpers { + + private MatrixHelpers() { + // nothing + } + + @Nonnull + public static RealMatrix randomOrthogonalMatrix(@Nonnull final Random random, final int dimension) { + return QRDecomposition.decomposeMatrix(randomGaussianMatrix(random, dimension, dimension)).getQ(); + } + + @Nonnull + public static RealMatrix randomGaussianMatrix(@Nonnull final Random random, + final int rowDimension, + final int columnDimension) { + final double[][] resultMatrix = new double[rowDimension][columnDimension]; + for (int row = 0; row < rowDimension; row++) { + for (int column = 0; column < columnDimension; column++) { + resultMatrix[row][column] = random.nextGaussian(); + } + } + return new RowMajorRealMatrix(resultMatrix); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Metric.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Metric.java new file mode 100644 index 0000000000..980f92d2f6 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Metric.java @@ -0,0 +1,150 @@ +/* + * Metric.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import javax.annotation.Nonnull; + +/** + * Represents various distance calculation strategies (metrics) for vectors. + *

+ * Each enum constant holds a specific metric implementation, providing a type-safe way to calculate the distance + * between two points in a multidimensional space. + * + * @see MetricDefinition + */ +public enum Metric implements MetricDefinition { + /** + * Represents the Manhattan distance metric, implemented by {@link MetricDefinition.ManhattanMetric}. + *

+ * This metric calculates a distance overlaying the multidimensional space with a grid-like structure only allowing + * orthogonal lines. In 2D this resembles the street structure in Manhattan where one would have to go {@code x} + * blocks north/south and {@code y} blocks east/west leading to a total distance of {@code x + y}. + * @see MetricDefinition.ManhattanMetric + */ + MANHATTAN_METRIC(new MetricDefinition.ManhattanMetric()), + + /** + * Represents the Euclidean distance metric, implemented by {@link MetricDefinition.EuclideanMetric}. + *

+ * This metric calculates the "ordinary" straight-line distance between two points + * in Euclidean space. The distance is the square root of the sum of the + * squared differences between the corresponding coordinates of the two points. + * @see MetricDefinition.EuclideanMetric + */ + EUCLIDEAN_METRIC(new MetricDefinition.EuclideanMetric()), + + /** + * Represents the squared Euclidean distance metric, implemented by {@link MetricDefinition.EuclideanSquareMetric}. + *

+ * This metric calculates the sum of the squared differences between the coordinates of two vectors, defined as + * {@code sum((p_i - q_i)^2)}. It is computationally less expensive than the standard Euclidean distance because it + * avoids the final square root operation. + *

+ * This is often preferred in algorithms where comparing distances is more important than the actual distance value, + * such as in clustering algorithms, as it preserves the relative ordering of distances. + * + * @see Squared Euclidean + * distance + * @see MetricDefinition.EuclideanSquareMetric + */ + EUCLIDEAN_SQUARE_METRIC(new MetricDefinition.EuclideanSquareMetric()), + + /** + * Represents the Cosine distance metric, implemented by {@link MetricDefinition.CosineMetric}. + *

+ * This metric calculates a "distance" between two vectors {@code v1} and {@code v2} that ranges between + * {@code 0.0d} and {@code 2.0d} that corresponds to {@code 1 - cos(v1, v2)}, meaning that if {@code v1 == v2}, + * the distance is {@code 0} while if {@code v1} is orthogonal to {@code v2} it is {@code 1}. + * @see MetricDefinition.CosineMetric + */ + COSINE_METRIC(new MetricDefinition.CosineMetric()), + + /** + * Dot product similarity, implemented by {@link MetricDefinition.DotProductMetric} + *

+ * This metric calculates the inverted dot product of two vectors. It is not a true metric as several properties of + * true metrics do not hold, for instance this metric can be negative. + * + * @see Dot Product + * @see MetricDefinition.DotProductMetric + */ + DOT_PRODUCT_METRIC(new MetricDefinition.DotProductMetric()); + + @Nonnull + private final MetricDefinition metricDefinition; + + /** + * Constructs a new Metric instance with the specified metric. + * @param metricDefinition the metric to be associated with this Metric instance; must not be null. + */ + Metric(@Nonnull final MetricDefinition metricDefinition) { + this.metricDefinition = metricDefinition; + } + + @Override + public boolean satisfiesZeroSelfDistance() { + return metricDefinition.satisfiesZeroSelfDistance(); + } + + @Override + public boolean satisfiesPositivity() { + return metricDefinition.satisfiesPositivity(); + } + + @Override + public boolean satisfiesSymmetry() { + return metricDefinition.satisfiesSymmetry(); + } + + @Override + public boolean satisfiesTriangleInequality() { + return metricDefinition.satisfiesTriangleInequality(); + } + + @Override + public double distance(@Nonnull final double[] vectorData1, @Nonnull final double[] vectorData2) { + return metricDefinition.distance(vectorData1, vectorData2); + } + + /** + * Calculates a distance between two n-dimensional vectors. + *

+ * The two vectors are represented as arrays of {@link Double} and must be of the + * same length (i.e., have the same number of dimensions). + * + * @param vector1 the first vector. Must not be null. + * @param vector2 the second vector. Must not be null and must have the same + * length as {@code vector1}. + * + * @return the calculated distance as a {@code double}. + * + * @throws IllegalArgumentException if the vectors have different lengths. + * @throws NullPointerException if either {@code vector1} or {@code vector2} is null. + */ + public double distance(@Nonnull RealVector vector1, @Nonnull RealVector vector2) { + return distance(vector1.getData(), vector2.getData()); + } + + @Override + public String toString() { + return metricDefinition.toString(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/MetricDefinition.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/MetricDefinition.java new file mode 100644 index 0000000000..bde98cf512 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/MetricDefinition.java @@ -0,0 +1,301 @@ +/* + * MetricDefinition.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import javax.annotation.Nonnull; + +/** + * Defines a metric for measuring the distance or similarity between n-dimensional vectors. + *

+ * This interface provides a contract for various distance calculation algorithms, such as Euclidean, Manhattan, + * and Cosine distance. Implementations of this interface can be used in algorithms that require a metric for + * comparing data vectors, like clustering or nearest neighbor searches. + */ +interface MetricDefinition { + /** + * Method to be implemented by the specific metric. + * @return {@code true} iff for all {@link RealVector}s {@code x} holds that {@code distance(x, x) == 0} + */ + default boolean satisfiesZeroSelfDistance() { + return true; + } + + /** + * Method to be implemented by the specific metric. + * @return {@code true} iff for all {@link RealVector}s {@code x, y} holds that {@code distance(x, y) >= 0} + */ + default boolean satisfiesPositivity() { + return true; + } + + /** + * Method to be implemented by the specific metric. + * @return {@code true} iff for all {@link RealVector}s {@code x, y} holds that + * {@code distance(x, y) == distance(y, x)} + */ + default boolean satisfiesSymmetry() { + return true; + } + + /** + * Method to be implemented by the specific metric. + * @return {@code true} iff for all {@link RealVector}s {@code x, y, z} holds that + * {@code distance(x, y) + distance(y, z) >= distance(x, z)} + */ + default boolean satisfiesTriangleInequality() { + return true; + } + + /** + * Convenience method that returns if all properties of a metric required to be a true metric are satisfied. + * @return {@code true} iff this metric is a true metric. + */ + default boolean isTrueMetric() { + return satisfiesZeroSelfDistance() && + satisfiesPositivity() && + satisfiesSymmetry() && + satisfiesTriangleInequality(); + } + + @Nonnull + static String toString(@Nonnull final MetricDefinition metricDefinition) { + return metricDefinition.getClass().getSimpleName() + ";" + metricDefinition.isTrueMetric() + " metric"; + } + + /** + * Calculates a distance between two n-dimensional vectors. + *

+ * The two vectors are represented as arrays of {@link Double} and must be of the + * same length (i.e., have the same number of dimensions). + * + * @param vector1 the first vector. Must not be null. + * @param vector2 the second vector. Must not be null and must have the same + * length as {@code vector1}. + * + * @return the calculated distance as a {@code double}. + * + * @throws IllegalArgumentException if the vectors have different lengths. + * @throws NullPointerException if either {@code vector1} or {@code vector2} is null. + */ + double distance(@Nonnull double[] vector1, @Nonnull double[] vector2); + + /** + * A helper method to validate that vectors can be compared. + * @param vector1 The first vector. + * @param vector2 The second vector. + */ + private static void validate(double[] vector1, double[] vector2) { + if (vector1 == null || vector2 == null) { + throw new IllegalArgumentException("Vectors cannot be null"); + } + if (vector1.length != vector2.length) { + throw new IllegalArgumentException( + "Vectors must have the same dimensionality. Got " + vector1.length + " and " + vector2.length + ); + } + if (vector1.length == 0) { + throw new IllegalArgumentException("Vectors cannot be empty."); + } + } + + /** + * Represents the Manhattan distance metric. + *

+ * This metric calculates a distance overlaying the multidimensional space with a grid-like structure only allowing + * orthogonal lines. In 2D this resembles the street structure in Manhattan where one would have to go {@code x} + * blocks north/south and {@code y} blocks east/west leading to a total distance of {@code x + y}. + */ + final class ManhattanMetric implements MetricDefinition { + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + MetricDefinition.validate(vector1, vector2); + double sumOfAbsDiffs = 0.0; + for (int i = 0; i < vector1.length; i++) { + sumOfAbsDiffs += Math.abs(vector1[i] - vector2[i]); + } + return sumOfAbsDiffs; + } + + @Override + @Nonnull + public String toString() { + return MetricDefinition.toString(this); + } + } + + /** + * Represents the Euclidean distance metric. + *

+ * This metric calculates the "ordinary" straight-line distance between two points + * in Euclidean space. The distance is the square root of the sum of the + * squared differences between the corresponding coordinates of the two points. + */ + final class EuclideanMetric implements MetricDefinition { + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + MetricDefinition.validate(vector1, vector2); + + return Math.sqrt(EuclideanSquareMetric.distanceInternal(vector1, vector2)); + } + + @Override + @Nonnull + public String toString() { + return MetricDefinition.toString(this); + } + } + + /** + * Represents the squared Euclidean distance metric. + *

+ * This metric calculates the sum of the squared differences between the coordinates of two vectors, defined as + * {@code sum((p_i - q_i)^2)}. It is computationally less expensive than the standard Euclidean distance because it + * avoids the final square root operation. + *

+ * This is often preferred in algorithms where comparing distances is more important than the actual distance value, + * such as in clustering algorithms, as it preserves the relative ordering of distances. + * + * @see Squared Euclidean + * distance + */ + final class EuclideanSquareMetric implements MetricDefinition { + @Override + public boolean satisfiesTriangleInequality() { + return false; + } + + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + MetricDefinition.validate(vector1, vector2); + return distanceInternal(vector1, vector2); + } + + private static double distanceInternal(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + double sumOfSquares = 0.0d; + for (int i = 0; i < vector1.length; i++) { + double diff = vector1[i] - vector2[i]; + sumOfSquares += diff * diff; + } + return sumOfSquares; + } + + @Override + @Nonnull + public String toString() { + return MetricDefinition.toString(this); + } + } + + /** + * Represents the Cosine distance metric. + *

+ * This metric calculates a "distance" between two vectors {@code v1} and {@code v2} that ranges between + * {@code 0.0d} and {@code 2.0d} that corresponds to {@code 1 - cos(v1, v2)}, meaning that if {@code v1 == v2}, + * the distance is {@code 0} while if {@code v1} is orthogonal to {@code v2} it is {@code 1}. + * @see MetricDefinition.CosineMetric + */ + final class CosineMetric implements MetricDefinition { + @Override + public boolean satisfiesTriangleInequality() { + return false; + } + + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + MetricDefinition.validate(vector1, vector2); + + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < vector1.length; i++) { + normA += vector1[i] * vector1[i]; + normB += vector2[i] * vector2[i]; + } + + // Handle the case of zero-vectors to avoid division by zero + if (normA == 0.0 || normB == 0.0) { + return Double.POSITIVE_INFINITY; + } + + final double dotProduct = DotProductMetric.dotProduct(vector1, vector2); + + if (!Double.isFinite(normA) || !Double.isFinite(normB) || !Double.isFinite(dotProduct)) { + return Double.NaN; + } + + return 1.0d - DotProductMetric.dotProduct(vector1, vector2) / (Math.sqrt(normA) * Math.sqrt(normB)); + } + + @Override + @Nonnull + public String toString() { + return MetricDefinition.toString(this); + } + } + + /** + * Dot product similarity. + *

+ * This metric calculates the inverted dot product of two vectors. It is not a true metric as the dot product can + * be positive at which point the distance is negative. In order to make callers aware of this fact, this distance + * only allows {@link MetricDefinition#distance(double[], double[])} to be called. + * + * @see Dot Product + * @see DotProductMetric + */ + final class DotProductMetric implements MetricDefinition { + @Override + public boolean satisfiesZeroSelfDistance() { + return false; + } + + @Override + public boolean satisfiesPositivity() { + return false; + } + + @Override + public boolean satisfiesTriangleInequality() { + return false; + } + + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + return -dotProduct(vector1, vector2); + } + + public static double dotProduct(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + MetricDefinition.validate(vector1, vector2); + + double product = 0.0d; + for (int i = 0; i < vector1.length; i++) { + product += vector1[i] * vector2[i]; + } + return product; + } + + @Override + @Nonnull + public String toString() { + return MetricDefinition.toString(this); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/QRDecomposition.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/QRDecomposition.java new file mode 100644 index 0000000000..da7038e578 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/QRDecomposition.java @@ -0,0 +1,244 @@ +/* + * QRDecomposition.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.base.Preconditions; + +import javax.annotation.Nonnull; +import java.util.function.Supplier; + +/** + * Provides a static method to compute the QR decomposition of a matrix. + *

+ * This class is a utility class and cannot be instantiated. The decomposition + * is performed using the Householder reflection method. The result of the + * decomposition of a matrix A is an orthogonal matrix Q and an upper-triangular + * matrix R such that {@code A = QR}. + */ +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class QRDecomposition { + /** + * Private constructor to prevent instantiation of this utility class. + */ + private QRDecomposition() { + // nothing + } + + /** + * Decomposes a square matrix A into an orthogonal matrix Q and an upper + * triangular matrix R, such that A = QR. + *

+ * This implementation uses the Householder reflection method to perform the + * decomposition. The resulting Q and R matrices are not computed immediately but are + * available through suppliers within the returned {@link Result} object, allowing for + * lazy evaluation. The decomposition is performed on the transpose of the input matrix + * for efficiency. + * + * @param matrix the square matrix to decompose. Must not be null. + * + * @return a {@link Result} object containing suppliers for the Q and R matrices. + * + * @throws IllegalArgumentException if the provided {@code matrix} is not square. + */ + @Nonnull + public static Result decomposeMatrix(@Nonnull final RealMatrix matrix) { + Preconditions.checkArgument(matrix.isSquare()); + + final double[] rDiagonal = new double[matrix.getRowDimension()]; + final double[][] qrt = matrix.toRowMajor().transpose().getData(); + + for (int minor = 0; minor < matrix.getRowDimension(); minor++) { + performHouseholderReflection(minor, qrt, rDiagonal); + } + + return new Result(() -> getQ(qrt, rDiagonal), () -> getR(qrt, rDiagonal)); + } + + /** + * Performs a Householder reflection on a minor of a matrix. + *

+ * This method is a core step in QR decomposition. It transforms the {@code minor}-th + * column of the {@code qrt} matrix into a vector with a single non-zero element {@code a} + * (which becomes the new diagonal element of the R matrix), and applies the same + * transformation to the remaining columns of the minor. The transformation is done in-place. + *

+ * The reflection is defined by a matrix {@code H = I - 2vv'/|v|^2}, where the vector {@code v} + * is derived from the {@code minor}-th column of the matrix. + * + * @param minor the index of the minor matrix to be transformed. + * @param qrt the matrix to be transformed in-place. On exit, this matrix is + * updated to reflect the Householder transformation. + * @param rDiagonal an array where the diagonal element of the R matrix for the + * current {@code minor} will be stored. + */ + private static void performHouseholderReflection(final int minor, final double[][] qrt, + final double[] rDiagonal) { + + final double[] qrtMinor = qrt[minor]; + + /* + * Let x be the first column of the minor, and a^2 = |x|^2. + * x will be in the positions qr[minor][minor] through qr[m][minor]. + * The first column of the transformed minor will be (a, 0, 0, ...)' + * The sign of "a" is chosen to be opposite to the sign of the first + * component of x. Let's find "a": + */ + double xNormSqr = 0; + for (int row = minor; row < qrtMinor.length; row++) { + final double c = qrtMinor[row]; + xNormSqr += c * c; + } + final double a = (qrtMinor[minor] > 0) ? -Math.sqrt(xNormSqr) : Math.sqrt(xNormSqr); + rDiagonal[minor] = a; + + if (a != 0.0) { + /* + * Calculate the normalized reflection vector v and transform + * the first column. We know the norm of v beforehand: v = x-ae + * so |v|^2 = = -2a+a^2 = + * a^2+a^2-2a = 2a*(a - ). + * Here is now qr[minor][minor]. + * v = x-ae is stored in the column at qr: + */ + qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor]) + + /* + * Transform the rest of the columns of the minor: + * They will be transformed by the matrix H = I-2vv'/|v|^2. + * If x is a column vector of the minor, then + * Hx = (I-2vv'/|v|^2)x = x-2v*v'x/|v|^2 = x - 2/|v|^2 v. + * Therefore, the transformation is easily calculated by + * subtracting the column vector (2/|v|^2)v from x. + * + * Let 2/|v|^2 = alpha. From above, we have + * |v|^2 = -2a*(qr[minor][minor]), so + * alpha = -/(a*qr[minor][minor]) + */ + for (int col = minor + 1; col < qrt.length; col++) { + final double[] qrtCol = qrt[col]; + double alpha = 0; + for (int row = minor; row < qrtCol.length; row++) { + alpha -= qrtCol[row] * qrtMinor[row]; + } + alpha /= a * qrtMinor[minor]; + + // Subtract the column vector alpha*v from x. + for (int row = minor; row < qrtCol.length; row++) { + qrtCol[row] -= alpha * qrtMinor[row]; + } + } + } + } + + /** + * Returns the matrix {@code Q} of the decomposition where {@code Q} is an orthogonal matrix. + * @return the {@code Q} matrix + */ + @Nonnull + private static RealMatrix getQ(final double[][] qrt, final double[] rDiagonal) { + final int m = qrt.length; + double[][] q = new double[m][m]; + + for (int minor = m - 1; minor >= 0; minor--) { + final double[] qrtMinor = qrt[minor]; + q[minor][minor] = 1.0d; + if (qrtMinor[minor] != 0.0) { + for (int col = minor; col < m; col++) { + double alpha = 0; + for (int row = minor; row < m; row++) { + alpha -= q[row][col] * qrtMinor[row]; + } + alpha /= rDiagonal[minor] * qrtMinor[minor]; + + for (int row = minor; row < m; row++) { + q[row][col] += -alpha * qrtMinor[row]; + } + } + } + } + return new RowMajorRealMatrix(q); + } + + /** + * Constructs the upper-triangular R matrix from a QRT decomposition's packed storage. + *

+ * This is a helper method that reconstructs the {@code R} matrix from the compact + * representation used by some QR decomposition algorithms. The resulting matrix @{code R} + * is upper-triangular. + *

+ * The upper-triangular elements (where row index {@code i < j}) are extracted + * from the {@code qrt} matrix at transposed indices ({@code qrt[j][i]}). The + * diagonal elements (where {@code i == j}) are taken from the {@code rDiagonal} + * array. All lower-triangular elements (where {@code i > j}) are set to 0.0. + *

+ * + * @param qrt The packed QRT decomposition data. The strict upper-triangular + * part of {@code R} is stored in this matrix. + * @param rDiagonal An array containing the diagonal elements of the R matrix. + * + * @return The reconstructed upper-triangular R matrix as a {@link RealMatrix}. + */ + @Nonnull + private static RealMatrix getR(final double[][] qrt, final double[] rDiagonal) { + final int m = qrt.length; // square in this helper + final double[][] r = new double[m][m]; + + // R is upper-triangular. With Commons-Math style storage: + // - for i < j, R[i][j] is in qrt[j][i] + // - for i == j, R[i][i] comes from rDiag[i] + // - for i > j, R[i][j] = 0 + for (int i = 0; i < m; i++) { + for (int j = 0; j < m; j++) { + if (i < j) { + r[i][j] = qrt[j][i]; + } else if (i == j) { + r[i][j] = rDiagonal[i]; + } else { + r[i][j] = 0.0; + } + } + } + return new RowMajorRealMatrix(r); + } + + @SuppressWarnings("checkstyle:MemberName") + public static class Result { + @Nonnull + private final Supplier qSupplier; + @Nonnull + private final Supplier rSupplier; + + public Result(@Nonnull final Supplier qSupplier, @Nonnull final Supplier rSupplier) { + this.qSupplier = qSupplier; + this.rSupplier = rSupplier; + } + + @Nonnull + RealMatrix getQ() { + return qSupplier.get(); + } + + @Nonnull + RealMatrix getR() { + return rSupplier.get(); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Quantizer.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Quantizer.java new file mode 100644 index 0000000000..b8018a7320 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Quantizer.java @@ -0,0 +1,87 @@ +/* + * Quantizer.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import javax.annotation.Nonnull; + +/** + * Defines the contract for a quantizer, a component responsible for encoding data vectors into a different, ideally + * a more compact, representation. + *

+ * Quantizers are typically used in machine learning and information retrieval to transform raw data into a format that + * is more suitable for processing, such as a compressed representation. + */ +public interface Quantizer { + /** + * Returns the {@code Estimator} instance associated with this object. + *

+ * The estimator is responsible for performing the primary distance estimation or calculation logic. This method + * provides access to that underlying component. + * + * @return the {@link Estimator} instance, which is guaranteed to be non-null. + */ + @Nonnull + Estimator estimator(); + + /** + * Encodes the given data vector into another vector representation. + *

+ * This method transforms the raw input data into a different, quantized format, which is often a vector more + * suitable for processing/storing the data. The specifics of the encoding depend on the implementation of the class. + * + * @param data the input {@link RealVector} to be encoded. Must not be {@code null} and is assumed to have been + * preprocessed, such as by rotation and/or translation. The preprocessing has to align with the requirements + * of the specific quantizer. + * @return the encoded vector representation of the input data, guaranteed to be non-null. + */ + @Nonnull + RealVector encode(@Nonnull RealVector data); + + /** + * Creates a no-op {@code Quantizer} that does not perform any data transformation. + *

+ * The returned quantizer's {@link Quantizer#encode(RealVector)} method acts as an + * identity function, returning the input vector without modification. The + * {@link Quantizer#estimator()} is created directly from the distance function + * of the provided {@link Metric}. This can be useful for baseline comparisons + * or for algorithms that require a {@code Quantizer} but where no quantization + * is desired. + * + * @param metric the {@link Metric} used to build the distance estimator for the quantizer. + * @return a new {@link Quantizer} instance that performs no operation. + */ + @Nonnull + static Quantizer noOpQuantizer(@Nonnull final Metric metric) { + return new Quantizer() { + @Nonnull + @Override + public Estimator estimator() { + return metric::distance; + } + + @Nonnull + @Override + public RealVector encode(@Nonnull final RealVector data) { + return data; + } + }; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealMatrix.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealMatrix.java new file mode 100644 index 0000000000..4d9e4638f3 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealMatrix.java @@ -0,0 +1,117 @@ +/* + * RealMatrix.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +public interface RealMatrix extends LinearOperator { + @Nonnull + double[][] getData(); + + double getEntry(int row, int column); + + @Override + default boolean isTransposable() { + return true; + } + + @Nonnull + RealMatrix transpose(); + + @Nonnull + @Override + default RealVector operate(@Nonnull final RealVector vector) { + Verify.verify(getColumnDimension() == vector.getNumDimensions()); + final double[] result = new double[getRowDimension()]; + for (int i = 0; i < getRowDimension(); i ++) { + double sum = 0.0d; + for (int j = 0; j < getColumnDimension(); j ++) { + sum += getEntry(i, j) * vector.getComponent(j); + } + result[i] = sum; + } + return new DoubleRealVector(result); + } + + @Nonnull + @Override + default RealVector operateTranspose(@Nonnull final RealVector vector) { + Verify.verify(getRowDimension() == vector.getNumDimensions()); + final double[] result = new double[getColumnDimension()]; + for (int j = 0; j < getColumnDimension(); j ++) { + double sum = 0.0d; + for (int i = 0; i < getRowDimension(); i ++) { + sum += getEntry(i, j) * vector.getComponent(i); + } + result[j] = sum; + } + return new DoubleRealVector(result); + } + + @Nonnull + RealMatrix multiply(@Nonnull RealMatrix otherMatrix); + + @Nonnull + RealMatrix subMatrix(int startRow, int lengthRow, int startColumn, int lengthColumn); + + @Nonnull + RowMajorRealMatrix toRowMajor(); + + @Nonnull + ColumnMajorRealMatrix toColumnMajor(); + + @Nonnull + RealMatrix quickTranspose(); + + default boolean valueEquals(@Nullable final Object o) { + if (!(o instanceof RealMatrix)) { + return false; + } + + final RealMatrix that = (RealMatrix)o; + if (getRowDimension() != that.getRowDimension() || + getColumnDimension() != that.getColumnDimension()) { + return false; + } + + for (int i = 0; i < getRowDimension(); i ++) { + for (int j = 0; j < getColumnDimension(); j ++) { + if (getEntry(i, j) != that.getEntry(i, j)) { + return false; + } + } + } + return true; + } + + default int valueBasedHashCode() { + int hashCode = 0; + for (int i = 0; i < getRowDimension(); i ++) { + for (int j = 0; j < getColumnDimension(); j ++) { + hashCode += 31 * Double.hashCode(getEntry(i, j)); + } + } + return hashCode; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java new file mode 100644 index 0000000000..44ff3c826d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java @@ -0,0 +1,192 @@ +/* + * RealVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.base.Preconditions; +import com.apple.foundationdb.half.Half; + +import javax.annotation.Nonnull; + +/** + * An abstract base class representing a mathematical vector. + *

+ * This class provides a generic framework for vectors of different numerical types, + * where {@code R} is a subtype of {@link Number}. It includes common operations and functionalities like size, + * component access, equality checks, and conversions. Concrete implementations must provide specific logic for + * data type conversions and raw data representation. + */ +public interface RealVector { + /** + * Returns the number of elements in the vector, i.e. the number of dimensions. + * @return the number of dimensions + */ + int getNumDimensions(); + + /** + * Gets the component of this object at the specified dimension. + *

+ * The dimension is a zero-based index. For a 3D vector, for example, dimension 0 might correspond to the + * x-component, 1 to the y-component, and 2 to the z-component. This method provides direct access to the + * underlying data element. + * @param dimension the zero-based index of the component to retrieve. + * @return the component at the specified dimension, which is guaranteed to be non-null. + * @throws IndexOutOfBoundsException if the {@code dimension} is negative or + * greater than or equal to the number of dimensions of this object. + */ + double getComponent(int dimension); + + /** + * Returns the underlying data array. + *

+ * The returned array is guaranteed to be non-null. Note that this method + * returns a direct reference to the internal array, not a copy. + * @return the data array of type {@code R[]}, never {@code null}. + */ + @Nonnull + double[] getData(); + + @Nonnull + RealVector withData(@Nonnull double[] data); + + /** + * Gets the raw byte data representation of this object. + *

+ * This method provides a direct, unprocessed view of the object's underlying data. The format of the byte array is + * implementation-specific and should be documented by the concrete class that implements this method. + * @return a non-null byte array containing the raw data. + */ + @Nonnull + byte[] getRawData(); + + /** + * Converts this object into a {@code RealVector} of {@link Half} precision floating-point numbers. + *

+ * As this is an abstract method, implementing classes are responsible for defining the specific conversion logic + * from their internal representation to a {@code RealVector} using {@link Half} objects to serialize and + * deserialize the vector. If this object already is a {@code HalfRealVector} this method should return {@code this}. + * @return a non-null {@link HalfRealVector} containing the {@link Half} precision floating-point representation of + * this object. + */ + @Nonnull + HalfRealVector toHalfRealVector(); + + /** + * Converts this object into a {@code RealVector} of single precision floating-point numbers. + *

+ * As this is an abstract method, implementing classes are responsible for defining the specific conversion logic + * from their internal representation to a {@code RealVector} using floating point numbers to serialize and + * deserialize the vector. If this object already is a {@code FloatRealVector} this method should return + * {@code this}. + * @return a non-null {@link FloatRealVector} containing the single precision floating-point representation of + * this object. + */ + @Nonnull + FloatRealVector toFloatRealVector(); + + /** + * Converts this vector into a {@link DoubleRealVector}. + *

+ * This method provides a way to obtain a double-precision floating-point representation of the vector. If the + * vector is already an instance of {@code DoubleRealVector}, this method may return the instance itself. Otherwise, + * it will create a new {@code DoubleRealVector} containing the same elements, which may involve a conversion of the + * underlying data type. + * @return a non-null {@link DoubleRealVector} representation of this vector. + */ + @Nonnull + DoubleRealVector toDoubleRealVector(); + + default double dot(@Nonnull final RealVector other) { + Preconditions.checkArgument(getNumDimensions() == other.getNumDimensions()); + double sum = 0.0d; + final double[] thisData = getData(); + final double[] otherData = other.getData(); + for (int i = 0; i < thisData.length; i++) { + sum += thisData[i] * otherData[i]; + } + return sum; + } + + default double l2Norm() { + return Math.sqrt(dot(this)); + } + + @Nonnull + default RealVector normalize() { + double n = l2Norm(); + final int numDimensions = getNumDimensions(); + double[] y = new double[numDimensions]; + if (n == 0.0 || !Double.isFinite(n)) { + return withData(y); // all zeros + } + double inv = 1.0 / n; + for (int i = 0; i < numDimensions; i++) { + y[i] = getComponent(i) * inv; + } + return withData(y); + } + + @Nonnull + default RealVector add(@Nonnull final RealVector other) { + Preconditions.checkArgument(getNumDimensions() == other.getNumDimensions()); + final double[] result = new double[getNumDimensions()]; + for (int i = 0; i < getNumDimensions(); i ++) { + result[i] = getComponent(i) + other.getComponent(i); + } + return withData(result); + } + + @Nonnull + default RealVector add(final double scalar) { + final double[] result = new double[getNumDimensions()]; + for (int i = 0; i < getNumDimensions(); i ++) { + result[i] = getComponent(i) + scalar; + } + return withData(result); + } + + @Nonnull + default RealVector subtract(@Nonnull final RealVector other) { + Preconditions.checkArgument(getNumDimensions() == other.getNumDimensions()); + final double[] result = new double[getNumDimensions()]; + for (int i = 0; i < getNumDimensions(); i ++) { + result[i] = getComponent(i) - other.getComponent(i); + } + return withData(result); + } + + @Nonnull + default RealVector subtract(final double scalar) { + final double[] result = new double[getNumDimensions()]; + for (int i = 0; i < getNumDimensions(); i ++) { + result[i] = getComponent(i) - scalar; + } + return withData(result); + } + + @Nonnull + default RealVector multiply(final double scalar) { + final double[] result = new double[getNumDimensions()]; + for (int i = 0; i < getNumDimensions(); i ++) { + result[i] = getComponent(i) * scalar; + } + return withData(result); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RowMajorRealMatrix.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RowMajorRealMatrix.java new file mode 100644 index 0000000000..502c13d081 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RowMajorRealMatrix.java @@ -0,0 +1,144 @@ +/* + * RowMajorRealMatrix.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; + +import javax.annotation.Nonnull; +import java.util.Arrays; +import java.util.function.Supplier; + +public class RowMajorRealMatrix implements RealMatrix { + @Nonnull + private final double[][] data; + @Nonnull + private final Supplier hashCodeSupplier; + + public RowMajorRealMatrix(@Nonnull final double[][] data) { + Preconditions.checkArgument(data.length > 0); + Preconditions.checkArgument(data[0].length > 0); + this.data = data; + this.hashCodeSupplier = Suppliers.memoize(this::valueBasedHashCode); + } + + @Nonnull + @Override + public double[][] getData() { + return data; + } + + @Override + public int getRowDimension() { + return data.length; + } + + @Override + public int getColumnDimension() { + return data[0].length; + } + + @Override + public double getEntry(final int row, final int column) { + return data[row][column]; + } + + @Nonnull + public double[] getRow(final int row) { + return data[row]; + } + + @Nonnull + @Override + public RealMatrix transpose() { + int n = getRowDimension(); + int m = getColumnDimension(); + double[][] result = new double[m][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < m; j++) { + result[j][i] = getEntry(i, j); + } + } + return new RowMajorRealMatrix(result); + } + + @Nonnull + @Override + public RealMatrix multiply(@Nonnull final RealMatrix otherMatrix) { + Preconditions.checkArgument(getColumnDimension() == otherMatrix.getRowDimension()); + final int n = getRowDimension(); + final int m = otherMatrix.getColumnDimension(); + final int common = getColumnDimension(); + double[][] result = new double[n][m]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < m; j++) { + for (int k = 0; k < common; k++) { + result[i][j] += getEntry(i, k) * otherMatrix.getEntry(k, j); + } + } + } + return new RowMajorRealMatrix(result); + } + + @Nonnull + @Override + public RealMatrix subMatrix(final int startRow, final int lengthRow, final int startColumn, final int lengthColumn) { + final double[][] subData = new double[lengthRow][lengthColumn]; + + for (int i = startRow; i < startRow + lengthRow; i ++) { + System.arraycopy(data[i], startColumn, subData[i - startRow], 0, lengthColumn); + } + + return new RowMajorRealMatrix(subData); + } + + @Nonnull + @Override + public RowMajorRealMatrix toRowMajor() { + return this; + } + + @Nonnull + @Override + public ColumnMajorRealMatrix toColumnMajor() { + return new ColumnMajorRealMatrix(transpose().getData()); + } + + @Nonnull + @Override + public RealMatrix quickTranspose() { + return new ColumnMajorRealMatrix(data); + } + + @Override + public final boolean equals(final Object o) { + if (o instanceof RowMajorRealMatrix) { + final RowMajorRealMatrix that = (RowMajorRealMatrix)o; + return Arrays.deepEquals(data, that.data); + } + return valueEquals(o); + } + + @Override + public int hashCode() { + return hashCodeSupplier.get(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/StoredVecsIterator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/StoredVecsIterator.java new file mode 100644 index 0000000000..1aab3625e9 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/StoredVecsIterator.java @@ -0,0 +1,151 @@ +/* + * StoredVecsIterator.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.util.List; + +/** + * Abstract iterator implementation to read the IVecs/FVecs data format that is used by publicly available + * embedding datasets. + * + * @param the component type of the vectors which must extends {@link Number} + * @param the type of object this iterator creates and uses to represent a stored vector in memory + */ +public abstract class StoredVecsIterator extends AbstractIterator { + @Nonnull + private final FileChannel fileChannel; + + protected StoredVecsIterator(@Nonnull final FileChannel fileChannel) { + this.fileChannel = fileChannel; + } + + @Nonnull + protected abstract N[] newComponentArray(int size); + + @Nonnull + protected abstract N toComponent(@Nonnull ByteBuffer byteBuffer); + + @Nonnull + protected abstract T toTarget(@Nonnull N[] components); + + @Nullable + @Override + protected T computeNext() { + try { + final ByteBuffer headerBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); + // allocate a buffer for reading floats later; you may reuse + headerBuf.clear(); + final int bytesRead = fileChannel.read(headerBuf); + if (bytesRead < 4) { + if (bytesRead == -1) { + return endOfData(); + } + throw new IOException("corrupt fvecs file"); + } + headerBuf.flip(); + final int dims = headerBuf.getInt(); + if (dims <= 0) { + throw new IOException("Invalid dimension " + dims + " at position " + (fileChannel.position() - 4)); + } + final ByteBuffer vecBuf = ByteBuffer.allocate(dims * 4).order(ByteOrder.LITTLE_ENDIAN); + while (vecBuf.hasRemaining()) { + int read = fileChannel.read(vecBuf); + if (read < 0) { + throw new EOFException("unexpected EOF when reading vector data"); + } + } + vecBuf.flip(); + final N[] rawVecData = newComponentArray(dims); + for (int i = 0; i < dims; i++) { + rawVecData[i] = toComponent(vecBuf); + } + + return toTarget(rawVecData); + } catch (final IOException ioE) { + throw new RuntimeException(ioE); + } + } + + /** + * Iterator to read floating point vectors from a {@link FileChannel} providing an iterator of + * {@link DoubleRealVector}s. + */ + public static class StoredFVecsIterator extends StoredVecsIterator { + public StoredFVecsIterator(@Nonnull final FileChannel fileChannel) { + super(fileChannel); + } + + @Nonnull + @Override + protected Double[] newComponentArray(final int size) { + return new Double[size]; + } + + @Nonnull + @Override + protected Double toComponent(@Nonnull final ByteBuffer byteBuffer) { + return (double)byteBuffer.getFloat(); + } + + @Nonnull + @Override + protected DoubleRealVector toTarget(@Nonnull final Double[] components) { + return new DoubleRealVector(components); + } + } + + /** + * Iterator to read vectors from a {@link FileChannel} into a list of integers. + */ + public static class StoredIVecsIterator extends StoredVecsIterator> { + public StoredIVecsIterator(@Nonnull final FileChannel fileChannel) { + super(fileChannel); + } + + @Nonnull + @Override + protected Integer[] newComponentArray(final int size) { + return new Integer[size]; + } + + @Nonnull + @Override + protected Integer toComponent(@Nonnull final ByteBuffer byteBuffer) { + return byteBuffer.getInt(); + } + + @Nonnull + @Override + protected List toTarget(@Nonnull final Integer[] components) { + return ImmutableList.copyOf(components); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/VectorType.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/VectorType.java new file mode 100644 index 0000000000..baee54d921 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/VectorType.java @@ -0,0 +1,28 @@ +/* + * VectorType.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +public enum VectorType { + HALF, + SINGLE, + DOUBLE, + RABITQ +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/package-info.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/package-info.java new file mode 100644 index 0000000000..34451ab26f --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/package-info.java @@ -0,0 +1,25 @@ +/* + * package-info.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Package that implements basic mathematical objects such as vectors and matrices as well as + * operations on these objects. + */ +package com.apple.foundationdb.linear; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/EncodedRealVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/EncodedRealVector.java new file mode 100644 index 0000000000..ba5c62e67f --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/EncodedRealVector.java @@ -0,0 +1,302 @@ +/* + * EncodedRealVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.rabitq; + +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.linear.FloatRealVector; +import com.apple.foundationdb.linear.HalfRealVector; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.VectorType; +import com.google.common.base.Suppliers; +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.function.Supplier; + +@SuppressWarnings("checkstyle:MemberName") +public class EncodedRealVector implements RealVector { + private static final double EPS0 = 1.9d; + + @Nonnull + private final int[] encoded; + private final double fAddEx; + private final double fRescaleEx; + private final double fErrorEx; + + @Nonnull + private final Supplier hashCodeSupplier; + @Nonnull + private final Supplier dataSupplier; + @Nonnull + private final Supplier rawDataSupplier; + @Nonnull + private final Supplier toHalfRealVectorSupplier; + @Nonnull + private final Supplier toFloatRealVectorSupplier; + + public EncodedRealVector(final int numExBits, @Nonnull final int[] encoded, final double fAddEx, final double fRescaleEx, + final double fErrorEx) { + this.encoded = encoded; + this.fAddEx = fAddEx; + this.fRescaleEx = fRescaleEx; + this.fErrorEx = fErrorEx; + + this.hashCodeSupplier = Suppliers.memoize(this::computeHashCode); + this.rawDataSupplier = Suppliers.memoize(() -> computeRawData(numExBits)); + this.dataSupplier = Suppliers.memoize(() -> computeData(numExBits)); + this.toHalfRealVectorSupplier = Suppliers.memoize(this::computeHalfRealVector); + this.toFloatRealVectorSupplier = Suppliers.memoize(this::computeFloatRealVector); + } + + @Nonnull + public int[] getEncodedData() { + return encoded; + } + + public double getAddEx() { + return fAddEx; + } + + public double getRescaleEx() { + return fRescaleEx; + } + + public double getErrorEx() { + return fErrorEx; + } + + @Override + public final boolean equals(final Object o) { + if (!(o instanceof EncodedRealVector)) { + return false; + } + + final EncodedRealVector that = (EncodedRealVector)o; + return Double.compare(fAddEx, that.fAddEx) == 0 && + Double.compare(fRescaleEx, that.fRescaleEx) == 0 && + Double.compare(fErrorEx, that.fErrorEx) == 0 && + Arrays.equals(encoded, that.encoded); + } + + @Override + public int hashCode() { + return hashCodeSupplier.get(); + } + + public int computeHashCode() { + int result = Arrays.hashCode(encoded); + result = 31 * result + Double.hashCode(fAddEx); + result = 31 * result + Double.hashCode(fRescaleEx); + result = 31 * result + Double.hashCode(fErrorEx); + return result; + } + + @Override + public int getNumDimensions() { + return encoded.length; + } + + public int getEncodedComponent(final int dimension) { + return encoded[dimension]; + } + + + @Override + public double getComponent(final int dimension) { + return getData()[dimension]; + } + + @Nonnull + @Override + public double[] getData() { + return dataSupplier.get(); + } + + @Nonnull + @Override + public RealVector withData(@Nonnull final double[] data) { + // we explicitly make this a normal double vector instead of an encoded vector + return new DoubleRealVector(data); + } + + @Nonnull + public double[] computeData(final int numExBits) { + final int numDimensions = getNumDimensions(); + final double cB = (1 << numExBits) - 0.5; + final RealVector z = new DoubleRealVector(encoded).subtract(cB); + final double normZ = z.l2Norm(); + + // Solve for rho and Δx from fErrorEx and fRescaleEx + final double a = (2.0 * EPS0) / Math.sqrt(numDimensions - 1.0); + final double denom = a * Math.abs(fRescaleEx) * normZ; + Verify.verify(denom != 0.0, "degenerate parameters: denom == 0"); + + final double r = Math.min(1.0, (2.0 * Math.abs(fErrorEx)) / denom); // clamp for safety + final double rho = Math.sqrt(Math.max(0.0, 1.0 - r * r)); + + final double deltaX = -0.5 * fRescaleEx * rho; + + // ô = c + Δx * r + return z.multiply(deltaX).getData(); + } + + @Nonnull + @Override + public byte[] getRawData() { + return rawDataSupplier.get(); + } + + @Nonnull + protected byte[] computeRawData(final int numExBits) { + int numBits = getNumDimensions() * (numExBits + 1); // congruency with paper + final int length = 25 + // RABITQ (byte) + fAddEx (double) + fRescaleEx (double) + fErrorEx (double) + (numBits - 1) / 8 + 1; // snap byte array to the smallest length fitting all bits + final byte[] vectorBytes = new byte[length]; + final ByteBuffer buffer = ByteBuffer.wrap(vectorBytes).order(ByteOrder.BIG_ENDIAN); + buffer.put((byte)VectorType.RABITQ.ordinal()); + buffer.putDouble(fAddEx); + buffer.putDouble(fRescaleEx); + buffer.putDouble(fErrorEx); + packEncodedComponents(numExBits, buffer); + return vectorBytes; + } + + private void packEncodedComponents(final int numExBits, @Nonnull final ByteBuffer buffer) { + // big-endian + final int bitsPerComponent = numExBits + 1; // congruency with paper + int remainingBitsInByte = 8; + byte currentByte = 0; + for (int i = 0; i < getNumDimensions(); i++) { + final int component = getEncodedComponent(i); + int remainingBitsInComponent = bitsPerComponent; + + while (remainingBitsInComponent > 0) { + final int remainingMask = (1 << remainingBitsInComponent) - 1; + final int remainingComponent = component & remainingMask; + + if (remainingBitsInComponent <= remainingBitsInByte) { + currentByte = (byte)(currentByte | (remainingComponent << (remainingBitsInByte - remainingBitsInComponent))); + remainingBitsInByte -= remainingBitsInComponent; + if (remainingBitsInByte == 0) { + remainingBitsInByte = 8; + buffer.put(currentByte); + currentByte = 0; + } + break; + } + + // remainingBitsInComponent > bitOffset + currentByte = (byte)(currentByte | (remainingComponent >> (remainingBitsInComponent - remainingBitsInByte))); + remainingBitsInComponent -= remainingBitsInByte; + remainingBitsInByte = 8; + buffer.put(currentByte); + currentByte = 0; + } + } + + if (remainingBitsInByte < 8) { + buffer.put(currentByte); + } + } + + @Nonnull + @Override + public HalfRealVector toHalfRealVector() { + return toHalfRealVectorSupplier.get(); + } + + @Nonnull + private HalfRealVector computeHalfRealVector() { + return new HalfRealVector(getData()); + } + + @Nonnull + @Override + public FloatRealVector toFloatRealVector() { + return toFloatRealVectorSupplier.get(); + } + + @Nonnull + private FloatRealVector computeFloatRealVector() { + return new FloatRealVector(getData()); + } + + @Nonnull + @Override + public DoubleRealVector toDoubleRealVector() { + return new DoubleRealVector(getData()); + } + + @Nonnull + public static EncodedRealVector fromBytes(@Nonnull final byte[] vectorBytes, + final int numDimensions, + final int numExBits) { + final ByteBuffer buffer = ByteBuffer.wrap(vectorBytes).order(ByteOrder.BIG_ENDIAN); + Verify.verify(buffer.get() == VectorType.RABITQ.ordinal()); + + final double fAddEx = buffer.getDouble(); + final double fRescaleEx = buffer.getDouble(); + final double fErrorEx = buffer.getDouble(); + final int[] components = unpackComponents(buffer, numDimensions, numExBits); + return new EncodedRealVector(numExBits, components, fAddEx, fRescaleEx, fErrorEx); + } + + @Nonnull + private static int[] unpackComponents(@Nonnull final ByteBuffer buffer, + final int numDimensions, + final int numExBits) { + int[] result = new int[numDimensions]; + + // big-endian + final int bitsPerComponent = numExBits + 1; // congruency with paper + int remainingBitsInByte = 8; + byte currentByte = buffer.get(); + for (int i = 0; i < numDimensions; i++) { + int remainingBitsForComponent = bitsPerComponent; + + while (remainingBitsForComponent > 0) { + final int mask = (1 << remainingBitsInByte) - 1; + int maskedByte = currentByte & mask; + + if (remainingBitsForComponent <= remainingBitsInByte) { + result[i] |= maskedByte >> (remainingBitsInByte - remainingBitsForComponent); + + remainingBitsInByte -= remainingBitsForComponent; + if (remainingBitsInByte == 0) { + remainingBitsInByte = 8; + currentByte = (i + 1 == numDimensions) ? 0 : buffer.get(); + } + break; + } + + // remainingBitsForComponent > remainingBitsInByte + result[i] |= maskedByte << remainingBitsForComponent - remainingBitsInByte; + remainingBitsForComponent -= remainingBitsInByte; + remainingBitsInByte = 8; + currentByte = buffer.get(); + } + } + return result; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitEstimator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitEstimator.java new file mode 100644 index 0000000000..2cc299e3b1 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitEstimator.java @@ -0,0 +1,115 @@ +/* + * RaBitEstimator.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.rabitq; + +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.linear.Estimator; +import com.apple.foundationdb.linear.Metric; +import com.apple.foundationdb.linear.RealVector; + +import javax.annotation.Nonnull; + +public class RaBitEstimator implements Estimator { + @Nonnull + private final Metric metric; + private final int numExBits; + + public RaBitEstimator(@Nonnull final Metric metric, + final int numExBits) { + this.metric = metric; + this.numExBits = numExBits; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + public int getNumExBits() { + return numExBits; + } + + @Override + public double distance(@Nonnull final RealVector query, + @Nonnull final RealVector storedVector) { + if (!(query instanceof EncodedRealVector) && storedVector instanceof EncodedRealVector) { + // only use the estimator if the first (by convention) vector is not encoded, but the second is + return distance(query, (EncodedRealVector)storedVector); + } + if (query instanceof EncodedRealVector && !(storedVector instanceof EncodedRealVector)) { + return distance(storedVector, (EncodedRealVector)query); + } + // use the regular metric for all other cases + return metric.distance(query, storedVector); + } + + private double distance(@Nonnull final RealVector query, // pre-rotated query q + @Nonnull final EncodedRealVector encodedVector) { + return estimateDistanceAndErrorBound(query, encodedVector).getDistance(); + } + + @Nonnull + public Result estimateDistanceAndErrorBound(@Nonnull final RealVector query, // pre-rotated query q + @Nonnull final EncodedRealVector encodedVector) { + final double cb = (1 << numExBits) - 0.5; + final double gAdd = query.dot(query); + final double gError = Math.sqrt(gAdd); + final RealVector totalCode = new DoubleRealVector(encodedVector.getEncodedData()); + final RealVector xuc = totalCode.subtract(cb); + final double dot = query.dot(xuc); + + switch (metric) { + case DOT_PRODUCT_METRIC: + case EUCLIDEAN_SQUARE_METRIC: + return new Result(encodedVector.getAddEx() + gAdd + encodedVector.getRescaleEx() * dot, + encodedVector.getErrorEx() * gError); + case EUCLIDEAN_METRIC: + return new Result(Math.sqrt(encodedVector.getAddEx() + gAdd + encodedVector.getRescaleEx() * dot), + Math.sqrt(encodedVector.getErrorEx() * gError)); + default: + throw new UnsupportedOperationException("metric not supported by quantizer"); + } + } + + public static class Result { + private final double distance; + private final double err; + + public Result(final double distance, final double err) { + this.distance = distance; + this.err = err; + } + + public double getDistance() { + return distance; + } + + public double getErr() { + return err; + } + + @Override + public String toString() { + return "estimate[" + "distance=" + distance + ", err=" + err + "]"; + } + } +} + diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitQuantizer.java b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitQuantizer.java new file mode 100644 index 0000000000..6204d2b909 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitQuantizer.java @@ -0,0 +1,437 @@ +/* + * RaBitQuantizer.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.rabitq; + +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.linear.Metric; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.linear.RealVector; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +import javax.annotation.Nonnull; +import java.util.Comparator; +import java.util.PriorityQueue; + +/** + * Implements the RaBit quantization scheme, a technique for compressing high-dimensional vectors into a compact + * integer-based representation. + *

+ * This class provides the logic to encode a {@link RealVector} into an {@link EncodedRealVector}. + * The encoding process involves finding an optimal scaling factor, quantizing the vector's components, + * and pre-calculating values that facilitate efficient distance estimation in the quantized space. + * It is configured with a specific {@link Metric} and a number of "extra bits" ({@code numExBits}) + * which control the precision of the quantization. + *

+ * Note that this implementation largely follows this paper + * by Jianyang Gao et al. It also mirrors algorithmic similarity, terms, and variable/method naming-conventions of the + * C++ implementation that can be found here. + * + * @see Quantizer + * @see RaBitEstimator + * @see EncodedRealVector + */ +public final class RaBitQuantizer implements Quantizer { + private static final double EPS = 1e-5; + private static final double EPS0 = 1.9; + private static final int N_ENUM = 10; + + // 0th entry unused; defined up to 8 extra bits in the source. + private static final double[] TIGHT_START = { + 0.00, 0.15, 0.20, 0.52, 0.59, 0.71, 0.75, 0.77, 0.81 + }; + + final int numExBits; + @Nonnull + private final Metric metric; + + /** + * Constructs a new {@code RaBitQuantizer} instance. + *

+ * This constructor initializes the quantizer with a specific metric and the number of + * extra bits to be used in the quantization process. + * + * @param metric the {@link Metric} to be used for quantization; must not be null. + * @param numExBits the number of extra bits for quantization. + */ + public RaBitQuantizer(@Nonnull final Metric metric, final int numExBits) { + Preconditions.checkArgument(numExBits > 0 && numExBits < TIGHT_START.length); + Preconditions.checkArgument( + metric == Metric.EUCLIDEAN_METRIC || + metric == Metric.EUCLIDEAN_SQUARE_METRIC || + metric == Metric.DOT_PRODUCT_METRIC); + + this.numExBits = numExBits; + this.metric = metric; + } + + /** + * Creates and returns a new {@link RaBitEstimator} instance. + *

+ * This method acts as a factory, constructing the estimator based on the + * {@code metric} and {@code numExBits} configuration of this object. + * The {@code @Override} annotation indicates that this is an implementation + * of a method from a superclass or interface. + * + * @return a new, non-null instance of {@link RaBitEstimator} + */ + @Nonnull + @Override + public RaBitEstimator estimator() { + return new RaBitEstimator(metric, numExBits); + } + + /** + * Encodes a given {@link RealVector} into its corresponding encoded representation. + *

+ * This method overrides the parent's {@code encode} method. It delegates the + * core encoding logic to an internal helper method and returns the final + * {@link EncodedRealVector}. + * + * @param data the {@link RealVector} to be encoded; must not be null. The vector must be pre-rotated and + * translated. + * + * @return the resulting {@link EncodedRealVector}, guaranteed to be non-null. + */ + @Nonnull + @Override + public EncodedRealVector encode(@Nonnull final RealVector data) { + return encodeInternal(data).getEncodedVector(); + } + + /** + * Encodes a real-valued vector into a quantized representation. + *

+ * This is an internal method that performs the core encoding logic. It first + * generates a base code using {@link #exBitsCode(RealVector)}, then incorporates + * sign information to create the final code. It precomputes various geometric + * properties (norms, dot products) of the original vector and its quantized + * counterpart to calculate metric-specific scaling and error factors. These + * factors are used for efficient distance calculations with the encoded vector. + * + * @param data the real-valued vector to be encoded. Must not be null. + * @return a {@code Result} object containing the {@link EncodedRealVector} and + * other intermediate values from the encoding process. The result is never null. + * + * @throws IllegalArgumentException if the configured {@code metric} is not supported for encoding. + */ + @Nonnull + @VisibleForTesting + Result encodeInternal(@Nonnull final RealVector data) { + final int dims = data.getNumDimensions(); + + QuantizeExResult base = exBitsCode(data); + int[] signedCode = base.code; + double ipInv = base.ipNormInv; + + int[] totalCode = new int[dims]; + for (int i = 0; i < dims; i++) { + int sgn = (data.getComponent(i) >= 0.0) ? +1 : 0; + totalCode[i] = signedCode[i] + (sgn << numExBits); + } + + final double cb = -(((1 << numExBits) - 0.5)); + double[] xuCbData = new double[dims]; + for (int i = 0; i < dims; i++) { + xuCbData[i] = totalCode[i] + cb; + } + final RealVector xuCb = new DoubleRealVector(xuCbData); + + // 5) Precompute all needed values + final double residualL2Sqr = data.dot(data); + final double residualL2Norm = Math.sqrt(residualL2Sqr); + final double ipResidualXuCb = data.dot(xuCb); + final double xuCbNorm = xuCb.l2Norm(); + final double xuCbNormSqr = xuCbNorm * xuCbNorm; + + final double ipResidualXuCbSafe = + (ipResidualXuCb == 0.0) ? Double.POSITIVE_INFINITY : ipResidualXuCb; + + double tmpError = residualL2Norm * EPS0 * + Math.sqrt(((residualL2Sqr * xuCbNormSqr) / (ipResidualXuCbSafe * ipResidualXuCbSafe) - 1.0) + / (Math.max(1, dims - 1))); + + double fAddEx; + double fRescaleEx; + double fErrorEx; + + if (metric == Metric.EUCLIDEAN_SQUARE_METRIC || metric == Metric.EUCLIDEAN_METRIC) { + fAddEx = residualL2Sqr; + fRescaleEx = ipInv * (-2.0 * residualL2Norm); + fErrorEx = 2.0 * tmpError; + } else if (metric == Metric.DOT_PRODUCT_METRIC) { + fAddEx = 1.0; + fRescaleEx = ipInv * (-1.0 * residualL2Norm); + fErrorEx = tmpError; + } else { + throw new IllegalArgumentException("Unsupported metric"); + } + + return new Result(new EncodedRealVector(numExBits, totalCode, fAddEx, fRescaleEx, fErrorEx), base.t, ipInv); + } + + /** + * Builds per-dimension extra-bit code using the best {@code t} found by {@link #bestRescaleFactor(RealVector)} and + * returns the code, {@code t}, and {@code ipNormInv}. + * @param residual rotated residual vector r. + */ + private QuantizeExResult exBitsCode(@Nonnull final RealVector residual) { + int dims = residual.getNumDimensions(); + + // oAbs = |r| normalized (RaBitQ does this before quantizeEx) + final RealVector oAbs = absOfNormalized(residual); + + final QuantizeExResult q = quantizeEx(oAbs); + + int[] k = q.code; + // revert codes for negative dims + int[] signed = new int[dims]; + int mask = (1 << numExBits) - 1; + for (int j = 0; j < dims; ++j) { + if (residual.getComponent(j) < 0) { + int tmp = k[j]; + signed[j] = (~tmp) & mask; + } else { + signed[j] = k[j]; + } + } + + return new QuantizeExResult(signed, q.t, q.ipNormInv); + } + + /** + * Method to quantize a vector. + * + * @param oAbs absolute values of a L2-normalized residual vector (nonnegative; length = dim) + * @return quantized levels (ex-bits), the chosen scale t, and ipNormInv + * Notes: If the residual is the all-zero vector (or numerically so), this returns zero codes, + * {@code t = 0}, and {@code ipNormInv = 1} (benign fallback). Downstream code uses {@code ipNormInv} to + * compute {@code fRescaleEx}, etc. + */ + private QuantizeExResult quantizeEx(@Nonnull final RealVector oAbs) { + final int dim = oAbs.getNumDimensions(); + final int maxLevel = (1 << numExBits) - 1; + + // Choose t via the sweep. + double t = bestRescaleFactor(oAbs); + // ipNorm = sum_i ( (k_i + 0.5) * |r_i| ) + double ipNorm = 0.0; + + // Build per-coordinate integer levels: k_i = floor(t * |r_i|) + int[] code = new int[dim]; + for (int i = 0; i < dim; i++) { + int k = (int) Math.floor(t * oAbs.getComponent(i) + EPS); + if (k > maxLevel) { + k = maxLevel; + } + code[i] = k; + ipNorm += (k + 0.5) * oAbs.getComponent(i); + } + + // ipNormInv = 1 / ipNorm, with a benign fallback. + double ipNormInv; + if (ipNorm > 0.0 && Double.isFinite(ipNorm)) { + ipNormInv = 1.0 / ipNorm; + if (!Double.isFinite(ipNormInv) || ipNormInv == 0.0) { + ipNormInv = 1.0; // extremely defensive + } + } else { + ipNormInv = 1.0; // fallback used in the C++ source + } + + return new QuantizeExResult(code, t, ipNormInv); + } + + /** + * Calculates the best rescaling factor {@code t} for a given vector of absolute values. + *

+ * This method implements an efficient algorithm to find a scaling factor {@code t} + * that maximizes an objective function related to the quantization of the input vector. + * The objective function being maximized is effectively + * {@code sum(u_i * o_i) / sqrt(sum(u_i^2 + u_i))}, where {@code u_i = floor(t * o_i)} + * and {@code o_i} are the components of the input vector {@code oAbs}. + *

+ * The algorithm performs a sweep over the scaling factor {@code t}. It uses a + * min-priority queue to efficiently jump between critical values of {@code t} where + * the floor of {@code t * o_i} changes for some coordinate {@code i}. The search is + * bounded within a pre-calculated "tight" range {@code [tStart, tEnd]} to ensure + * efficiency. + * + * @param oAbs The vector of absolute values for which to find the best rescale factor. + * Components must be non-negative. + * + * @return The optimal scaling factor {@code t} that maximizes the objective function, + * or 0.0 if the input vector is all zeros. + */ + private double bestRescaleFactor(@Nonnull final RealVector oAbs) { + final int numDimensions = oAbs.getNumDimensions(); + + // max_o = max(oAbs) + double maxO = 0.0d; + for (double v : oAbs.getData()) { + if (v > maxO) { + maxO = v; + } + } + if (maxO <= 0.0) { + return 0.0; // all zeros: nothing to scale + } + + // t_end and a "tight" t_start as in the C++ code + final int maxLevel = (1 << numExBits) - 1; + final double tEnd = ((maxLevel) + N_ENUM) / maxO; + final double tStart = tEnd * TIGHT_START[numExBits]; + + // cur_o_bar[i] = floor(tStart * oAbs[i]), but stored as int + final int[] curOB = new int[numDimensions]; + double sqrDen = numDimensions * 0.25; // Σ (cur^2 + cur) starts from D/4 + double numer = 0.0; + for (int i = 0; i < numDimensions; i++) { + int cur = (int) ((tStart * oAbs.getComponent(i)) + EPS); + curOB[i] = cur; + sqrDen += (double) cur * cur + cur; + numer += (cur + 0.5) * oAbs.getComponent(i); + } + + // Min-heap keyed by next threshold t at which coordinate "i" increments: + // t_i(k->k+1) = (curOB[i] + 1) / oAbs[i] + + final PriorityQueue pq = new PriorityQueue<>(Comparator.comparingDouble(n -> n.t)); + for (int i = 0; i < numDimensions; i++) { + final double curOAbs = oAbs.getComponent(i); + if (curOAbs > 0.0) { + double tNext = (curOB[i] + 1) / curOAbs; + pq.add(new Node(tNext, i)); + } + } + + double maxIp = 0.0; + double bestT = 0.0; + + while (!pq.isEmpty()) { + final Node node = pq.poll(); + final double curT = node.t; + final int i = node.idx; + + // increment cur_o_bar[i] + curOB[i]++; + final int u = curOB[i]; + + // update denominator and numerator: + // sqrDen += 2*u; numer += oAbs[i] + sqrDen += 2.0 * u; + numer += oAbs.getComponent(i); + + // objective value + final double curIp = numer / Math.sqrt(sqrDen); + if (curIp > maxIp) { + maxIp = curIp; + bestT = curT; + } + + // schedule next threshold for this coordinate, unless we've hit max level + if (u < maxLevel) { + final double oi = oAbs.getComponent(i); + final double tNext = (u + 1) / oi; + if (tNext < tEnd) { + pq.add(new Node(tNext, i)); + } + } + } + + return bestT; + } + + /** + * Computes a new vector containing the element-wise absolute values of the L2-normalized input vector. + *

+ * This operation is equivalent to first normalizing the vector {@code x} by its L2 norm, + * and then taking the absolute value of each resulting component. If the L2 norm of {@code x} + * is zero or not finite (e.g., {@link Double#POSITIVE_INFINITY}), a new zero vector of the + * same dimension is returned. + * + * @param x the input vector to be normalized and processed. Must not be null. + * + * @return a new {@code RealVector} containing the absolute values of the components of the + * normalized input vector. + */ + private static RealVector absOfNormalized(@Nonnull final RealVector x) { + double n = x.l2Norm(); + double[] y = new double[x.getNumDimensions()]; + if (n == 0.0 || !Double.isFinite(n)) { + return new DoubleRealVector(y); // all zeros + } + double inv = 1.0 / n; + for (int i = 0; i < x.getNumDimensions(); i++) { + y[i] = Math.abs(x.getComponent(i) * inv); + } + return new DoubleRealVector(y); + } + + @SuppressWarnings("checkstyle:MemberName") + public static final class Result { + public EncodedRealVector encodedVector; + public final double t; + public final double ipNormInv; + + public Result(@Nonnull final EncodedRealVector encodedVector, double t, double ipNormInv) { + this.encodedVector = encodedVector; + this.t = t; + this.ipNormInv = ipNormInv; + } + + public EncodedRealVector getEncodedVector() { + return encodedVector; + } + + public double getT() { + return t; + } + + public double getIpNormInv() { + return ipNormInv; + } + } + + @SuppressWarnings("checkstyle:MemberName") + private static final class QuantizeExResult { + public final int[] code; // k_i = floor(t * oAbs[i]) in [0, 2^exBits - 1] + public final double t; // chosen global scale + public final double ipNormInv; // 1 / sum_i ( (k_i + 0.5) * oAbs[i] ) + + public QuantizeExResult(int[] code, double t, double ipNormInv) { + this.code = code; + this.t = t; + this.ipNormInv = ipNormInv; + } + } + + @SuppressWarnings("checkstyle:MemberName") + private static final class Node { + private final double t; + private final int idx; + + Node(double t, int idx) { + this.t = t; + this.idx = idx; + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/package-info.java b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/package-info.java new file mode 100644 index 0000000000..df00c483a1 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/package-info.java @@ -0,0 +1,24 @@ +/* + * package-info.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * RaBitQ implementation. + */ +package com.apple.foundationdb.rabitq; diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfConstantsTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfConstantsTest.java new file mode 100644 index 0000000000..74a3f84fe0 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfConstantsTest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Christian Heina + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications Copyright 2015-2025 Apple Inc. and the FoundationDB project authors. + * This source file is part of the FoundationDB open source project + */ + +package com.apple.foundationdb.half; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** + * Unit test for {@link HalfConstants}. + * + * @author Christian Heina (developer@christianheina.com) + */ +public class HalfConstantsTest { + @Test + public void constantsTest() { + Assertions.assertEquals(HalfConstants.SIGNIFICAND_WIDTH, 11); + Assertions.assertEquals(HalfConstants.MIN_SUB_EXPONENT, -24); + Assertions.assertEquals(HalfConstants.EXP_BIAS, 15); + + Assertions.assertEquals(HalfConstants.SIGN_BIT_MASK, 0x8000); + Assertions.assertEquals(HalfConstants.EXP_BIT_MASK, 0x7C00); + Assertions.assertEquals(HalfConstants.SIGNIF_BIT_MASK, 0x03FF); + + // Check all bits filled and no overlap + Assertions.assertEquals((HalfConstants.SIGN_BIT_MASK | HalfConstants.EXP_BIT_MASK | HalfConstants.SIGNIF_BIT_MASK), 65535); + Assertions.assertEquals((HalfConstants.SIGN_BIT_MASK & HalfConstants.EXP_BIT_MASK), 0); + Assertions.assertEquals((HalfConstants.SIGN_BIT_MASK & HalfConstants.SIGNIF_BIT_MASK), 0); + Assertions.assertEquals((HalfConstants.EXP_BIT_MASK & HalfConstants.SIGNIF_BIT_MASK), 0); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfMathTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfMathTest.java new file mode 100644 index 0000000000..b45abe0599 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfMathTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023 Christian Heina + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications Copyright 2015-2025 Apple Inc. and the FoundationDB project authors. + * This source file is part of the FoundationDB open source project + */ + +package com.apple.foundationdb.half; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** + * Unit test for {@link HalfMath}. + * + * @author Christian Heina (developer@christianheina.com) + */ +public class HalfMathTest { + private static final Half LARGEST_SUBNORMAL = Half.shortBitsToHalf((short) 0x3ff); + + @Test + public void ulpTest() { + // Special cases + Assertions.assertEquals(Half.NaN, HalfMath.ulp(Half.NaN)); + Assertions.assertEquals(Half.POSITIVE_INFINITY, HalfMath.ulp(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(Half.POSITIVE_INFINITY, HalfMath.ulp(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(Half.MIN_VALUE, HalfMath.ulp(Half.NEGATIVE_ZERO)); + Assertions.assertEquals(Half.MIN_VALUE, HalfMath.ulp(Half.POSITIVE_ZERO)); + Assertions.assertEquals(HalfMath.ulp(Half.MAX_VALUE), Half.valueOf(Math.pow(2, 5))); + Assertions.assertEquals(HalfMath.ulp(Half.NEGATIVE_MAX_VALUE), Half.valueOf(Math.pow(2, 5))); + + // Regular cases + Assertions.assertEquals(Half.MIN_VALUE, HalfMath.ulp(Half.MIN_NORMAL)); + Assertions.assertEquals(Half.MIN_VALUE, HalfMath.ulp(LARGEST_SUBNORMAL)); + + Assertions.assertEquals(Half.MIN_VALUE, HalfMath.ulp(Half.shortBitsToHalf((short) 0x7ff))); + Assertions.assertEquals(Half.MIN_VALUE, HalfMath.ulp(Half.shortBitsToHalf((short) 0x7ff))); + } + + @Test + public void getExponentTest() { + // Special cases + Assertions.assertEquals(Half.MAX_EXPONENT + 1, HalfMath.getExponent(Half.NaN)); + Assertions.assertEquals(Half.MAX_EXPONENT + 1, HalfMath.getExponent(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(Half.MAX_EXPONENT + 1, HalfMath.getExponent(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(Half.MIN_EXPONENT - 1, HalfMath.getExponent(Half.POSITIVE_ZERO)); + Assertions.assertEquals(Half.MIN_EXPONENT - 1, HalfMath.getExponent(Half.NEGATIVE_ZERO)); + Assertions.assertEquals(Half.MIN_EXPONENT - 1, HalfMath.getExponent(Half.MIN_VALUE)); + Assertions.assertEquals(Half.MIN_EXPONENT - 1, HalfMath.getExponent(LARGEST_SUBNORMAL)); + + // Regular cases + Assertions.assertEquals(-13, HalfMath.getExponent(Half.valueOf(0.0002f))); + Assertions.assertEquals(-9, HalfMath.getExponent(Half.valueOf(0.002f))); + Assertions.assertEquals(-6, HalfMath.getExponent(Half.valueOf(0.02f))); + Assertions.assertEquals(-3, HalfMath.getExponent(Half.valueOf(0.2f))); + Assertions.assertEquals(1, HalfMath.getExponent(Half.valueOf(2.0f))); + Assertions.assertEquals(4, HalfMath.getExponent(Half.valueOf(20.0f))); + Assertions.assertEquals(7, HalfMath.getExponent(Half.valueOf(200.0f))); + Assertions.assertEquals(10, HalfMath.getExponent(Half.valueOf(2000.0f))); + Assertions.assertEquals(14, HalfMath.getExponent(Half.valueOf(20000.0f))); + } + + @Test + public void absTest() { + // Special cases + Assertions.assertEquals(Half.POSITIVE_INFINITY, HalfMath.abs(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(Half.POSITIVE_INFINITY, HalfMath.abs(Half.NEGATIVE_INFINITY)); + + Assertions.assertEquals(Half.NaN, HalfMath.abs(Half.NaN)); + Assertions.assertEquals(HalfMath.abs(Half.shortBitsToHalf((short) 0x7e04)), Half.shortBitsToHalf((short) 0x7e04)); + Assertions.assertEquals(HalfMath.abs(Half.shortBitsToHalf((short) 0x7fff)), Half.shortBitsToHalf((short) 0x7fff)); + + // Regular cases + Assertions.assertEquals(Half.POSITIVE_ZERO, HalfMath.abs(Half.POSITIVE_ZERO)); + Assertions.assertEquals(Half.POSITIVE_ZERO, HalfMath.abs(Half.NEGATIVE_ZERO)); + Assertions.assertEquals(Half.MAX_VALUE, HalfMath.abs(Half.MAX_VALUE)); + Assertions.assertEquals(Half.MAX_VALUE, HalfMath.abs(Half.NEGATIVE_MAX_VALUE)); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfTest.java new file mode 100644 index 0000000000..d66799092f --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/half/HalfTest.java @@ -0,0 +1,501 @@ +/* + * Copyright 2023 Christian Heina + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications Copyright 2015-2025 Apple Inc. and the FoundationDB project authors. + * This source file is part of the FoundationDB open source project + */ + +package com.apple.foundationdb.half; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** + * Unit test for {@link Half}. + * + * @author Christian Heina (developer@christianheina.com) + */ +public class HalfTest { + private static final short POSITIVE_INFINITY_SHORT_VALUE = (short) 0x7c00; + private static final short NEGATIVE_INFINITY_SHORT_VALUE = (short) 0xfc00; + private static final short NaN_SHORT_VALUE = (short) 0x7e00; + private static final short MAX_VALUE_SHORT_VALUE = (short) 0x7bff; + private static final short MIN_NORMAL_SHORT_VALUE = (short) 0x0400; + private static final short MIN_VALUE_SHORT_VALUE = (short) 0x1; + private static final int MAX_EXPONENT = 15; + private static final int MIN_EXPONENT = -14; + private static final int SIZE = 16; + private static final int BYTES = 2; + private static final short POSITIVE_ZERO_SHORT_VALUE = (short) 0x0; + private static final short NEGATIVE_ZERO_SHORT_VALUE = (short) 0x8000; + + private static final short LOWEST_ABOVE_ONE_SHORT_VALUE = (short) 0x3c01; + private static final Half LOWEST_ABOVE_ONE = Half.shortBitsToHalf(LOWEST_ABOVE_ONE_SHORT_VALUE); + private static final short NEGATIVE_MAX_VALUE_SHORT_VALUE = (short) 0xfbff; + private static final Half NEGATIVE_MAX_VALUE = Half.shortBitsToHalf(NEGATIVE_MAX_VALUE_SHORT_VALUE); + + @Test + public void publicStaticClassVariableTest() { + Assertions.assertEquals(POSITIVE_INFINITY_SHORT_VALUE, Half.halfToShortBits(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.shortBitsToHalf(POSITIVE_INFINITY_SHORT_VALUE)); + + Assertions.assertEquals(NEGATIVE_INFINITY_SHORT_VALUE, Half.halfToShortBits(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(Half.NEGATIVE_INFINITY, Half.shortBitsToHalf(NEGATIVE_INFINITY_SHORT_VALUE)); + + Assertions.assertEquals(NaN_SHORT_VALUE, Half.halfToShortBits(Half.NaN)); + Assertions.assertEquals(Half.NaN, Half.shortBitsToHalf(NaN_SHORT_VALUE)); + + Assertions.assertEquals(MAX_VALUE_SHORT_VALUE, Half.halfToShortBits(Half.MAX_VALUE)); + Assertions.assertEquals(Half.MAX_VALUE, Half.shortBitsToHalf(MAX_VALUE_SHORT_VALUE)); + + Assertions.assertEquals(NEGATIVE_MAX_VALUE_SHORT_VALUE, Half.halfToShortBits(Half.NEGATIVE_MAX_VALUE)); + Assertions.assertEquals(Half.NEGATIVE_MAX_VALUE, Half.shortBitsToHalf(NEGATIVE_MAX_VALUE_SHORT_VALUE)); + + Assertions.assertEquals(MIN_NORMAL_SHORT_VALUE, Half.halfToShortBits(Half.MIN_NORMAL)); + Assertions.assertEquals(Half.MIN_NORMAL, Half.shortBitsToHalf(MIN_NORMAL_SHORT_VALUE)); + Assertions.assertEquals(Half.MIN_NORMAL.doubleValue(), Math.pow(2, -14)); + + Assertions.assertEquals(MIN_VALUE_SHORT_VALUE, Half.halfToShortBits(Half.MIN_VALUE)); + Assertions.assertEquals(Half.MIN_VALUE, Half.shortBitsToHalf(MIN_VALUE_SHORT_VALUE)); + Assertions.assertEquals(Half.MIN_VALUE.doubleValue(), Math.pow(2, -24)); + + Assertions.assertEquals(Half.MAX_EXPONENT, MAX_EXPONENT); + Assertions.assertEquals(Half.MAX_EXPONENT, Math.getExponent(Half.MAX_VALUE.floatValue())); + + Assertions.assertEquals(Half.MIN_EXPONENT, MIN_EXPONENT); + Assertions.assertEquals(Half.MIN_EXPONENT, Math.getExponent(Half.MIN_NORMAL.floatValue())); + + Assertions.assertEquals(Half.SIZE, SIZE); + Assertions.assertEquals(Half.BYTES, BYTES); + + Assertions.assertEquals(POSITIVE_ZERO_SHORT_VALUE, Half.halfToShortBits(Half.POSITIVE_ZERO)); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.shortBitsToHalf(POSITIVE_ZERO_SHORT_VALUE)); + + Assertions.assertEquals(NEGATIVE_ZERO_SHORT_VALUE, Half.halfToShortBits(Half.NEGATIVE_ZERO)); + Assertions.assertEquals(Half.NEGATIVE_ZERO, Half.shortBitsToHalf(NEGATIVE_ZERO_SHORT_VALUE)); + } + + @Test + public void shortBitsToHalfTest() { + Assertions.assertEquals(Float.POSITIVE_INFINITY, Half.shortBitsToHalf(POSITIVE_INFINITY_SHORT_VALUE).floatValue()); + Assertions.assertEquals(Float.NEGATIVE_INFINITY, Half.shortBitsToHalf(NEGATIVE_INFINITY_SHORT_VALUE).floatValue()); + Assertions.assertEquals(Float.NaN, Half.shortBitsToHalf(NaN_SHORT_VALUE).floatValue()); + Assertions.assertEquals(65504f, Half.shortBitsToHalf(MAX_VALUE_SHORT_VALUE).floatValue()); + Assertions.assertEquals(6.103515625e-5f, Half.shortBitsToHalf(MIN_NORMAL_SHORT_VALUE).floatValue()); + Assertions.assertEquals(5.9604645e-8f, Half.shortBitsToHalf(MIN_VALUE_SHORT_VALUE).floatValue()); + Assertions.assertEquals(0f, Half.shortBitsToHalf(POSITIVE_ZERO_SHORT_VALUE).floatValue()); + Assertions.assertEquals(-0f, Half.shortBitsToHalf(NEGATIVE_ZERO_SHORT_VALUE).floatValue()); + + Assertions.assertEquals(1.00097656f, Half.shortBitsToHalf(LOWEST_ABOVE_ONE_SHORT_VALUE).floatValue()); + } + + @Test + public void halfToShortBitsTest() { + Assertions.assertEquals(POSITIVE_INFINITY_SHORT_VALUE, Half.halfToShortBits(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(NEGATIVE_INFINITY_SHORT_VALUE, Half.halfToShortBits(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(NaN_SHORT_VALUE, Half.halfToShortBits(Half.NaN)); + Assertions.assertEquals(NaN_SHORT_VALUE, Half.halfToShortBits(Half.shortBitsToHalf((short) 0x7e04))); + Assertions.assertEquals(NaN_SHORT_VALUE, Half.halfToShortBits(Half.shortBitsToHalf((short) 0x7fff))); + Assertions.assertEquals(MAX_VALUE_SHORT_VALUE, Half.halfToShortBits(Half.MAX_VALUE)); + Assertions.assertEquals(MIN_NORMAL_SHORT_VALUE, Half.halfToShortBits(Half.MIN_NORMAL)); + Assertions.assertEquals(MIN_VALUE_SHORT_VALUE, Half.halfToShortBits(Half.MIN_VALUE)); + Assertions.assertEquals(POSITIVE_ZERO_SHORT_VALUE, Half.halfToShortBits(Half.POSITIVE_ZERO)); + Assertions.assertEquals(NEGATIVE_ZERO_SHORT_VALUE, Half.halfToShortBits(Half.NEGATIVE_ZERO)); + + Assertions.assertEquals(LOWEST_ABOVE_ONE_SHORT_VALUE, Half.halfToShortBits(LOWEST_ABOVE_ONE)); + } + + @Test + public void halfToRawShortBitsTest() { + Assertions.assertEquals(POSITIVE_INFINITY_SHORT_VALUE, Half.halfToRawShortBits(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(NEGATIVE_INFINITY_SHORT_VALUE, Half.halfToRawShortBits(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(NaN_SHORT_VALUE, Half.halfToRawShortBits(Half.NaN)); + Assertions.assertEquals((short) 0x7e04, Half.halfToRawShortBits(Half.shortBitsToHalf((short) 0x7e04))); + Assertions.assertEquals((short) 0x7fff, Half.halfToRawShortBits(Half.shortBitsToHalf((short) 0x7fff))); + Assertions.assertEquals((short) 0x7e00, Half.halfToRawShortBits(Half.valueOf(Float.intBitsToFloat(0x7fe00000)))); + Assertions.assertEquals((short) 0x7e00, Half.halfToRawShortBits(Half.valueOf(Float.intBitsToFloat(0x7fe00001)))); + Assertions.assertEquals(MAX_VALUE_SHORT_VALUE, Half.halfToRawShortBits(Half.MAX_VALUE)); + Assertions.assertEquals(MIN_NORMAL_SHORT_VALUE, Half.halfToRawShortBits(Half.MIN_NORMAL)); + Assertions.assertEquals(MIN_VALUE_SHORT_VALUE, Half.halfToRawShortBits(Half.MIN_VALUE)); + Assertions.assertEquals(POSITIVE_ZERO_SHORT_VALUE, Half.halfToRawShortBits(Half.POSITIVE_ZERO)); + Assertions.assertEquals(NEGATIVE_ZERO_SHORT_VALUE, Half.halfToRawShortBits(Half.NEGATIVE_ZERO)); + + Assertions.assertEquals(LOWEST_ABOVE_ONE_SHORT_VALUE, Half.halfToRawShortBits(LOWEST_ABOVE_ONE)); + } + + @Test + public void shortValueTest() { + Assertions.assertEquals(Short.MAX_VALUE, Half.POSITIVE_INFINITY.shortValue()); + Assertions.assertEquals(Short.MIN_VALUE, Half.NEGATIVE_INFINITY.shortValue()); + Assertions.assertEquals((short) 0, Half.NaN.shortValue()); + Assertions.assertEquals(Short.MAX_VALUE, Half.MAX_VALUE.shortValue()); + Assertions.assertEquals(Short.MIN_VALUE, NEGATIVE_MAX_VALUE.shortValue()); + Assertions.assertEquals((short) 0, Half.MIN_NORMAL.shortValue()); + Assertions.assertEquals((short) 0, Half.MIN_VALUE.shortValue()); + Assertions.assertEquals((short) 0, Half.POSITIVE_ZERO.shortValue()); + Assertions.assertEquals((short) 0, Half.NEGATIVE_ZERO.shortValue()); + + Assertions.assertEquals((short) 1, LOWEST_ABOVE_ONE.shortValue()); + } + + @Test + public void intValueTest() { + Assertions.assertEquals(Integer.MAX_VALUE, Half.POSITIVE_INFINITY.intValue()); + Assertions.assertEquals(Integer.MIN_VALUE, Half.NEGATIVE_INFINITY.intValue()); + Assertions.assertEquals(0, Half.NaN.intValue()); + Assertions.assertEquals(65504, Half.MAX_VALUE.intValue()); + Assertions.assertEquals(0, Half.MIN_NORMAL.intValue()); + Assertions.assertEquals(0, Half.MIN_VALUE.intValue()); + Assertions.assertEquals(0, Half.POSITIVE_ZERO.intValue()); + Assertions.assertEquals(0, Half.NEGATIVE_ZERO.intValue()); + + Assertions.assertEquals(1, LOWEST_ABOVE_ONE.intValue()); + } + + @Test + public void longValueTest() { + Assertions.assertEquals(Long.MAX_VALUE, Half.POSITIVE_INFINITY.longValue()); + Assertions.assertEquals(Long.MIN_VALUE, Half.NEGATIVE_INFINITY.longValue()); + Assertions.assertEquals(0, Half.NaN.longValue()); + Assertions.assertEquals(65504, Half.MAX_VALUE.longValue()); + Assertions.assertEquals(0, Half.MIN_NORMAL.longValue()); + Assertions.assertEquals(0, Half.MIN_VALUE.longValue()); + Assertions.assertEquals(0, Half.POSITIVE_ZERO.longValue()); + Assertions.assertEquals(0, Half.NEGATIVE_ZERO.longValue()); + + Assertions.assertEquals(1, LOWEST_ABOVE_ONE.longValue()); + } + + @Test + public void floatValueTest() { + Assertions.assertEquals(Float.POSITIVE_INFINITY, Half.POSITIVE_INFINITY.floatValue()); + Assertions.assertEquals(Float.NEGATIVE_INFINITY, Half.NEGATIVE_INFINITY.floatValue()); + Assertions.assertEquals(Float.NaN, Half.NaN.floatValue()); + Assertions.assertEquals(65504f, Half.MAX_VALUE.floatValue()); + Assertions.assertEquals(6.103515625e-5f, Half.MIN_NORMAL.floatValue()); + Assertions.assertEquals(5.9604645e-8f, Half.MIN_VALUE.floatValue()); + Assertions.assertEquals(0f, Half.POSITIVE_ZERO.floatValue()); + Assertions.assertEquals(-0f, Half.NEGATIVE_ZERO.floatValue()); + + Assertions.assertEquals(1.00097656f, LOWEST_ABOVE_ONE.floatValue()); + } + + @Test + public void doubleValueTest() { + Assertions.assertEquals(Double.POSITIVE_INFINITY, Half.POSITIVE_INFINITY.doubleValue()); + Assertions.assertEquals(Double.NEGATIVE_INFINITY, Half.NEGATIVE_INFINITY.doubleValue()); + Assertions.assertEquals(Double.NaN, Half.NaN.doubleValue()); + Assertions.assertEquals(65504d, Half.MAX_VALUE.doubleValue()); + Assertions.assertEquals(6.103515625e-5d, Half.MIN_NORMAL.doubleValue()); + Assertions.assertEquals(5.9604644775390625E-8d, Half.MIN_VALUE.doubleValue()); + Assertions.assertEquals(0d, Half.POSITIVE_ZERO.doubleValue()); + Assertions.assertEquals(-0d, Half.NEGATIVE_ZERO.doubleValue()); + + Assertions.assertEquals(1.0009765625d, LOWEST_ABOVE_ONE.doubleValue()); + } + + @Test + public void byteValueTest() { + Assertions.assertEquals(Half.POSITIVE_INFINITY.byteValue(), Float.valueOf(Float.POSITIVE_INFINITY).byteValue()); + Assertions.assertEquals(Half.NEGATIVE_INFINITY.byteValue(), Float.valueOf(Float.NEGATIVE_INFINITY).byteValue()); + Assertions.assertEquals(Half.NaN.byteValue(), Float.valueOf(Float.NaN).byteValue()); + Assertions.assertEquals(Half.MAX_VALUE.byteValue(), Float.valueOf(Float.MAX_VALUE).byteValue()); + Assertions.assertEquals(Half.MIN_NORMAL.byteValue(), Float.valueOf(Float.MIN_NORMAL).byteValue()); + Assertions.assertEquals(Half.MIN_VALUE.byteValue(), Float.valueOf(Float.MIN_VALUE).byteValue()); + Assertions.assertEquals(Half.POSITIVE_ZERO.byteValue(), Float.valueOf(0.0f).byteValue()); + Assertions.assertEquals(Half.NEGATIVE_ZERO.byteValue(), Float.valueOf(-0.0f).byteValue()); + } + + @Test + public void valueOfStringTest() { + // Decmial values + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.valueOf("Infinity")); + Assertions.assertEquals(Half.NEGATIVE_INFINITY, Half.valueOf("-Infinity")); + Assertions.assertEquals(Half.NaN, Half.valueOf("NaN")); + Assertions.assertEquals(Half.MAX_VALUE, Half.valueOf("65504")); + Assertions.assertEquals(Half.MIN_NORMAL, Half.valueOf("6.103515625e-5")); + Assertions.assertEquals(Half.MIN_VALUE, Half.valueOf("5.9604645e-8")); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.valueOf("0")); + Assertions.assertEquals(Half.NEGATIVE_ZERO, Half.valueOf("-0")); + + Assertions.assertEquals(LOWEST_ABOVE_ONE, Half.valueOf("1.00097656f")); + + // Hex values + Assertions.assertEquals(Half.valueOf("0x1.0p0"), Half.valueOf(1.0f)); + Assertions.assertEquals(Half.valueOf("-0x1.0p0"), Half.valueOf(-1.0f)); + Assertions.assertEquals(Half.valueOf("0x1.0p1"), Half.valueOf(2.0f)); + Assertions.assertEquals(Half.valueOf("0x1.8p1"), Half.valueOf(3.0f)); + Assertions.assertEquals(Half.valueOf("0x1.0p-1"), Half.valueOf(0.5f)); + Assertions.assertEquals(Half.valueOf("0x1.0p-2"), Half.valueOf(0.25f)); + Assertions.assertEquals(Half.valueOf("0x0.ffcp-14"), Half.shortBitsToHalf((short) 0x3ff)); + } + + @Test + public void valueOfStringNumberFormatExceptionTest() { + Assertions.assertThrows(NumberFormatException.class, () -> Half.valueOf("ABC")); + } + + @Test + public void valueOfStringNullPointerExceptionTest() { + Assertions.assertThrows(NullPointerException.class, () -> Half.valueOf((String)null)); + } + + @Test + public void valueOfDoubleTest() { + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.valueOf(Double.valueOf(Double.POSITIVE_INFINITY))); + Assertions.assertEquals(Half.NEGATIVE_INFINITY, Half.valueOf(Double.valueOf(Double.NEGATIVE_INFINITY))); + Assertions.assertEquals(Half.NaN, Half.valueOf(Double.valueOf(Double.NaN))); + Assertions.assertEquals(Half.MAX_VALUE, Half.valueOf(Double.valueOf(65504d))); + Assertions.assertEquals(Half.MIN_NORMAL, Half.valueOf(Double.valueOf(6.103515625e-5d))); + Assertions.assertEquals(Half.MIN_VALUE, Half.valueOf(Double.valueOf(5.9604644775390625E-8d))); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.valueOf(Double.valueOf(0d))); + Assertions.assertEquals(Half.NEGATIVE_ZERO, Half.valueOf(Double.valueOf(-0d))); + + Assertions.assertEquals(LOWEST_ABOVE_ONE, Half.valueOf(Double.valueOf(1.0009765625d))); + } + + @Test + public void valueOfFloatTest() { + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.valueOf(Float.valueOf(Float.POSITIVE_INFINITY))); + Assertions.assertEquals(Half.NEGATIVE_INFINITY, Half.valueOf(Float.valueOf(Float.NEGATIVE_INFINITY))); + Assertions.assertEquals(Half.NaN, Half.valueOf(Float.valueOf(Float.NaN))); + Assertions.assertEquals(Half.MAX_VALUE, Half.valueOf(Float.valueOf(65504f))); + Assertions.assertEquals(Half.MIN_NORMAL, Half.valueOf(Float.valueOf(6.103515625e-5f))); + Assertions.assertEquals(Half.MIN_VALUE, Half.valueOf(Float.valueOf(5.9604645e-8f))); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.valueOf(Float.valueOf(0f))); + Assertions.assertEquals(Half.NEGATIVE_ZERO, Half.valueOf(Float.valueOf(-0f))); + + Assertions.assertEquals(LOWEST_ABOVE_ONE, Half.valueOf(Float.valueOf(1.00097656f))); + } + + @Test + public void valueOfHalfTest() { + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.valueOf(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(Half.NEGATIVE_INFINITY, Half.valueOf(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(Half.NaN, Half.valueOf(Half.NaN)); + Assertions.assertEquals(Half.MAX_VALUE, Half.valueOf(Half.MAX_VALUE)); + Assertions.assertEquals(Half.MIN_NORMAL, Half.valueOf(Half.MIN_NORMAL)); + Assertions.assertEquals(Half.MIN_VALUE, Half.valueOf(Half.MIN_VALUE)); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.valueOf(Half.POSITIVE_ZERO)); + Assertions.assertEquals(Half.NEGATIVE_ZERO, Half.valueOf(Half.NEGATIVE_ZERO)); + + Assertions.assertEquals(LOWEST_ABOVE_ONE, Half.valueOf(LOWEST_ABOVE_ONE)); + } + + @Test + public void isNaNTest() { + Assertions.assertFalse(Half.POSITIVE_INFINITY.isNaN()); + Assertions.assertFalse(Half.NEGATIVE_INFINITY.isNaN()); + Assertions.assertTrue(Half.NaN.isNaN()); + Assertions.assertFalse(Half.MAX_VALUE.isNaN()); + Assertions.assertFalse(Half.MIN_NORMAL.isNaN()); + Assertions.assertFalse(Half.MIN_VALUE.isNaN()); + Assertions.assertFalse(Half.POSITIVE_ZERO.isNaN()); + Assertions.assertFalse(Half.NEGATIVE_ZERO.isNaN()); + + Assertions.assertFalse(LOWEST_ABOVE_ONE.isNaN()); + } + + @Test + public void isInfiniteTest() { + Assertions.assertTrue(Half.POSITIVE_INFINITY.isInfinite()); + Assertions.assertTrue(Half.NEGATIVE_INFINITY.isInfinite()); + Assertions.assertFalse(Half.NaN.isInfinite()); + Assertions.assertFalse(Half.MAX_VALUE.isInfinite()); + Assertions.assertFalse(Half.MIN_NORMAL.isInfinite()); + Assertions.assertFalse(Half.MIN_VALUE.isInfinite()); + Assertions.assertFalse(Half.POSITIVE_ZERO.isInfinite()); + Assertions.assertFalse(Half.NEGATIVE_ZERO.isInfinite()); + + Assertions.assertFalse(LOWEST_ABOVE_ONE.isInfinite()); + } + + @Test + public void isFiniteTest() { + Assertions.assertFalse(Half.POSITIVE_INFINITY.isFinite()); + Assertions.assertFalse(Half.NEGATIVE_INFINITY.isFinite()); + Assertions.assertFalse(Half.NaN.isFinite()); + Assertions.assertTrue(Half.MAX_VALUE.isFinite()); + Assertions.assertTrue(Half.MIN_NORMAL.isFinite()); + Assertions.assertTrue(Half.MIN_VALUE.isFinite()); + Assertions.assertTrue(Half.POSITIVE_ZERO.isFinite()); + Assertions.assertTrue(Half.NEGATIVE_ZERO.isFinite()); + + Assertions.assertTrue(LOWEST_ABOVE_ONE.isFinite()); + } + + @Test + public void toStringTest() { + Assertions.assertEquals("Infinity", Half.POSITIVE_INFINITY.toString()); + Assertions.assertEquals("-Infinity", Half.NEGATIVE_INFINITY.toString()); + Assertions.assertEquals("NaN", Half.NaN.toString()); + Assertions.assertEquals("65504.0", Half.MAX_VALUE.toString()); + Assertions.assertEquals("6.1035156e-5", Half.MIN_NORMAL.toString().toLowerCase()); + Assertions.assertEquals("5.9604645e-8", Half.MIN_VALUE.toString().toLowerCase()); + Assertions.assertEquals("0.0", Half.POSITIVE_ZERO.toString()); + Assertions.assertEquals("-0.0", Half.NEGATIVE_ZERO.toString()); + + Assertions.assertEquals("1.0009766", LOWEST_ABOVE_ONE.toString().toLowerCase()); + } + + @Test + public void toHexStringTest() { + Assertions.assertEquals("Infinity", Half.toHexString(Half.POSITIVE_INFINITY)); + Assertions.assertEquals("-Infinity", Half.toHexString(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals("NaN", Half.toHexString(Half.NaN)); + Assertions.assertEquals("0x1.ffcp15", Half.toHexString(Half.MAX_VALUE)); + Assertions.assertEquals("0x1.0p-14", Half.toHexString(Half.MIN_NORMAL).toLowerCase()); + Assertions.assertEquals("0x0.004p-14", Half.toHexString(Half.MIN_VALUE).toLowerCase()); + Assertions.assertEquals("0x0.0p0", Half.toHexString(Half.POSITIVE_ZERO)); + Assertions.assertEquals("-0x0.0p0", Half.toHexString(Half.NEGATIVE_ZERO)); + + Assertions.assertEquals("0x1.004p0", Half.toHexString(LOWEST_ABOVE_ONE)); + + Assertions.assertEquals("0x1.0p0", Half.toHexString(Half.valueOf(1.0f))); + Assertions.assertEquals("-0x1.0p0", Half.toHexString(Half.valueOf(-1.0f))); + Assertions.assertEquals("0x1.0p1", Half.toHexString(Half.valueOf(2.0f))); + Assertions.assertEquals("0x1.8p1", Half.toHexString(Half.valueOf(3.0f))); + Assertions.assertEquals("0x1.0p-1", Half.toHexString(Half.valueOf(0.5f))); + Assertions.assertEquals("0x1.0p-2", Half.toHexString(Half.valueOf(0.25f))); + Assertions.assertEquals("0x0.ffcp-14", Half.toHexString(Half.shortBitsToHalf((short) 0x3ff))); + } + + @SuppressWarnings("EqualsWithItself") + @Test + public void equalsTest() { + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.POSITIVE_INFINITY); + Assertions.assertEquals(Half.NEGATIVE_INFINITY, Half.NEGATIVE_INFINITY); + Assertions.assertEquals(Half.NaN, Half.NaN); + Assertions.assertEquals(Half.MAX_VALUE, Half.MAX_VALUE); + Assertions.assertEquals(Half.MIN_NORMAL, Half.MIN_NORMAL); + Assertions.assertEquals(Half.MIN_VALUE, Half.MIN_VALUE); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.POSITIVE_ZERO); + Assertions.assertEquals(Half.NEGATIVE_ZERO, Half.NEGATIVE_ZERO); + + Assertions.assertEquals(LOWEST_ABOVE_ONE, LOWEST_ABOVE_ONE); + + Assertions.assertNotEquals(Half.POSITIVE_INFINITY, Half.NEGATIVE_INFINITY); + Assertions.assertNotEquals(Half.NEGATIVE_INFINITY, Half.POSITIVE_INFINITY); + Assertions.assertNotEquals(Half.NaN, Half.POSITIVE_INFINITY); + Assertions.assertNotEquals(Half.MAX_VALUE, Half.NaN); + Assertions.assertNotEquals(Half.MIN_NORMAL, Half.MIN_VALUE); + Assertions.assertNotEquals(Half.MIN_VALUE, Half.POSITIVE_ZERO); + Assertions.assertNotEquals(Half.POSITIVE_ZERO, Half.NEGATIVE_ZERO); + Assertions.assertNotEquals(Half.NEGATIVE_ZERO, Half.POSITIVE_ZERO); + + Assertions.assertNotEquals(LOWEST_ABOVE_ONE, Half.MIN_NORMAL); + + Assertions.assertNotEquals(null, LOWEST_ABOVE_ONE); + + // Additional NaN tests + Assertions.assertEquals(Half.NaN, Half.NaN); + Assertions.assertEquals(Half.NaN, Half.shortBitsToHalf((short)0x7e04)); + Assertions.assertEquals(Half.NaN, Half.shortBitsToHalf((short)0x7fff)); + } + + @Test + public void hashCodeTest() { + Assertions.assertEquals(POSITIVE_INFINITY_SHORT_VALUE, Half.POSITIVE_INFINITY.hashCode()); + Assertions.assertEquals(NEGATIVE_INFINITY_SHORT_VALUE, Half.NEGATIVE_INFINITY.hashCode()); + Assertions.assertEquals(NaN_SHORT_VALUE, Half.NaN.hashCode()); + Assertions.assertEquals(MAX_VALUE_SHORT_VALUE, Half.MAX_VALUE.hashCode()); + Assertions.assertEquals(MIN_NORMAL_SHORT_VALUE, Half.MIN_NORMAL.hashCode()); + Assertions.assertEquals(MIN_VALUE_SHORT_VALUE, Half.MIN_VALUE.hashCode()); + Assertions.assertEquals(POSITIVE_ZERO_SHORT_VALUE, Half.POSITIVE_ZERO.hashCode()); + Assertions.assertEquals(NEGATIVE_ZERO_SHORT_VALUE, Half.NEGATIVE_ZERO.hashCode()); + + Assertions.assertEquals(LOWEST_ABOVE_ONE_SHORT_VALUE, LOWEST_ABOVE_ONE.hashCode()); + } + + @SuppressWarnings("EqualsWithItself") + @Test + public void compareToTest() { + // Left + Assertions.assertEquals(1, Half.POSITIVE_INFINITY.compareTo(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(1, Half.MAX_VALUE.compareTo(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(1, Half.NaN.compareTo(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(1, Half.MAX_VALUE.compareTo(Half.MIN_NORMAL)); + Assertions.assertEquals(1, Half.MIN_NORMAL.compareTo(Half.MIN_VALUE)); + Assertions.assertEquals(1, Half.MIN_VALUE.compareTo(Half.POSITIVE_ZERO)); + Assertions.assertEquals(1, Half.POSITIVE_ZERO.compareTo(Half.NEGATIVE_ZERO)); + Assertions.assertEquals(1, LOWEST_ABOVE_ONE.compareTo(Half.NEGATIVE_ZERO)); + + // Right + Assertions.assertEquals(-1, Half.NEGATIVE_INFINITY.compareTo(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(-1, Half.NEGATIVE_INFINITY.compareTo(Half.MAX_VALUE)); + Assertions.assertEquals(-1, Half.NEGATIVE_INFINITY.compareTo(Half.NaN)); + Assertions.assertEquals(-1, Half.MIN_NORMAL.compareTo(Half.MAX_VALUE)); + Assertions.assertEquals(-1, Half.MIN_VALUE.compareTo(Half.MIN_NORMAL)); + Assertions.assertEquals(-1, Half.POSITIVE_ZERO.compareTo(Half.MIN_VALUE)); + Assertions.assertEquals(-1, Half.NEGATIVE_ZERO.compareTo(Half.POSITIVE_ZERO)); + Assertions.assertEquals(-1, Half.NEGATIVE_ZERO.compareTo(LOWEST_ABOVE_ONE)); + + // Equals + Assertions.assertEquals(0, Half.POSITIVE_INFINITY.compareTo(Half.POSITIVE_INFINITY)); + Assertions.assertEquals(0, Half.NEGATIVE_INFINITY.compareTo(Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(0, Half.NaN.compareTo(Half.NaN)); + Assertions.assertEquals(0, Half.MAX_VALUE.compareTo(Half.MAX_VALUE)); + Assertions.assertEquals(0, Half.MIN_NORMAL.compareTo(Half.MIN_NORMAL)); + Assertions.assertEquals(0, Half.MIN_VALUE.compareTo(Half.MIN_VALUE)); + Assertions.assertEquals(0, Half.POSITIVE_ZERO.compareTo(Half.POSITIVE_ZERO)); + Assertions.assertEquals(0, Half.NEGATIVE_ZERO.compareTo(Half.NEGATIVE_ZERO)); + Assertions.assertEquals(0, LOWEST_ABOVE_ONE.compareTo(LOWEST_ABOVE_ONE)); + } + + @Test + public void sumTest() { + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.sum(Half.POSITIVE_INFINITY, Half.MAX_VALUE)); + Assertions.assertEquals(Half.NaN, Half.sum(Half.POSITIVE_INFINITY, Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(Half.NaN, Half.sum(Half.NaN, Half.NaN)); + Assertions.assertEquals(Half.NaN, Half.sum(Half.NaN, Half.MAX_VALUE)); + Assertions.assertEquals(Half.sum(Half.MIN_NORMAL, Half.MIN_VALUE), Half.valueOf(6.109476E-5f)); + Assertions.assertEquals(Half.MIN_VALUE, Half.sum(Half.MIN_VALUE, Half.POSITIVE_ZERO)); + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.sum(Half.MAX_VALUE, LOWEST_ABOVE_ONE)); + Assertions.assertEquals( + Half.NEGATIVE_INFINITY, + Half.sum(Half.valueOf(-Half.MAX_VALUE.floatValue()), Half.valueOf(-LOWEST_ABOVE_ONE.floatValue()))); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.sum(Half.POSITIVE_ZERO, Half.NEGATIVE_ZERO)); + + Assertions.assertEquals(Half.NaN, Half.sum(Half.NaN, LOWEST_ABOVE_ONE)); + } + + @Test + public void maxTest() { + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.max(Half.POSITIVE_INFINITY, Half.MAX_VALUE)); + Assertions.assertEquals(Half.POSITIVE_INFINITY, Half.max(Half.POSITIVE_INFINITY, Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(Half.NaN, Half.max(Half.NaN, Half.NaN)); + Assertions.assertEquals(Half.NaN, Half.max(Half.NaN, Half.MAX_VALUE)); + Assertions.assertEquals(Half.MIN_NORMAL, Half.max(Half.MIN_NORMAL, Half.MIN_VALUE)); + Assertions.assertEquals(Half.MIN_VALUE, Half.max(Half.MIN_VALUE, Half.POSITIVE_ZERO)); + Assertions.assertEquals(Half.MAX_VALUE, Half.max(Half.MAX_VALUE, LOWEST_ABOVE_ONE)); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.max(Half.POSITIVE_ZERO, Half.NEGATIVE_ZERO)); + + Assertions.assertEquals(Half.NaN, Half.max(Half.NaN, LOWEST_ABOVE_ONE)); + } + + @Test + public void minTest() { + Assertions.assertEquals(Half.MAX_VALUE, Half.min(Half.POSITIVE_INFINITY, Half.MAX_VALUE)); + Assertions.assertEquals(Half.NEGATIVE_INFINITY, Half.min(Half.POSITIVE_INFINITY, Half.NEGATIVE_INFINITY)); + Assertions.assertEquals(Half.NaN, Half.min(Half.NaN, Half.NaN)); + Assertions.assertEquals(Half.NaN, Half.min(Half.NaN, Half.MAX_VALUE)); + Assertions.assertEquals(Half.MIN_VALUE, Half.min(Half.MIN_NORMAL, Half.MIN_VALUE)); + Assertions.assertEquals(Half.POSITIVE_ZERO, Half.min(Half.MIN_VALUE, Half.POSITIVE_ZERO)); + Assertions.assertEquals(LOWEST_ABOVE_ONE, Half.min(Half.MAX_VALUE, LOWEST_ABOVE_ONE)); + Assertions.assertEquals(Half.NEGATIVE_ZERO, Half.min(Half.POSITIVE_ZERO, Half.NEGATIVE_ZERO)); + + Assertions.assertEquals(Half.NaN, Half.min(Half.NaN, LOWEST_ABOVE_ONE)); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/half/MoreHalfTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/half/MoreHalfTest.java new file mode 100644 index 0000000000..817de5f627 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/half/MoreHalfTest.java @@ -0,0 +1,117 @@ +/* + * MoreHalfTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.half; + +import com.apple.test.RandomizedTestUtils; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.within; + +public class MoreHalfTest { + private static final double HALF_MIN_NORMAL = Math.scalb(1.0, -14); + private static final double REL_BOUND = Math.scalb(1.0, -10); // 2^-10 + + @Nonnull + private static Stream randomSeeds() { + return RandomizedTestUtils.randomSeeds(12345, 987654, 423, 18378195); + } + + @ParameterizedTest + @MethodSource("randomSeeds") + void roundTripTest(final long seed) { + final Random rnd = new Random(seed); + for (int i = 0; i < 10_000; i ++) { + // uniform in [0, 1) or [-2 * HALF_MAX, 2 * HALF_MAX) + double x = (i % 2 == 1) + ? rnd.nextDouble() + : ((rnd.nextDouble() * 2 - 1) * 2 * Half.MAX_VALUE.doubleValue()); + double y = Half.valueOf(x).doubleValue(); + + Assertions.assertThat(sameSign(x, y)).isTrue(); + + if (Math.abs(x) > Half.MAX_VALUE.doubleValue()) { + if (x > 0) { + Assertions.assertThat(y).isEqualTo(Double.POSITIVE_INFINITY); + } else { + Assertions.assertThat(y).isEqualTo(Double.NEGATIVE_INFINITY); + } + } else if (Math.abs(x) >= HALF_MIN_NORMAL) { + Assertions.assertThat((y - x) / x).isCloseTo(0.0d, within(REL_BOUND)); + } else { + Assertions.assertThat(y - x).isCloseTo(0.0d, within(Math.scalb(1.0d, -(Math.getExponent(x) + 23)))); + } + + double z = Half.valueOf(y).doubleValue(); + Assertions.assertThat(z).isEqualTo(y); + } + } + + /** + * Method to return if two doubles have the same sign including {@code Inf, -Inf, NaN, +-0}, etc. + * @param a a double + * @param b another double + * @return {@code true} iff {@code a} and {@code b} have the same sign + */ + private static boolean sameSign(final double a, final double b) { + // extract raw bit representations + long bitsA = Double.doubleToRawLongBits(a); + long bitsB = Double.doubleToRawLongBits(b); + // the sign bit is the most significant bit (bit 63) + return ((bitsA ^ bitsB) & 0x8000_0000_0000_0000L) == 0; + } + + @SuppressWarnings("checkstyle:AbbreviationAsWordInName") + @Test + void conversionRoundingTestMaxValue() { + final float smallestFloatGreaterThanHalfMax = Math.nextUp(Half.MAX_VALUE.floatValue()); + Assertions.assertThat(Half.valueOf(smallestFloatGreaterThanHalfMax)).matches(h -> h.isInfinite()); + } + + @SuppressWarnings("checkstyle:AbbreviationAsWordInName") + @Test + void conversionRoundingTestMinValue() { + // + // This conversion in Half implements round-to-nearest, ties-to-even for the half subnormal case by adding + // 0x007FF000 (which is 0x00800000 - 0x00001000), then shifting and final rounding. A consequence of this + // particular rounding scheme is: + // At the boundary we’re probing (mid = 2^-25, float exponent = 102), the expression + // ((0x007FF000 + significand) >> 23) stays 0 until significand >= 0x00001000 (4096 float-ULPs at that + // exponent), which means Math.nextUp(2^-25f) (only 1 ULP above the midpoint) still rounds to 0 in this + // implementation. + // + final float midF = Math.scalb(1.0f, -25); // 2^-25 exact in float + final int bits = Float.floatToRawIntBits(midF) + 0x00001000; // +4096 ULPs at this exponent + final float aboveF = Float.intBitsToFloat(bits); + + final Half h0 = Half.valueOf(Math.nextDown(midF)); // -> 0.0 + final Half h1 = Half.valueOf(aboveF); // -> 2^-24 + + Assertions.assertThat(h0.doubleValue()).isEqualTo(0.0); + Assertions.assertThat(h1.doubleValue()).isEqualTo(Math.scalb(1.0, -24)); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/FhtKacRotatorTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/FhtKacRotatorTest.java new file mode 100644 index 0000000000..9b44987174 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/FhtKacRotatorTest.java @@ -0,0 +1,104 @@ +/* + * FhtKacRotatorTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableSet; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.within; + +public class FhtKacRotatorTest { + @Nonnull + private static Stream randomSeedsWithNumDimensions() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) + .flatMap(seed -> ImmutableSet.of(3, 5, 10, 128, 768, 1000).stream() + .map(numDimensions -> Arguments.of(seed, numDimensions))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testSimpleRotationAndBack(final long seed, final int numDimensions) { + final FhtKacRotator rotator = new FhtKacRotator(seed, numDimensions, 10); + + final Random random = new Random(seed); + final RealVector x = RealVectorTest.createRandomDoubleVector(random, numDimensions); + final RealVector y = rotator.operate(x); + final RealVector z = rotator.operateTranspose(y); + + Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(x, z)).isCloseTo(0, within(2E-10)); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testRotationIsStable(final long seed, final int numDimensions) { + final FhtKacRotator rotator1 = new FhtKacRotator(seed, numDimensions, 10); + final FhtKacRotator rotator2 = new FhtKacRotator(seed, numDimensions, 10); + Assertions.assertThat(rotator1.hashCode()).isEqualTo(rotator2.hashCode()); + Assertions.assertThat(rotator1).isEqualTo(rotator2); + + final Random random = new Random(seed); + final RealVector x = RealVectorTest.createRandomDoubleVector(random, numDimensions); + final RealVector x_ = rotator1.operate(x); + final RealVector x__ = rotator2.operate(x); + + Assertions.assertThat(x_).isEqualTo(x__); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testOrthogonality(final long seed, final int numDimensions) { + final FhtKacRotator rotator = new FhtKacRotator(seed, numDimensions, 10); + final ColumnMajorRealMatrix p = new ColumnMajorRealMatrix(rotator.computeP().transpose().getData()); + + for (int j = 0; j < numDimensions; j ++) { + final RealVector rotated = rotator.operateTranspose(new DoubleRealVector(p.getColumn(j))); + for (int i = 0; i < numDimensions; i++) { + double expected = (i == j) ? 1.0 : 0.0; + Assertions.assertThat(Math.abs(rotated.getComponent(i) - expected)) + .isCloseTo(0, within(2E-14)); + } + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testOrthogonalityWithP(final long seed, final int dimensionality) { + final FhtKacRotator rotator = new FhtKacRotator(seed, dimensionality, 10); + final RealMatrix p = rotator.computeP(); + final RealMatrix product = p.transpose().multiply(p); + + for (int i = 0; i < dimensionality; i++) { + for (int j = 0; j < dimensionality; j++) { + double expected = (i == j) ? 1.0 : 0.0; + Assertions.assertThat(Math.abs(product.getEntry(i, j) - expected)) + .isCloseTo(0, within(2E-14)); + } + } + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/MetricTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/MetricTest.java new file mode 100644 index 0000000000..cfb087eacd --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/MetricTest.java @@ -0,0 +1,177 @@ +/* + * MetricTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.Stream; + +import static com.apple.foundationdb.linear.Metric.COSINE_METRIC; +import static com.apple.foundationdb.linear.Metric.DOT_PRODUCT_METRIC; +import static com.apple.foundationdb.linear.Metric.EUCLIDEAN_METRIC; +import static com.apple.foundationdb.linear.Metric.EUCLIDEAN_SQUARE_METRIC; +import static com.apple.foundationdb.linear.Metric.MANHATTAN_METRIC; + +public class MetricTest { + static Stream metricAndExpectedDistance() { + // Distance between (1.0, 2.0) and (4.0, 6.0) + final RealVector v1 = v(1.0, 2.0); + final RealVector v2 = v(4.0, 6.0); + return Stream.of( + Arguments.of(MANHATTAN_METRIC, v1, v2, 7.0), // |1 - 4| + |2 - 6| = 7 + Arguments.of(EUCLIDEAN_METRIC, v1, v2, 5.0), // sqrt((1-4)^2 + (2-6)^2) = sqrt(9 + 16) = 5 + Arguments.of(EUCLIDEAN_SQUARE_METRIC, v1, v2, 25.0), // (1-4)^2 + (2-6)^2 = 9 + 16 = 25 + Arguments.of(COSINE_METRIC, v1, v2, 0.007722), // ((1 * 4) + (2 * 6)) / (sqrt(1^2 + 2^2) * sqrt(4^2 + 6^2) = 16 / (sqrt(5) * sqrt(52)) ≈ 1 - 0.992277 ≈ 0.007722 + Arguments.of(DOT_PRODUCT_METRIC, v1, v2, -16.0) // -((1 * 4) + (2 * 6)) = -(4 + 12) = -16 + ); + } + + @ParameterizedTest + @MethodSource("metricAndExpectedDistance") + void basicMetricTest(@Nonnull final Metric metric, @Nonnull final RealVector v1, @Nonnull final RealVector v2, final double expectedDistance) { + Assertions.assertThat(metric.distance(v1, v2)).isCloseTo(expectedDistance, Offset.offset(2E-4d)); + } + + @Nonnull + private static Stream randomSeedsWithMetrics() { + return RandomizedTestUtils.randomSeeds(12345, 987654, 423, 18378195) + .flatMap(seed -> + Sets.cartesianProduct( + ImmutableSet.of(MANHATTAN_METRIC, + EUCLIDEAN_METRIC, + EUCLIDEAN_SQUARE_METRIC, + COSINE_METRIC, + DOT_PRODUCT_METRIC), ImmutableSet.of(3, 5, 128, 768)).stream() + .map(metricsAndNumDimensions -> + Arguments.of(seed, metricsAndNumDimensions.get(0), metricsAndNumDimensions.get(1)))); + } + + @Test + void testMetricDefinitionBasics() { + Assertions.assertThat(MANHATTAN_METRIC.toString()).contains(MetricDefinition.ManhattanMetric.class.getSimpleName()); + Assertions.assertThat(MANHATTAN_METRIC.isTrueMetric()).isTrue(); + Assertions.assertThat(EUCLIDEAN_METRIC.toString()).contains(MetricDefinition.EuclideanMetric.class.getSimpleName()); + Assertions.assertThat(EUCLIDEAN_METRIC.isTrueMetric()).isTrue(); + Assertions.assertThat(EUCLIDEAN_SQUARE_METRIC.toString()).contains(MetricDefinition.EuclideanSquareMetric.class.getSimpleName()); + Assertions.assertThat(EUCLIDEAN_SQUARE_METRIC.isTrueMetric()).isFalse(); + Assertions.assertThat(COSINE_METRIC.toString()).contains(MetricDefinition.CosineMetric.class.getSimpleName()); + Assertions.assertThat(COSINE_METRIC.isTrueMetric()).isFalse(); + Assertions.assertThat(DOT_PRODUCT_METRIC.toString()).contains(MetricDefinition.DotProductMetric.class.getSimpleName()); + Assertions.assertThat(DOT_PRODUCT_METRIC.isTrueMetric()).isFalse(); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithMetrics") + void basicPropertyTest(final long seed, @Nonnull final Metric metric, final int numDimensions) { + final Random random = new Random(seed); + + for (int i = 0; i < 1000; i ++) { + // either use vectors that draw from [0, 1) or use the entire full double range + final RealVector x = (i % 2 == 1) ? randomV(random, numDimensions) : randomVFull(random, numDimensions); + final RealVector y = (i % 2 == 1) ? randomV(random, numDimensions) : randomVFull(random, numDimensions); + final RealVector z = (i % 2 == 1) ? randomV(random, numDimensions) : randomVFull(random, numDimensions); + + final double distanceXX = metric.distance(x, x); + final double distanceXY = metric.distance(x, y); + final double distanceYX = metric.distance(y, x); + final double distanceYZ = metric.distance(y, z); + final double distanceXZ = metric.distance(x, z); + + if (!Double.isFinite(distanceXX) || !Double.isFinite(distanceXY)) { + // + // Some metrics are numerically unstable across the entire numerical range. + // For instance COSINE_METRIC may return Double.NaN or Double.POSITIVE_INFINITY which is ok. + // + continue; + } + + Assertions.assertThat(distanceXX).satisfiesAnyOf( + d -> Assertions.assertThat(metric.satisfiesZeroSelfDistance()).isFalse(), + d -> Assertions.assertThat(d).isCloseTo(0, Offset.offset(2E-10d))); + + Assertions.assertThat(distanceXY).satisfiesAnyOf( + d -> Assertions.assertThat(metric.satisfiesPositivity()).isFalse(), + d -> Assertions.assertThat(d).isGreaterThanOrEqualTo(0)); + + Assertions.assertThat(distanceXY).satisfiesAnyOf( + d -> Assertions.assertThat(metric.satisfiesSymmetry()).isFalse(), + d -> Assertions.assertThat(d).isCloseTo(distanceYX, Offset.offset(2E-10d))); + + Assertions.assertThat(distanceXY).satisfiesAnyOf( + d -> Assertions.assertThat(metric.satisfiesTriangleInequality()).isFalse(), + d -> Assertions.assertThat(d + distanceYZ).isGreaterThanOrEqualTo(distanceXZ), + d -> Assertions.assertThat(triangleHolds(distanceXY, distanceYZ, distanceXZ, numDimensions * 3)).isTrue()); + } + } + + static boolean triangleHolds(double distanceXY, double distanceYZ, double distanceXZ, int numOperations) { + double u = 0x1p-53; // relative error ~1.11e-16 + double gamma = (numOperations * u) / (1 - numOperations * u); // ~ n*u + double scale = distanceXY + distanceYZ + distanceXZ; // magnitude to scale tol + double tol = 4 * gamma * scale + + (Math.ulp(distanceXY + distanceYZ) + Math.ulp(distanceXZ)); // small guard + return distanceXZ <= distanceXY + distanceYZ + tol; + } + + @Nonnull + @SuppressWarnings("checkstyle:MethodName") + private static RealVector v(final double... components) { + return new DoubleRealVector(components); + } + + @Nonnull + @SuppressWarnings("checkstyle:MethodName") + private static RealVector randomV(@Nonnull final Random random, final int numDimensions) { + final double[] components = new double[numDimensions]; + for (int i = 0; i < numDimensions; i++) { + components[i] = random.nextDouble(); + } + return new DoubleRealVector(components); + } + + @Nonnull + @SuppressWarnings("checkstyle:MethodName") + private static RealVector randomVFull(@Nonnull final Random random, final int numDimensions) { + final double[] components = new double[numDimensions]; + for (int i = 0; i < numDimensions; i++) { + components[i] = randomDouble(random); + } + return new DoubleRealVector(components); + } + + private static double randomDouble(@Nonnull final Random random) { + double d; + do { + d = Double.longBitsToDouble(random.nextLong()); + } while (!Double.isFinite(d)); + return d; + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/QRDecompositionTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/QRDecompositionTest.java new file mode 100644 index 0000000000..3dfe46f2d1 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/QRDecompositionTest.java @@ -0,0 +1,100 @@ +/* + * QRDecompositionTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.within; + +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class QRDecompositionTest { + @Nonnull + private static Stream randomSeedsWithNumDimensions() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) + .flatMap(seed -> ImmutableSet.of(3, 5, 10, 128, 768).stream() + .map(numDimensions -> Arguments.of(seed, numDimensions))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testQREqualsM(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final RealMatrix m = MatrixHelpers.randomOrthogonalMatrix(random, numDimensions); + final QRDecomposition.Result result = QRDecomposition.decomposeMatrix(m); + final RealMatrix product = result.getQ().multiply(result.getR()); + for (int i = 0; i < product.getRowDimension(); i++) { + for (int j = 0; j < product.getColumnDimension(); j++) { + assertThat(product.getEntry(i, j)).isCloseTo(m.getEntry(i, j), within(2E-14)); + } + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testRepeatedQR(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final RealMatrix m = MatrixHelpers.randomOrthogonalMatrix(random, numDimensions); + final QRDecomposition.Result firstResult = QRDecomposition.decomposeMatrix(m); + final QRDecomposition.Result secondResult = QRDecomposition.decomposeMatrix(firstResult.getQ()); + + final RealMatrix r = secondResult.getR(); + for (int i = 0; i < r.getRowDimension(); i++) { + for (int j = 0; j < r.getColumnDimension(); j++) { + assertThat(Math.abs(r.getEntry(i, j))).isCloseTo((i == j) ? 1.0d : 0.0d, within(2E-14)); + } + } + } + + @Test + void testZeroes() { + double[][] mData = new double[5][5]; + + // Fill the top-left 2x5 however you like (non-zero): + mData[0] = new double[] { 1, 2, 3, 4, 5 }; + mData[1] = new double[] {-1, 0, 7, 8, 9 }; + + // Make rows 2..4 all zeros: + mData[2] = new double[] { 0, 0, 0, 0, 0 }; + mData[3] = new double[] { 0, 0, 0, 0, 0 }; + mData[4] = new double[] { 0, 0, 0, 0, 0 }; + + // => For any minor k ≥ 2, the column segment x (rows k...end) is all zeros => a == 0.0 branch taken. + final RealMatrix m = new RowMajorRealMatrix(mData); + final QRDecomposition.Result result = QRDecomposition.decomposeMatrix(m); + final RealMatrix product = result.getQ().multiply(result.getR()); + + for (int i = 0; i < product.getRowDimension(); i++) { + for (int j = 0; j < product.getColumnDimension(); j++) { + assertThat(product.getEntry(i, j)).isCloseTo(m.getEntry(i, j), within(2E-14)); + } + } + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealMatrixTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealMatrixTest.java new file mode 100644 index 0000000000..439f41ffcf --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealMatrixTest.java @@ -0,0 +1,193 @@ +/* + * RealMatrixTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.within; + +public class RealMatrixTest { + @Nonnull + private static Stream randomSeedsWithNumDimensions() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) + .flatMap(seed -> ImmutableSet.of(3, 5, 10, 128, 768, 1000).stream() + .map(numDimensions -> Arguments.of(seed, numDimensions))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testTranspose(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final int numRows = random.nextInt(numDimensions) + 1; + final int numColumns = random.nextInt(numDimensions) + 1; + final RealMatrix matrix = MatrixHelpers.randomGaussianMatrix(random, numRows, numColumns); + final RealMatrix otherMatrix = flip(matrix); + assertThat(otherMatrix).isEqualTo(matrix); + final RealMatrix anotherMatrix = flip(otherMatrix); + assertThat(anotherMatrix).isEqualTo(otherMatrix); + assertThat(anotherMatrix).isEqualTo(matrix); + assertThat(anotherMatrix.getClass()).isSameAs(matrix.getClass()); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testQuickTranspose(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final int numRows = random.nextInt(numDimensions) + 1; + final int numColumns = random.nextInt(numDimensions) + 1; + final RealMatrix matrix = MatrixHelpers.randomGaussianMatrix(random, numRows, numColumns); + final RealMatrix otherMatrix = matrix.quickTranspose().transpose(); + assertThat(otherMatrix).isEqualTo(matrix); + final RealMatrix anotherMatrix = matrix.quickTranspose().quickTranspose(); + assertThat(anotherMatrix).isEqualTo(matrix); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testDifferentMajor(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final int numRows = random.nextInt(numDimensions) + 1; + final int numColumns = random.nextInt(numDimensions) + 1; + final RealMatrix matrix = MatrixHelpers.randomGaussianMatrix(random, numRows, numColumns); + assertThat(matrix).isInstanceOf(RowMajorRealMatrix.class); + final RealMatrix otherMatrix = matrix.toColumnMajor(); + assertThat(otherMatrix.toColumnMajor()).isSameAs(otherMatrix); + assertThat(otherMatrix.hashCode()).isEqualTo(matrix.hashCode()); + assertThat(otherMatrix).isEqualTo(matrix); + final RealMatrix anotherMatrix = otherMatrix.toRowMajor(); + assertThat(anotherMatrix.toRowMajor()).isSameAs(anotherMatrix); + assertThat(anotherMatrix.hashCode()).isEqualTo(matrix.hashCode()); + assertThat(anotherMatrix).isEqualTo(matrix); + } + + + @Nonnull + private static RealMatrix flip(@Nonnull final RealMatrix matrix) { + assertThat(matrix) + .satisfiesAnyOf(m -> assertThat(m).isInstanceOf(RowMajorRealMatrix.class), + m -> assertThat(m).isInstanceOf(ColumnMajorRealMatrix.class)); + final double[][] data = matrix.transpose().getData(); + if (matrix instanceof RowMajorRealMatrix) { + return new ColumnMajorRealMatrix(data); + } else { + return new RowMajorRealMatrix(data); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testOperateAndBack(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final RealMatrix matrix = MatrixHelpers.randomOrthogonalMatrix(random, numDimensions); + assertThat(matrix.isTransposable()).isTrue(); + final RealVector x = RealVectorTest.createRandomDoubleVector(random, numDimensions); + final RealVector y = matrix.operate(x); + final RealVector z = matrix.operateTranspose(y); + assertThat(Metric.EUCLIDEAN_METRIC.distance(x, z)).isCloseTo(0, within(2E-10)); + } + + /** + * Tests the multiplication of two (dense) row-major matrices. We don't want to use just a fixed set of matrices, + * but rather want to run the multiplication for a lot of different arguments. The goal is to not replicate + * the actual multiplication algorithm for the test. We employ the following trick. We first generate a random + * orthogonal matrix of {@code d} dimensions. For that matrix, let's call it {@code r}, {@code r*r^T = I_d} which + * exercises the multiplication code path and whose result is easily verifiable. The problem is, though, that doing + * just that would only test the multiplication of square matrices. In order, to similarly test the multiplication + * of any kinds of rectangular matrices, we can still start out from a square matrix, but then randomly cut off the + * last {@code k} rows (and respectively {@code l} columns from the transpose) and only then multiply the results. + * The final resulting matrix is of dimensionality {@code d - k} by {@code d - l} which is a rectangular identity + * matrix, i.e.: + * {@code (I_(d-k, d-l))_ij = 1, if i == j, 0 otherwise}. + * + * @param seed a seed for a random number generator + * @param d the number of dimensions we should use when generating the random orthogonal matrix + */ + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testMultiplyRowMajorMatrix(final long seed, final int d) { + final Random random = new Random(seed); + final RealMatrix r = MatrixHelpers.randomOrthogonalMatrix(random, d); + assertMultiplyMxMT(d, random, r); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testMultiplyColumnMajorMatrix(final long seed, final int d) { + final Random random = new Random(seed); + final RealMatrix r = flip(MatrixHelpers.randomOrthogonalMatrix(random, d)); + assertMultiplyMxMT(d, random, r); + } + + private static void assertMultiplyMxMT(final int d, @Nonnull final Random random, @Nonnull final RealMatrix r) { + final int k = random.nextInt(d); + final int l = random.nextInt(d); + + final int numResultRows = d - k; + final int numResultColumns = d - l; + + final RealMatrix m = r.subMatrix(0, numResultRows, 0, d); + final RealMatrix mT = r.transpose().subMatrix(0, d, 0, numResultColumns); + + final RealMatrix product = m.multiply(mT); + + assertThat(product) + .satisfies(p -> assertThat(p.getRowDimension()).isEqualTo(numResultRows), + p -> assertThat(p.getColumnDimension()).isEqualTo(numResultColumns)); + + for (int i = 0; i < product.getRowDimension(); i++) { + for (int j = 0; j < product.getColumnDimension(); j++) { + double expected = (i == j) ? 1.0 : 0.0; + assertThat(Math.abs(product.getEntry(i, j) - expected)) + .isCloseTo(0, within(2E-14)); + } + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testMultiplyMatrix2(final long seed, final int d) { + final Random random = new Random(seed); + final int k = random.nextInt(d) + 1; + final int l = random.nextInt(d) + 1; + + final RowMajorRealMatrix m1 = MatrixHelpers.randomGaussianMatrix(random, k, l).toRowMajor(); + final ColumnMajorRealMatrix m2 = MatrixHelpers.randomGaussianMatrix(random, l, k).toColumnMajor(); + + final RealMatrix product = m1.multiply(m2); + + for (int i = 0; i < product.getRowDimension(); i++) { + for (int j = 0; j < product.getColumnDimension(); j++) { + final double expected = new DoubleRealVector(m1.getRow(i)).dot(new DoubleRealVector(m2.getColumn(j))); + assertThat(Math.abs(product.getEntry(i, j) - expected)) + .isCloseTo(0, within(2E-14)); + } + } + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealVectorTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealVectorTest.java new file mode 100644 index 0000000000..55571a1a38 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealVectorTest.java @@ -0,0 +1,267 @@ +/* + * RealVectorTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.foundationdb.half.Half; +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableSet; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.Stream; + +public class RealVectorTest { + @Nonnull + private static Stream randomSeedsWithNumDimensions() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) + .flatMap(seed -> ImmutableSet.of(3, 5, 10, 128, 768, 1000).stream() + .map(numDimensions -> Arguments.of(seed, numDimensions))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testPrecisionRoundTrips(final long seed, final int numDimensions) { + final Random random = new Random(seed); + for (int i = 0; i < 1000; i ++) { + final DoubleRealVector doubleVector = createRandomDoubleVector(random, numDimensions); + Assertions.assertThat(doubleVector.toDoubleRealVector()).isEqualTo(doubleVector); + + final FloatRealVector floatVector = doubleVector.toFloatRealVector(); + Assertions.assertThat(floatVector.toFloatRealVector()).isEqualTo(floatVector); + Assertions.assertThat(floatVector.toDoubleRealVector().toFloatRealVector()).isEqualTo(floatVector); + + final HalfRealVector halfVector = floatVector.toHalfRealVector(); + Assertions.assertThat(halfVector.toHalfRealVector()).isEqualTo(halfVector); + Assertions.assertThat(halfVector.toFloatRealVector().toHalfRealVector()).isEqualTo(halfVector); + + final HalfRealVector halfVector2 = doubleVector.toHalfRealVector(); + Assertions.assertThat(halfVector2).isEqualTo(halfVector); + Assertions.assertThat(halfVector2.toDoubleRealVector().toHalfRealVector()).isEqualTo(halfVector2); + + Assertions.assertThat(halfVector2.toDoubleRealVector().toFloatRealVector()) + .isEqualTo(doubleVector.toHalfRealVector().toFloatRealVector()); + } + } + + @Test + void testAlternativeConstructors() { + Assertions.assertThat(new DoubleRealVector(new Double[] {-3.0d, 0.0d, 2.0d})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + + Assertions.assertThat(new DoubleRealVector(new int[] {-3, 0, 2})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + + Assertions.assertThat(new DoubleRealVector(new long[] {-3L, 0L, 2L})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + + Assertions.assertThat(new FloatRealVector(new Float[] {-3.0f, 0.0f, 2.0f})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + + Assertions.assertThat(new FloatRealVector(new int[] {-3, 0, 2})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + + Assertions.assertThat(new FloatRealVector(new long[] {-3L, 0L, 2L})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + + Assertions.assertThat(new HalfRealVector(new Half[] {Half.valueOf(-3.0d), Half.valueOf(0.0d), Half.valueOf(2.0d)})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + + Assertions.assertThat(new HalfRealVector(new int[] {-3, 0, 2})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + + Assertions.assertThat(new HalfRealVector(new long[] {-3L, 0L, 2L})) + .satisfies(vector -> Assertions.assertThat(vector.getComponent(0)).isCloseTo(-3.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(1)).isCloseTo(0.0d, Offset.offset(2E-14)), + vector -> Assertions.assertThat(vector.getComponent(2)).isCloseTo(2.0d, Offset.offset(2E-14))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testDirectSerializationDeserialization(final long seed, final int numDimensions) { + final Random random = new Random(seed); + + final DoubleRealVector doubleVector = RealVectorTest.createRandomDoubleVector(random, numDimensions); + RealVector deserializedVector = DoubleRealVector.fromBytes(doubleVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(DoubleRealVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(doubleVector); + + final FloatRealVector floatVector = RealVectorTest.createRandomFloatVector(random, numDimensions); + deserializedVector = FloatRealVector.fromBytes(floatVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(FloatRealVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(floatVector); + + final HalfRealVector halfVector = RealVectorTest.createRandomHalfVector(random, numDimensions); + deserializedVector = HalfRealVector.fromBytes(halfVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(HalfRealVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(halfVector); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testNorm(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final DoubleRealVector zeroVector = new DoubleRealVector(new double[numDimensions]); + for (int i = 0; i < 1000; i ++) { + final DoubleRealVector doubleVector = createRandomDoubleVector(random, numDimensions); + Assertions.assertThat(doubleVector.l2Norm()) + .isCloseTo(Metric.EUCLIDEAN_METRIC.distance(doubleVector, zeroVector), Offset.offset(2E-14)); + + final FloatRealVector floatVector = createRandomFloatVector(random, numDimensions); + Assertions.assertThat(floatVector.l2Norm()) + .isCloseTo(Metric.EUCLIDEAN_METRIC.distance(floatVector, zeroVector), Offset.offset(2E-5)); + + final HalfRealVector halfVector = createRandomHalfVector(random, numDimensions); + Assertions.assertThat(halfVector.l2Norm()) + .isCloseTo(Metric.EUCLIDEAN_METRIC.distance(halfVector, zeroVector), Offset.offset(2E-2)); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testNormalize(final long seed, final int numDimensions) { + final Random random = new Random(seed); + for (int i = 0; i < 1000; i ++) { + final DoubleRealVector doubleVector = createRandomDoubleVector(random, numDimensions); + RealVector normalizedVector = doubleVector.normalize(); + Assertions.assertThat(normalizedVector.multiply(doubleVector.l2Norm())) + .satisfies(v -> Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(doubleVector, v)) + .isCloseTo(0, Offset.offset(2E-14))); + final FloatRealVector floatVector = createRandomFloatVector(random, numDimensions); + normalizedVector = floatVector.normalize(); + Assertions.assertThat(normalizedVector.multiply(floatVector.l2Norm())) + .satisfies(v -> Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(floatVector, v)) + .isCloseTo(0, Offset.offset(2E-5))); + final HalfRealVector halfVector = createRandomHalfVector(random, numDimensions); + normalizedVector = halfVector.normalize(); + Assertions.assertThat(normalizedVector.multiply(halfVector.l2Norm())) + .satisfies(v -> Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(halfVector, v)) + .isCloseTo(0, Offset.offset(2E-2))); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testAdd(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final DoubleRealVector zeroVector = new DoubleRealVector(new double[numDimensions]); + for (int i = 0; i < 1000; i ++) { + final DoubleRealVector doubleVector = createRandomDoubleVector(random, numDimensions); + + Assertions.assertThat(doubleVector.add(doubleVector)) + .satisfies(v -> + Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(doubleVector.multiply(2.0d), v)) + .isCloseTo(0, Offset.offset(2E-14))); + + Assertions.assertThat(doubleVector.add(1.0d).add(-1.0d)) + .satisfies(v -> + Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(doubleVector, v)) + .isCloseTo(0, Offset.offset(2E-14))); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + @SuppressWarnings("AssertBetweenInconvertibleTypes") + void testEqualityAndHashCode(final long seed, final int numDimensions) { + final Random random = new Random(seed); + for (int i = 0; i < 1000; i ++) { + final HalfRealVector halfVector = createRandomHalfVector(random, numDimensions); + Assertions.assertThat(halfVector.toDoubleRealVector().hashCode()).isEqualTo(halfVector.hashCode()); + Assertions.assertThat(halfVector.toDoubleRealVector()).isEqualTo(halfVector); + Assertions.assertThat(halfVector.toFloatRealVector().hashCode()).isEqualTo(halfVector.hashCode()); + Assertions.assertThat(halfVector.toFloatRealVector()).isEqualTo(halfVector); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testDot(final long seed, final int numDimensions) { + final Random random = new Random(seed); + for (int i = 0; i < 1000; i ++) { + final DoubleRealVector doubleVector1 = createRandomDoubleVector(random, numDimensions); + final DoubleRealVector doubleVector2 = createRandomDoubleVector(random, numDimensions); + double dot = doubleVector1.dot(doubleVector2); + Assertions.assertThat(dot).isEqualTo(doubleVector2.dot(doubleVector1)); + Assertions.assertThat(dot) + .isCloseTo(-Metric.DOT_PRODUCT_METRIC.distance(doubleVector1, doubleVector2), Offset.offset(2E-14)); + + final FloatRealVector floatVector1 = createRandomFloatVector(random, numDimensions); + final FloatRealVector floatVector2 = createRandomFloatVector(random, numDimensions); + dot = floatVector1.dot(floatVector2); + Assertions.assertThat(dot).isEqualTo(floatVector2.dot(floatVector1)); + Assertions.assertThat(dot) + .isCloseTo(-Metric.DOT_PRODUCT_METRIC.distance(floatVector1, floatVector2), Offset.offset(2E-14)); + + final HalfRealVector halfVector1 = createRandomHalfVector(random, numDimensions); + final HalfRealVector halfVector2 = createRandomHalfVector(random, numDimensions); + dot = halfVector1.dot(halfVector2); + Assertions.assertThat(dot).isEqualTo(halfVector2.dot(halfVector1)); + Assertions.assertThat(dot) + .isCloseTo(-Metric.DOT_PRODUCT_METRIC.distance(halfVector1, halfVector2), Offset.offset(2E-14)); + } + } + + @Nonnull + public static DoubleRealVector createRandomDoubleVector(@Nonnull final Random random, final int numDimensions) { + return new DoubleRealVector(createRandomVectorData(random, numDimensions)); + } + + @Nonnull + public static FloatRealVector createRandomFloatVector(@Nonnull final Random random, final int numDimensions) { + return new FloatRealVector(createRandomVectorData(random, numDimensions)); + } + + @Nonnull + public static HalfRealVector createRandomHalfVector(@Nonnull final Random random, final int numDimensions) { + return new HalfRealVector(createRandomVectorData(random, numDimensions)); + } + + @Nonnull + public static double[] createRandomVectorData(@Nonnull final Random random, final int numDimensions) { + final double[] components = new double[numDimensions]; + for (int d = 0; d < numDimensions; d ++) { + components[d] = random.nextDouble() * (random.nextBoolean() ? -1 : 1); + } + return components; + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/StoredVecsIteratorTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/StoredVecsIteratorTest.java new file mode 100644 index 0000000000..cd9c18684f --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/StoredVecsIteratorTest.java @@ -0,0 +1,64 @@ +/* + * StoredVecsIteratorTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.google.common.collect.ImmutableSet; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +public class StoredVecsIteratorTest { + @SuppressWarnings("checkstyle:AbbreviationAsWordInName") + @Test + void readSIFT() throws IOException { + final Path siftSmallGroundTruthPath = Paths.get(".out/extracted/siftsmall/siftsmall_groundtruth.ivecs"); + final Path siftSmallQueryPath = Paths.get(".out/extracted/siftsmall/siftsmall_query.fvecs"); + + int numRecordsRead = 0; + try (final var queryChannel = FileChannel.open(siftSmallQueryPath, StandardOpenOption.READ); + final var groundTruthChannel = FileChannel.open(siftSmallGroundTruthPath, StandardOpenOption.READ)) { + final Iterator queryIterator = new StoredVecsIterator.StoredFVecsIterator(queryChannel); + final Iterator> groundTruthIterator = new StoredVecsIterator.StoredIVecsIterator(groundTruthChannel); + + Assertions.assertThat(queryIterator.hasNext()).isEqualTo(groundTruthIterator.hasNext()); + + while (queryIterator.hasNext()) { + final HalfRealVector queryVector = queryIterator.next().toHalfRealVector(); + Assertions.assertThat(queryVector.getNumDimensions()).isEqualTo(128); + + final Set groundTruthIndices = ImmutableSet.copyOf(groundTruthIterator.next()); + Assertions.assertThat(groundTruthIndices.size()).isEqualTo(100); + + Assertions.assertThat(groundTruthIndices).allSatisfy(index -> Assertions.assertThat(index).isBetween(0, 99999)); + numRecordsRead++; + } + } + Assertions.assertThat(numRecordsRead).isEqualTo(100); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/rabitq/RaBitQuantizerTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/rabitq/RaBitQuantizerTest.java new file mode 100644 index 0000000000..83a12b9b57 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/rabitq/RaBitQuantizerTest.java @@ -0,0 +1,240 @@ +/* + * RaBitQuantizerTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.rabitq; + +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.linear.FhtKacRotator; +import com.apple.foundationdb.linear.Metric; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.RealVectorTest; +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Locale; +import java.util.Objects; +import java.util.Random; +import java.util.stream.Stream; + +import static com.apple.foundationdb.linear.RealVectorTest.createRandomDoubleVector; + +public class RaBitQuantizerTest { + private static final Logger logger = LoggerFactory.getLogger(RaBitQuantizerTest.class); + + @Nonnull + private static Stream randomSeedsWithNumDimensionsAndNumExBits() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) + .flatMap(seed -> + Sets.cartesianProduct(ImmutableSet.of(3, 5, 10, 128, 768, 1000), + ImmutableSet.of(3, 4, 5, 6, 7, 8)) + .stream() + .map(arguments -> Arguments.of(seed, arguments.get(0), arguments.get(1)))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensionsAndNumExBits") + void basicEncodeTest(final long seed, final int numDimensions, final int numExBits) { + final Random random = new Random(seed); + final RealVector v = createRandomDoubleVector(random, numDimensions); + + final RaBitQuantizer quantizer = new RaBitQuantizer(Metric.EUCLIDEAN_SQUARE_METRIC, numExBits); + final EncodedRealVector encodedVector = quantizer.encode(v); + + // v and the re-centered encoded vector should be pointing into the same direction + final double[] reCenteredData = new double[numDimensions]; + final double cb = -(((1 << numExBits) - 0.5)); + for (int i = 0; i < numDimensions; i ++) { + reCenteredData[i] = (double)encodedVector.getEncodedComponent(i) + cb; + } + final RealVector reCentered = new DoubleRealVector(reCenteredData); + + // normalize both vectors so their dot product should be 1 + final RealVector v_bar = v.normalize(); + final RealVector reCenteredBar = reCentered.normalize(); + Assertions.assertThat(v_bar.dot(reCenteredBar)).isCloseTo(1, Offset.offset(0.01)); + } + + /** + * Create a random vector {@code v}, encode it into {@code encodedV} and estimate the distance between {@code v} and + * {@code encodedV} which should be very close to {@code 0}. + * @param seed a seed + * @param numDimensions the number of dimensions + * @param numExBits the number of bits per dimension used for encoding + */ + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensionsAndNumExBits") + void basicEncodeWithEstimationTest(final long seed, final int numDimensions, final int numExBits) { + final Random random = new Random(seed); + final RealVector v = createRandomDoubleVector(random, numDimensions); + final RaBitQuantizer quantizer = new RaBitQuantizer(Metric.EUCLIDEAN_SQUARE_METRIC, numExBits); + final EncodedRealVector encodedV = quantizer.encode(v); + final RaBitEstimator estimator = quantizer.estimator(); + final double estimatedDistance = estimator.distance(v, encodedV); + Assertions.assertThat(estimatedDistance).isCloseTo(0.0d, Offset.offset(0.01)); + } + + @Nonnull + private static Stream estimationArgs() { + return Stream.of( + Arguments.of(new double[]{0.5d, 0.5d}, new double[]{1.0d, 1.0d}, new double[]{-1.0d, 1.0d}, 4.0d), + Arguments.of(new double[]{0.0d, 0.0d}, new double[]{1.0d, 0.0d}, new double[]{0.0d, 1.0d}, 2.0d), + Arguments.of(new double[]{0.0d, 0.0d}, new double[]{0.0d, 0.0d}, new double[]{1.0d, 1.0d}, 2.0d) + ); + } + + @ParameterizedTest + @MethodSource("estimationArgs") + void basicEncodeWithEstimationTestSpecialValues(final double[] centroidData, final double[] vData, + final double[] qData, final double expectedDistance) { + final RealVector v = new DoubleRealVector(vData); + final RealVector q = new DoubleRealVector(qData); + + final RaBitQuantizer quantizer = new RaBitQuantizer(Metric.EUCLIDEAN_SQUARE_METRIC, 7); + final EncodedRealVector encodedVector = quantizer.encode(v); + final RaBitEstimator estimator = quantizer.estimator(); + final RaBitEstimator.Result estimatedDistanceResult = estimator.estimateDistanceAndErrorBound(q, encodedVector); + logger.info("estimated distance result = {}", estimatedDistanceResult); + Assertions.assertThat(estimatedDistanceResult.getDistance()) + .isCloseTo(expectedDistance, Offset.offset(0.01d)); + + final EncodedRealVector encodedVector2 = quantizer.encode(v); + Assertions.assertThat(encodedVector2.hashCode()).isEqualTo(encodedVector.hashCode()); + Assertions.assertThat(encodedVector2).isEqualTo(encodedVector); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensionsAndNumExBits") + void encodeManyWithEstimationsTest(final long seed, final int numDimensions, final int numExBits) { + final Random random = new Random(seed); + final FhtKacRotator rotator = new FhtKacRotator(seed, numDimensions, 10); + final int numRounds = 500; + int numEstimationWithinBounds = 0; + int numEstimationBetter = 0; + double sumRelativeError = 0.0d; + for (int round = 0; round < numRounds; round ++) { + RealVector v = null; + RealVector q = null; + RealVector sum = null; + final int numVectorsForCentroid = 10; + for (int i = 0; i < numVectorsForCentroid; i++) { + if (q == null) { + if (v != null) { + q = v; + } + } + + v = RealVectorTest.createRandomDoubleVector(random, numDimensions); + if (sum == null) { + sum = v; + } else { + sum.add(v); + } + } + Objects.requireNonNull(v); + Objects.requireNonNull(q); + + final RealVector centroid = sum.multiply(1.0d / numVectorsForCentroid); + + logger.trace("q = {}", q); + logger.trace("v = {}", v); + logger.trace("centroid = {}", centroid); + + final RealVector centroidRot = rotator.operateTranspose(centroid); + final RealVector qTrans = rotator.operateTranspose(q).subtract(centroidRot); + final RealVector vTrans = rotator.operateTranspose(v).subtract(centroidRot); + + logger.trace("qTrans = {}", qTrans); + logger.trace("vTrans = {}", vTrans); + logger.trace("centroidRot = {}", centroidRot); + + final RaBitQuantizer quantizer = new RaBitQuantizer(Metric.EUCLIDEAN_SQUARE_METRIC, numExBits); + final RaBitQuantizer.Result resultV = quantizer.encodeInternal(vTrans); + final EncodedRealVector encodedV = resultV.encodedVector; + logger.trace("fAddEx vor v = {}", encodedV.getAddEx()); + logger.trace("fRescaleEx vor v = {}", encodedV.getRescaleEx()); + logger.trace("fErrorEx vor v = {}", encodedV.getErrorEx()); + + final EncodedRealVector encodedQ = quantizer.encode(qTrans); + final RaBitEstimator estimator = quantizer.estimator(); + final RealVector reconstructedQ = rotator.operate(encodedQ.add(centroidRot)); + final RealVector reconstructedV = rotator.operate(encodedV.add(centroidRot)); + final RaBitEstimator.Result estimatedDistance = estimator.estimateDistanceAndErrorBound(qTrans, encodedV); + logger.trace("estimated ||qRot - vRot||^2 = {}", estimatedDistance); + final double trueDistance = Metric.EUCLIDEAN_SQUARE_METRIC.distance(vTrans, qTrans); + logger.trace("true ||qRot - vRot||^2 = {}", trueDistance); + if (trueDistance >= estimatedDistance.getDistance() - estimatedDistance.getErr() && + trueDistance < estimatedDistance.getDistance() + estimatedDistance.getErr()) { + numEstimationWithinBounds++; + } + logger.trace("reconstructed q = {}", reconstructedQ); + logger.trace("reconstructed v = {}", reconstructedV); + logger.trace("true ||qDec - vDec||^2 = {}", Metric.EUCLIDEAN_SQUARE_METRIC.distance(reconstructedV, reconstructedQ)); + final double reconstructedDistance = Metric.EUCLIDEAN_SQUARE_METRIC.distance(reconstructedV, q); + logger.trace("true ||q - vDec||^2 = {}", reconstructedDistance); + double error = Math.abs(estimatedDistance.getDistance() - trueDistance); + if (error < Math.abs(reconstructedDistance - trueDistance)) { + numEstimationBetter ++; + } + sumRelativeError += error / trueDistance; + } + logger.info("estimator within bounds = {}%", String.format(Locale.ROOT, "%.2f", (double)numEstimationWithinBounds * 100.0d / numRounds)); + logger.info("estimator better than reconstructed distance = {}%", String.format(Locale.ROOT, "%.2f", (double)numEstimationBetter * 100.0d / numRounds)); + logger.info("relative error = {}%", String.format(Locale.ROOT, "%.2f", sumRelativeError * 100.0d / numRounds)); + + Assertions.assertThat((double)numEstimationWithinBounds / numRounds).isGreaterThan(0.9); + Assertions.assertThat((double)numEstimationBetter / numRounds).isBetween(0.3, 0.7); + Assertions.assertThat(sumRelativeError / numRounds).isLessThan(0.1d); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensionsAndNumExBits") + void serializationRoundTripTest(final long seed, final int numDimensions, final int numExBits) { + final Random random = new Random(seed); + final RealVector v = createRandomDoubleVector(random, numDimensions); + final RaBitQuantizer quantizer = new RaBitQuantizer(Metric.EUCLIDEAN_SQUARE_METRIC, numExBits); + final EncodedRealVector encodedVector = quantizer.encode(v); + final byte[] rawData = encodedVector.getRawData(); + final EncodedRealVector deserialized = EncodedRealVector.fromBytes(rawData, numDimensions, numExBits); + Assertions.assertThat(deserialized).isEqualTo(encodedVector); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensionsAndNumExBits") + void precisionTest(final long seed, final int numDimensions, final int numExBits) { + final Random random = new Random(seed); + final RealVector v = createRandomDoubleVector(random, numDimensions); + final RaBitQuantizer quantizer = new RaBitQuantizer(Metric.EUCLIDEAN_SQUARE_METRIC, numExBits); + final EncodedRealVector encodedVector = quantizer.encode(v); + final DoubleRealVector reconstructedDoubleVector = encodedVector.toDoubleRealVector(); + Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(encodedVector.toFloatRealVector(), + reconstructedDoubleVector.toFloatRealVector())).isCloseTo(0, Offset.offset(0.1)); + Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(encodedVector.toHalfRealVector(), + reconstructedDoubleVector.toHalfRealVector())).isCloseTo(0, Offset.offset(0.1)); + } +} diff --git a/gradle/codequality/suppressions.xml b/gradle/codequality/suppressions.xml index 0d6f5ffcfa..7e5264601c 100644 --- a/gradle/codequality/suppressions.xml +++ b/gradle/codequality/suppressions.xml @@ -16,4 +16,6 @@ files=".*[\\/]generated[\\/].*"/> +