Skip to content

Commit 04728f3

Browse files
committed
more tests
1 parent 2443459 commit 04728f3

File tree

17 files changed

+476
-172
lines changed

17 files changed

+476
-172
lines changed

ACKNOWLEDGEMENTS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,11 @@ Christian Heina (HALF4J)
232232
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
233233
See the License for the specific language governing permissions and
234234
limitations under the License.
235+
236+
Jianyang Gao, Yutong Gou, Yuexuan Xu, Yongyi Yang, Cheng Long, Raymond Chi-Wing Wong,
237+
"Practical and Asymptotically Optimal Quantization of High-Dimensional Vectors in Euclidean Space for
238+
Approximate Nearest Neighbor Search",
239+
SIGMOD 2025, available at https://arxiv.org/abs/2409.09913
240+
241+
Yutong Gou, Jianyang Gao, Yuexuan Xu, Jifan Shi and Zhonghao Yang
242+
https://github.com/VectorDB-NTU/RaBitQ-Library/blob/main/LICENSE

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,8 @@ RealVector centroidRot(@Nonnull final FhtKacRotator rotator) {
548548
}
549549

550550
@Nonnull
551-
Quantizer raBitQuantizer(@Nonnull final RealVector centroidRot) {
552-
return new RaBitQuantizer(Metric.EUCLIDEAN_METRIC, centroidRot, getConfig().getRaBitQNumExBits());
551+
Quantizer raBitQuantizer() {
552+
return new RaBitQuantizer(Metric.EUCLIDEAN_METRIC, getConfig().getRaBitQNumExBits());
553553
}
554554

555555
//
@@ -596,7 +596,7 @@ public CompletableFuture<? extends List<? extends NodeReferenceAndNode<? extends
596596
final RealVector centroidRot = centroidRot(rotator);
597597
final RealVector queryVectorRot = rotator.operateTranspose(queryVector);
598598
queryVectorTrans = queryVectorRot.subtract(centroidRot);
599-
quantizer = raBitQuantizer(centroidRot);
599+
quantizer = raBitQuantizer();
600600
} else {
601601
queryVectorTrans = queryVector;
602602
quantizer = Quantizer.noOpQuantizer(Metric.EUCLIDEAN_METRIC);
@@ -1127,7 +1127,7 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
11271127
final RealVector centroidRot = centroidRot(rotator);
11281128
final RealVector newVectorRot = rotator.operateTranspose(newVector);
11291129
newVectorTrans = newVectorRot.subtract(centroidRot);
1130-
quantizer = raBitQuantizer(centroidRot);
1130+
quantizer = raBitQuantizer();
11311131
} else {
11321132
newVectorTrans = newVector;
11331133
quantizer = Quantizer.noOpQuantizer(Metric.EUCLIDEAN_METRIC);
@@ -1228,7 +1228,7 @@ public CompletableFuture<Void> insertBatch(@Nonnull final Transaction transactio
12281228
if (getConfig().isUseRaBitQ()) {
12291229
rotator = new FhtKacRotator(0, getConfig().getNumDimensions(), 10);
12301230
centroidRot = centroidRot(rotator);
1231-
quantizer = raBitQuantizer(centroidRot);
1231+
quantizer = raBitQuantizer();
12321232
} else {
12331233
rotator = null;
12341234
centroidRot = null;
@@ -1913,24 +1913,6 @@ private int insertionLayer() {
19131913
return (int) Math.floor(-Math.log(u) * lambda);
19141914
}
19151915

1916-
/**
1917-
* Logs a message at the INFO level, using a consumer for lazy evaluation.
1918-
* <p>
1919-
* This approach avoids the cost of constructing the log message if the INFO
1920-
* level is disabled. The provided {@link java.util.function.Consumer} will be
1921-
* executed only when {@code logger.isInfoEnabled()} returns {@code true}.
1922-
*
1923-
* @param loggerConsumer the {@link java.util.function.Consumer} that will be
1924-
* accepted if logging is enabled. It receives the
1925-
* {@code Logger} instance and must not be null.
1926-
*/
1927-
@SuppressWarnings("PMD.UnusedPrivateMethod")
1928-
private void info(@Nonnull final Consumer<Logger> loggerConsumer) {
1929-
if (logger.isInfoEnabled()) {
1930-
loggerConsumer.accept(logger);
1931-
}
1932-
}
1933-
19341916
private static class NodeReferenceWithLayer extends NodeReferenceWithVector {
19351917
private final int layer;
19361918

fdb-extensions/src/main/java/com/apple/foundationdb/async/rabitq/RaBitEstimator.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,11 @@
3030
public class RaBitEstimator implements Estimator {
3131
@Nonnull
3232
private final Metric metric;
33-
@Nonnull
34-
private final RealVector centroid;
3533
private final int numExBits;
3634

3735
public RaBitEstimator(@Nonnull final Metric metric,
38-
@Nonnull final RealVector centroid,
3936
final int numExBits) {
4037
this.metric = metric;
41-
this.centroid = centroid;
4238
this.numExBits = numExBits;
4339
}
4440

@@ -47,10 +43,6 @@ public Metric getMetric() {
4743
return metric;
4844
}
4945

50-
public int getNumDimensions() {
51-
return centroid.getNumDimensions();
52-
}
53-
5446
public int getNumExBits() {
5547
return numExBits;
5648
}
@@ -78,8 +70,7 @@ private double distance(@Nonnull final RealVector query, // pre-rotated query q
7870
public Result estimateDistanceAndErrorBound(@Nonnull final RealVector query, // pre-rotated query q
7971
@Nonnull final EncodedRealVector encodedVector) {
8072
final double cb = (1 << numExBits) - 0.5;
81-
final RealVector qc = query;
82-
final double gAdd = qc.dot(qc);
73+
final double gAdd = query.dot(query);
8374
final double gError = Math.sqrt(gAdd);
8475
final RealVector totalCode = new DoubleRealVector(encodedVector.getEncodedData());
8576
final RealVector xuc = totalCode.subtract(cb);
@@ -117,7 +108,7 @@ public double getErr() {
117108

118109
@Override
119110
public String toString() {
120-
return "Estimate[" + "distance=" + distance + ", err=" + err + "]";
111+
return "estimate[" + "distance=" + distance + ", err=" + err + "]";
121112
}
122113
}
123114
}

fdb-extensions/src/main/java/com/apple/foundationdb/async/rabitq/RaBitQuantizer.java

Lines changed: 114 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,60 +24,112 @@
2424
import com.apple.foundationdb.linear.Metric;
2525
import com.apple.foundationdb.linear.Quantizer;
2626
import com.apple.foundationdb.linear.RealVector;
27+
import com.google.common.annotations.VisibleForTesting;
28+
import com.google.common.base.Preconditions;
2729

2830
import javax.annotation.Nonnull;
2931
import java.util.Comparator;
3032
import java.util.PriorityQueue;
3133

34+
/**
35+
* Implements the RaBit quantization scheme, a technique for compressing high-dimensional vectors into a compact
36+
* integer-based representation.
37+
* <p>
38+
* This class provides the logic to encode a {@link RealVector} into an {@link EncodedRealVector}.
39+
* The encoding process involves finding an optimal scaling factor, quantizing the vector's components,
40+
* and pre-calculating values that facilitate efficient distance estimation in the quantized space.
41+
* It is configured with a specific {@link Metric} and a number of "extra bits" ({@code numExBits})
42+
* which control the precision of the quantization.
43+
* <p>
44+
* Note that this implementation largely follows this <a href="https://arxiv.org/pdf/2409.09913">paper</a>
45+
* by Jianyang Gao et al. It also mirrors algorithmic similarity, terms, and variable/method naming-conventions of the
46+
* C++ implementation that can be found <a href="https://github.com/VectorDB-NTU/RaBitQ-Library">here</a>.
47+
*
48+
* @see Quantizer
49+
* @see RaBitEstimator
50+
* @see EncodedRealVector
51+
*/
3252
public final class RaBitQuantizer implements Quantizer {
3353
private static final double EPS = 1e-5;
3454
private static final double EPS0 = 1.9;
3555
private static final int N_ENUM = 10;
3656

37-
// Matches kTightStart[] from the C++ (index by ex_bits).
3857
// 0th entry unused; defined up to 8 extra bits in the source.
3958
private static final double[] TIGHT_START = {
4059
0.00, 0.15, 0.20, 0.52, 0.59, 0.71, 0.75, 0.77, 0.81
4160
};
4261

43-
@Nonnull
44-
private final RealVector centroid;
4562
final int numExBits;
4663
@Nonnull
4764
private final Metric metric;
4865

49-
public RaBitQuantizer(@Nonnull final Metric metric,
50-
@Nonnull final RealVector centroid,
51-
final int numExBits) {
52-
this.centroid = centroid;
66+
/**
67+
* Constructs a new {@code RaBitQuantizer} instance.
68+
* <p>
69+
* This constructor initializes the quantizer with a specific metric and the number of
70+
* extra bits to be used in the quantization process.
71+
*
72+
* @param metric the {@link Metric} to be used for quantization; must not be null.
73+
* @param numExBits the number of extra bits for quantization.
74+
*/
75+
public RaBitQuantizer(@Nonnull final Metric metric, final int numExBits) {
76+
Preconditions.checkArgument(numExBits > 0 && numExBits < TIGHT_START.length);
77+
5378
this.numExBits = numExBits;
5479
this.metric = metric;
5580
}
5681

57-
public int getNumDimensions() {
58-
return centroid.getNumDimensions();
59-
}
60-
82+
/**
83+
* Creates and returns a new {@link RaBitEstimator} instance.
84+
* <p>
85+
* This method acts as a factory, constructing the estimator based on the
86+
* {@code metric} and {@code numExBits} configuration of this object.
87+
* The {@code @Override} annotation indicates that this is an implementation
88+
* of a method from a superclass or interface.
89+
*
90+
* @return a new, non-null instance of {@link RaBitEstimator}
91+
*/
6192
@Nonnull
6293
@Override
6394
public RaBitEstimator estimator() {
64-
return new RaBitEstimator(metric, centroid, numExBits);
95+
return new RaBitEstimator(metric, numExBits);
6596
}
6697

98+
/**
99+
* Encodes a given {@link RealVector} into its corresponding encoded representation.
100+
* <p>
101+
* This method overrides the parent's {@code encode} method. It delegates the
102+
* core encoding logic to an internal helper method and returns the final
103+
* {@link EncodedRealVector}.
104+
*
105+
* @param data the {@link RealVector} to be encoded; must not be null.
106+
*
107+
* @return the resulting {@link EncodedRealVector}, guaranteed to be non-null.
108+
*/
67109
@Nonnull
68110
@Override
69111
public EncodedRealVector encode(@Nonnull final RealVector data) {
70112
return encodeInternal(data).getEncodedVector();
71113
}
72114

73115
/**
74-
* Port of ex_bits_code_with_factor:
75-
* - params: data & centroid (rotated)
76-
* - forms residual internally
77-
* - computes shifted signed vector here (sign(r)*(k+0.5))
78-
* - applies C++ metric-dependent formulas exactly.
116+
* Encodes a real-valued vector into a quantized representation.
117+
* <p>
118+
* This is an internal method that performs the core encoding logic. It first
119+
* generates a base code using {@link #exBitsCode(RealVector)}, then incorporates
120+
* sign information to create the final code. It precomputes various geometric
121+
* properties (norms, dot products) of the original vector and its quantized
122+
* counterpart to calculate metric-specific scaling and error factors. These
123+
* factors are used for efficient distance calculations with the encoded vector.
124+
*
125+
* @param data the real-valued vector to be encoded. Must not be null.
126+
* @return a {@code Result} object containing the {@link EncodedRealVector} and
127+
* other intermediate values from the encoding process. The result is never null.
128+
*
129+
* @throws IllegalArgumentException if the configured {@code metric} is not supported for encoding.
79130
*/
80131
@Nonnull
132+
@VisibleForTesting
81133
Result encodeInternal(@Nonnull final RealVector data) {
82134
final int dims = data.getNumDimensions();
83135

@@ -132,10 +184,9 @@ Result encodeInternal(@Nonnull final RealVector data) {
132184
}
133185

134186
/**
135-
* Builds per-dimension extra-bit levels using the best t found by bestRescaleFactor() and returns
136-
* ipNormInv.
137-
* @param residual Rotated residual vector r (same thing the C++ feeds here).
138-
* This method internally uses |r| normalized to unit L2.
187+
* Builds per-dimension extra-bit code using the best {@code t} found by {@link #bestRescaleFactor(RealVector)} and
188+
* returns the code, {@code t}, and {@code ipNormInv}.
189+
* @param residual rotated residual vector r.
139190
*/
140191
private QuantizeExResult exBitsCode(@Nonnull final RealVector residual) {
141192
int dims = residual.getNumDimensions();
@@ -164,12 +215,11 @@ private QuantizeExResult exBitsCode(@Nonnull final RealVector residual) {
164215
/**
165216
* Method to quantize a vector.
166217
*
167-
* @param oAbs absolute values of a L2-normalized residual vector (nonnegative; length = dim)
168-
* @return quantized levels (ex-bits), the chosen scale t, and ipNormInv
169-
* Notes:
170-
* - If the residual is the all-zero vector (or numerically so), this returns zero codes,
171-
* t = 0, and ipNormInv = 1 (benign fallback).
172-
* - Downstream code (ex_bits_code_with_factor) uses ipNormInv to compute f_rescale_ex, etc.
218+
* @param oAbs absolute values of a L2-normalized residual vector (nonnegative; length = dim)
219+
* @return quantized levels (ex-bits), the chosen scale t, and ipNormInv
220+
* Notes: If the residual is the all-zero vector (or numerically so), this returns zero codes,
221+
* {@code t = 0}, and {@code ipNormInv = 1} (benign fallback). Downstream code uses {@code ipNormInv} to
222+
* compute {@code fRescaleEx}, etc.
173223
*/
174224
private QuantizeExResult quantizeEx(@Nonnull final RealVector oAbs) {
175225
final int dim = oAbs.getNumDimensions();
@@ -206,14 +256,28 @@ private QuantizeExResult quantizeEx(@Nonnull final RealVector oAbs) {
206256
}
207257

208258
/**
209-
* Method to compute the best factor {@code t}.
210-
* @param oAbs absolute values of a (row-wise) normalized residual; length = dim; nonnegative
211-
* @return t the rescale factor that maximizes the objective
259+
* Calculates the best rescaling factor {@code t} for a given vector of absolute values.
260+
* <p>
261+
* This method implements an efficient algorithm to find a scaling factor {@code t}
262+
* that maximizes an objective function related to the quantization of the input vector.
263+
* The objective function being maximized is effectively
264+
* {@code sum(u_i * o_i) / sqrt(sum(u_i^2 + u_i))}, where {@code u_i = floor(t * o_i)}
265+
* and {@code o_i} are the components of the input vector {@code oAbs}.
266+
* <p>
267+
* The algorithm performs a sweep over the scaling factor {@code t}. It uses a
268+
* min-priority queue to efficiently jump between critical values of {@code t} where
269+
* the floor of {@code t * o_i} changes for some coordinate {@code i}. The search is
270+
* bounded within a pre-calculated "tight" range {@code [tStart, tEnd]} to ensure
271+
* efficiency.
272+
*
273+
* @param oAbs The vector of absolute values for which to find the best rescale factor.
274+
* Components must be non-negative.
275+
*
276+
* @return The optimal scaling factor {@code t} that maximizes the objective function,
277+
* or 0.0 if the input vector is all zeros.
212278
*/
213279
private double bestRescaleFactor(@Nonnull final RealVector oAbs) {
214-
if (numExBits < 0 || numExBits >= TIGHT_START.length) {
215-
throw new IllegalArgumentException("numExBits out of supported range");
216-
}
280+
final int numDimensions = oAbs.getNumDimensions();
217281

218282
// max_o = max(oAbs)
219283
double maxO = 0.0d;
@@ -232,10 +296,10 @@ private double bestRescaleFactor(@Nonnull final RealVector oAbs) {
232296
final double tStart = tEnd * TIGHT_START[numExBits];
233297

234298
// cur_o_bar[i] = floor(tStart * oAbs[i]), but stored as int
235-
final int[] curOB = new int[getNumDimensions()];
236-
double sqrDen = getNumDimensions() * 0.25; // Σ (cur^2 + cur) starts from D/4
299+
final int[] curOB = new int[numDimensions];
300+
double sqrDen = numDimensions * 0.25; // Σ (cur^2 + cur) starts from D/4
237301
double numer = 0.0;
238-
for (int i = 0; i < getNumDimensions(); i++) {
302+
for (int i = 0; i < numDimensions; i++) {
239303
int cur = (int) ((tStart * oAbs.getComponent(i)) + EPS);
240304
curOB[i] = cur;
241305
sqrDen += (double) cur * cur + cur;
@@ -246,7 +310,7 @@ private double bestRescaleFactor(@Nonnull final RealVector oAbs) {
246310
// t_i(k->k+1) = (curOB[i] + 1) / oAbs[i]
247311

248312
final PriorityQueue<Node> pq = new PriorityQueue<>(Comparator.comparingDouble(n -> n.t));
249-
for (int i = 0; i < getNumDimensions(); i++) {
313+
for (int i = 0; i < numDimensions; i++) {
250314
final double curOAbs = oAbs.getComponent(i);
251315
if (curOAbs > 0.0) {
252316
double tNext = (curOB[i] + 1) / curOAbs;
@@ -291,6 +355,19 @@ private double bestRescaleFactor(@Nonnull final RealVector oAbs) {
291355
return bestT;
292356
}
293357

358+
/**
359+
* Computes a new vector containing the element-wise absolute values of the L2-normalized input vector.
360+
* <p>
361+
* This operation is equivalent to first normalizing the vector {@code x} by its L2 norm,
362+
* and then taking the absolute value of each resulting component. If the L2 norm of {@code x}
363+
* is zero or not finite (e.g., {@link Double#POSITIVE_INFINITY}), a new zero vector of the
364+
* same dimension is returned.
365+
*
366+
* @param x the input vector to be normalized and processed. Must not be null.
367+
*
368+
* @return a new {@code RealVector} containing the absolute values of the components of the
369+
* normalized input vector.
370+
*/
294371
private static RealVector absOfNormalized(@Nonnull final RealVector x) {
295372
double n = x.l2Norm();
296373
double[] y = new double[x.getNumDimensions()];

fdb-extensions/src/main/java/com/apple/foundationdb/half/Half.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
*
3030
* @author Christian Heina ([email protected])
3131
*/
32-
public class Half extends Number implements Comparable<Half> {
32+
public final class Half extends Number implements Comparable<Half> {
3333

3434
/**
3535
* A constant holding the positive infinity of type {@code Half}.

fdb-extensions/src/main/java/com/apple/foundationdb/linear/ColumnMajorRealMatrix.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ public class ColumnMajorRealMatrix implements RealMatrix {
3434
private final Supplier<Integer> hashCodeSupplier;
3535

3636
public ColumnMajorRealMatrix(@Nonnull final double[][] data) {
37+
Preconditions.checkArgument(data.length > 0);
38+
Preconditions.checkArgument(data[0].length > 0);
3739
this.data = data;
3840
this.hashCodeSupplier = Suppliers.memoize(this::valueBasedHashCode);
3941
}

0 commit comments

Comments
 (0)