2424import com .apple .foundationdb .linear .Metric ;
2525import com .apple .foundationdb .linear .Quantizer ;
2626import com .apple .foundationdb .linear .RealVector ;
27+ import com .google .common .annotations .VisibleForTesting ;
28+ import com .google .common .base .Preconditions ;
2729
2830import javax .annotation .Nonnull ;
2931import java .util .Comparator ;
3032import java .util .PriorityQueue ;
3133
34+ /**
35+ * Implements the RaBit quantization scheme, a technique for compressing high-dimensional vectors into a compact
36+ * integer-based representation.
37+ * <p>
38+ * This class provides the logic to encode a {@link RealVector} into an {@link EncodedRealVector}.
39+ * The encoding process involves finding an optimal scaling factor, quantizing the vector's components,
40+ * and pre-calculating values that facilitate efficient distance estimation in the quantized space.
41+ * It is configured with a specific {@link Metric} and a number of "extra bits" ({@code numExBits})
42+ * which control the precision of the quantization.
43+ * <p>
44+ * Note that this implementation largely follows this <a href="https://arxiv.org/pdf/2409.09913">paper</a>
45+ * by Jianyang Gao et al. It also mirrors algorithmic similarity, terms, and variable/method naming-conventions of the
46+ * C++ implementation that can be found <a href="https://github.com/VectorDB-NTU/RaBitQ-Library">here</a>.
47+ *
48+ * @see Quantizer
49+ * @see RaBitEstimator
50+ * @see EncodedRealVector
51+ */
3252public final class RaBitQuantizer implements Quantizer {
3353 private static final double EPS = 1e-5 ;
3454 private static final double EPS0 = 1.9 ;
3555 private static final int N_ENUM = 10 ;
3656
37- // Matches kTightStart[] from the C++ (index by ex_bits).
3857 // 0th entry unused; defined up to 8 extra bits in the source.
3958 private static final double [] TIGHT_START = {
4059 0.00 , 0.15 , 0.20 , 0.52 , 0.59 , 0.71 , 0.75 , 0.77 , 0.81
4160 };
4261
43- @ Nonnull
44- private final RealVector centroid ;
4562 final int numExBits ;
4663 @ Nonnull
4764 private final Metric metric ;
4865
49- public RaBitQuantizer (@ Nonnull final Metric metric ,
50- @ Nonnull final RealVector centroid ,
51- final int numExBits ) {
52- this .centroid = centroid ;
66+ /**
67+ * Constructs a new {@code RaBitQuantizer} instance.
68+ * <p>
69+ * This constructor initializes the quantizer with a specific metric and the number of
70+ * extra bits to be used in the quantization process.
71+ *
72+ * @param metric the {@link Metric} to be used for quantization; must not be null.
73+ * @param numExBits the number of extra bits for quantization.
74+ */
75+ public RaBitQuantizer (@ Nonnull final Metric metric , final int numExBits ) {
76+ Preconditions .checkArgument (numExBits > 0 && numExBits < TIGHT_START .length );
77+
5378 this .numExBits = numExBits ;
5479 this .metric = metric ;
5580 }
5681
57- public int getNumDimensions () {
58- return centroid .getNumDimensions ();
59- }
60-
82+ /**
83+ * Creates and returns a new {@link RaBitEstimator} instance.
84+ * <p>
85+ * This method acts as a factory, constructing the estimator based on the
86+ * {@code metric} and {@code numExBits} configuration of this object.
87+ * The {@code @Override} annotation indicates that this is an implementation
88+ * of a method from a superclass or interface.
89+ *
90+ * @return a new, non-null instance of {@link RaBitEstimator}
91+ */
6192 @ Nonnull
6293 @ Override
6394 public RaBitEstimator estimator () {
64- return new RaBitEstimator (metric , centroid , numExBits );
95+ return new RaBitEstimator (metric , numExBits );
6596 }
6697
98+ /**
99+ * Encodes a given {@link RealVector} into its corresponding encoded representation.
100+ * <p>
101+ * This method overrides the parent's {@code encode} method. It delegates the
102+ * core encoding logic to an internal helper method and returns the final
103+ * {@link EncodedRealVector}.
104+ *
105+ * @param data the {@link RealVector} to be encoded; must not be null.
106+ *
107+ * @return the resulting {@link EncodedRealVector}, guaranteed to be non-null.
108+ */
67109 @ Nonnull
68110 @ Override
69111 public EncodedRealVector encode (@ Nonnull final RealVector data ) {
70112 return encodeInternal (data ).getEncodedVector ();
71113 }
72114
73115 /**
74- * Port of ex_bits_code_with_factor:
75- * - params: data & centroid (rotated)
76- * - forms residual internally
77- * - computes shifted signed vector here (sign(r)*(k+0.5))
78- * - applies C++ metric-dependent formulas exactly.
116+ * Encodes a real-valued vector into a quantized representation.
117+ * <p>
118+ * This is an internal method that performs the core encoding logic. It first
119+ * generates a base code using {@link #exBitsCode(RealVector)}, then incorporates
120+ * sign information to create the final code. It precomputes various geometric
121+ * properties (norms, dot products) of the original vector and its quantized
122+ * counterpart to calculate metric-specific scaling and error factors. These
123+ * factors are used for efficient distance calculations with the encoded vector.
124+ *
125+ * @param data the real-valued vector to be encoded. Must not be null.
126+ * @return a {@code Result} object containing the {@link EncodedRealVector} and
127+ * other intermediate values from the encoding process. The result is never null.
128+ *
129+ * @throws IllegalArgumentException if the configured {@code metric} is not supported for encoding.
79130 */
80131 @ Nonnull
132+ @ VisibleForTesting
81133 Result encodeInternal (@ Nonnull final RealVector data ) {
82134 final int dims = data .getNumDimensions ();
83135
@@ -132,10 +184,9 @@ Result encodeInternal(@Nonnull final RealVector data) {
132184 }
133185
134186 /**
135- * Builds per-dimension extra-bit levels using the best t found by bestRescaleFactor() and returns
136- * ipNormInv.
137- * @param residual Rotated residual vector r (same thing the C++ feeds here).
138- * This method internally uses |r| normalized to unit L2.
187+ * Builds per-dimension extra-bit code using the best {@code t} found by {@link #bestRescaleFactor(RealVector)} and
188+ * returns the code, {@code t}, and {@code ipNormInv}.
189+ * @param residual rotated residual vector r.
139190 */
140191 private QuantizeExResult exBitsCode (@ Nonnull final RealVector residual ) {
141192 int dims = residual .getNumDimensions ();
@@ -164,12 +215,11 @@ private QuantizeExResult exBitsCode(@Nonnull final RealVector residual) {
164215 /**
165216 * Method to quantize a vector.
166217 *
167- * @param oAbs absolute values of a L2-normalized residual vector (nonnegative; length = dim)
168- * @return quantized levels (ex-bits), the chosen scale t, and ipNormInv
169- * Notes:
170- * - If the residual is the all-zero vector (or numerically so), this returns zero codes,
171- * t = 0, and ipNormInv = 1 (benign fallback).
172- * - Downstream code (ex_bits_code_with_factor) uses ipNormInv to compute f_rescale_ex, etc.
218+ * @param oAbs absolute values of a L2-normalized residual vector (nonnegative; length = dim)
219+ * @return quantized levels (ex-bits), the chosen scale t, and ipNormInv
220+ * Notes: If the residual is the all-zero vector (or numerically so), this returns zero codes,
221+ * {@code t = 0}, and {@code ipNormInv = 1} (benign fallback). Downstream code uses {@code ipNormInv} to
222+ * compute {@code fRescaleEx}, etc.
173223 */
174224 private QuantizeExResult quantizeEx (@ Nonnull final RealVector oAbs ) {
175225 final int dim = oAbs .getNumDimensions ();
@@ -206,14 +256,28 @@ private QuantizeExResult quantizeEx(@Nonnull final RealVector oAbs) {
206256 }
207257
208258 /**
209- * Method to compute the best factor {@code t}.
210- * @param oAbs absolute values of a (row-wise) normalized residual; length = dim; nonnegative
211- * @return t the rescale factor that maximizes the objective
259+ * Calculates the best rescaling factor {@code t} for a given vector of absolute values.
260+ * <p>
261+ * This method implements an efficient algorithm to find a scaling factor {@code t}
262+ * that maximizes an objective function related to the quantization of the input vector.
263+ * The objective function being maximized is effectively
264+ * {@code sum(u_i * o_i) / sqrt(sum(u_i^2 + u_i))}, where {@code u_i = floor(t * o_i)}
265+ * and {@code o_i} are the components of the input vector {@code oAbs}.
266+ * <p>
267+ * The algorithm performs a sweep over the scaling factor {@code t}. It uses a
268+ * min-priority queue to efficiently jump between critical values of {@code t} where
269+ * the floor of {@code t * o_i} changes for some coordinate {@code i}. The search is
270+ * bounded within a pre-calculated "tight" range {@code [tStart, tEnd]} to ensure
271+ * efficiency.
272+ *
273+ * @param oAbs The vector of absolute values for which to find the best rescale factor.
274+ * Components must be non-negative.
275+ *
276+ * @return The optimal scaling factor {@code t} that maximizes the objective function,
277+ * or 0.0 if the input vector is all zeros.
212278 */
213279 private double bestRescaleFactor (@ Nonnull final RealVector oAbs ) {
214- if (numExBits < 0 || numExBits >= TIGHT_START .length ) {
215- throw new IllegalArgumentException ("numExBits out of supported range" );
216- }
280+ final int numDimensions = oAbs .getNumDimensions ();
217281
218282 // max_o = max(oAbs)
219283 double maxO = 0.0d ;
@@ -232,10 +296,10 @@ private double bestRescaleFactor(@Nonnull final RealVector oAbs) {
232296 final double tStart = tEnd * TIGHT_START [numExBits ];
233297
234298 // cur_o_bar[i] = floor(tStart * oAbs[i]), but stored as int
235- final int [] curOB = new int [getNumDimensions () ];
236- double sqrDen = getNumDimensions () * 0.25 ; // Σ (cur^2 + cur) starts from D/4
299+ final int [] curOB = new int [numDimensions ];
300+ double sqrDen = numDimensions * 0.25 ; // Σ (cur^2 + cur) starts from D/4
237301 double numer = 0.0 ;
238- for (int i = 0 ; i < getNumDimensions () ; i ++) {
302+ for (int i = 0 ; i < numDimensions ; i ++) {
239303 int cur = (int ) ((tStart * oAbs .getComponent (i )) + EPS );
240304 curOB [i ] = cur ;
241305 sqrDen += (double ) cur * cur + cur ;
@@ -246,7 +310,7 @@ private double bestRescaleFactor(@Nonnull final RealVector oAbs) {
246310 // t_i(k->k+1) = (curOB[i] + 1) / oAbs[i]
247311
248312 final PriorityQueue <Node > pq = new PriorityQueue <>(Comparator .comparingDouble (n -> n .t ));
249- for (int i = 0 ; i < getNumDimensions () ; i ++) {
313+ for (int i = 0 ; i < numDimensions ; i ++) {
250314 final double curOAbs = oAbs .getComponent (i );
251315 if (curOAbs > 0.0 ) {
252316 double tNext = (curOB [i ] + 1 ) / curOAbs ;
@@ -291,6 +355,19 @@ private double bestRescaleFactor(@Nonnull final RealVector oAbs) {
291355 return bestT ;
292356 }
293357
358+ /**
359+ * Computes a new vector containing the element-wise absolute values of the L2-normalized input vector.
360+ * <p>
361+ * This operation is equivalent to first normalizing the vector {@code x} by its L2 norm,
362+ * and then taking the absolute value of each resulting component. If the L2 norm of {@code x}
363+ * is zero or not finite (e.g., {@link Double#POSITIVE_INFINITY}), a new zero vector of the
364+ * same dimension is returned.
365+ *
366+ * @param x the input vector to be normalized and processed. Must not be null.
367+ *
368+ * @return a new {@code RealVector} containing the absolute values of the components of the
369+ * normalized input vector.
370+ */
294371 private static RealVector absOfNormalized (@ Nonnull final RealVector x ) {
295372 double n = x .l2Norm ();
296373 double [] y = new double [x .getNumDimensions ()];
0 commit comments