diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index d550c3572..6ba92be66 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -200,6 +200,8 @@ namespace hnswlib std::vector element_levels_; // keeps level of each element size_t data_size_{0}; + size_t original_data_size_{0}; // original float32 data size (before quantization) + TurboQuantizer* quantizer_{nullptr}; // non-null when TurboQuant is enabled DISTFUNC fstdistfunc_; void *dist_func_param_{nullptr}; @@ -294,8 +296,12 @@ namespace hnswlib max_elements_ = max_elements; num_deleted_ = 0; data_size_ = s->get_data_size(); + original_data_size_ = data_size_; // may differ from data_size_ if quantized fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); + + // Quantizer is not set here — it will be set externally + // via set_quantizer() from code that has access to TurboQuantL2Space/IPSpace M_ = M; maxM_ = M_; maxM0_ = M_ * 2; @@ -344,6 +350,14 @@ namespace hnswlib } } + // Set the quantizer after construction (avoids circular include dependencies) + void set_quantizer(TurboQuantizer* q) { + quantizer_ = q; + if (q) { + original_data_size_ = q->dim_ * sizeof(float); + } + } + ~HierarchicalNSW() { free(data_level0_memory_); @@ -1521,6 +1535,22 @@ namespace hnswlib lock_table.unlock(); char *data_ptrv = getDataByInternalId(internalId); + + // When TurboQuant is enabled, dequantize the stored codes back to float32 + if (quantizer_) { + size_t dim = quantizer_->dim_; + std::vector data(dim); + quantizer_->dequantize((const uint8_t*)data_ptrv, data.data()); + // If normalize was used during insert, undo normalization + if (normalize_) { + float length = ((float *)length_memory_)[internalId]; + for (size_t i = 0; i < dim; i++) { + data[i] *= length; + } + } + return data; + } + float length = 1.0; if (normalize_) { @@ -1710,7 +1740,7 @@ namespace hnswlib { const void *newPoint = dataPoint; - size_t dim = *((size_t *)dist_func_param_); + size_t dim = quantizer_ ? (size_t)quantizer_->dim_ : *((size_t *)dist_func_param_); std::vector norm_array(dim); if (normalize_) { @@ -1951,17 +1981,24 @@ namespace hnswlib // Initialisation of the data and label and if appropriate the length const void *normalized_vector = data_point; - size_t dim = *((size_t *)dist_func_param_); - std::vector norm_array(dim); + // When TurboQuant is enabled, dist_func_param_ points to TurboQuantizer, not size_t* + size_t dim = quantizer_ ? (size_t)quantizer_->dim_ : *((size_t *)dist_func_param_); + std::vector norm_array; if (normalize_) { + norm_array.resize(dim); float length = normalize_vector((float *)data_point, norm_array.data(), dim); - void *lengthPtr = length_memory_ + cur_c * sizeof(float); memcpy(length_memory_ + cur_c * sizeof(float), &length, sizeof(float)); normalized_vector = norm_array.data(); } memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); - memcpy(getDataByInternalId(cur_c), normalized_vector, data_size_); + if (quantizer_) { + // TurboQuant: quantize the float32 vector to b-bit codes + norm + quantizer_->quantize((const float*)normalized_vector, + (uint8_t*)getDataByInternalId(cur_c)); + } else { + memcpy(getDataByInternalId(cur_c), normalized_vector, data_size_); + } if (curlevel) { @@ -1971,11 +2008,16 @@ namespace hnswlib memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); } + // For graph construction: use stored (possibly quantized) data as the query + // This ensures distance computations are always symmetric (codes vs codes) + const void* effective_data_point = quantizer_ ? + (const void*)getDataByInternalId(cur_c) : (const void*)normalized_vector; + if ((signed)currObj != -1) { if (curlevel < maxlevelcopy) { - dist_t curdist = fstdistfunc_(normalized_vector, getDataByInternalId(currObj), dist_func_param_); + dist_t curdist = fstdistfunc_(effective_data_point, getDataByInternalId(currObj), dist_func_param_); for (int level = maxlevelcopy; level > curlevel; level--) { bool changed = true; @@ -1993,7 +2035,7 @@ namespace hnswlib tableint cand = datal[i]; if (cand < 0 || cand > max_elements_) throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(normalized_vector, getDataByInternalId(cand), dist_func_param_); + dist_t d = fstdistfunc_(effective_data_point, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { curdist = d; @@ -2012,14 +2054,14 @@ namespace hnswlib throw std::runtime_error("Level error"); std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( - currObj, normalized_vector, level); + currObj, effective_data_point, level); if (epDeleted) { - top_candidates.emplace(fstdistfunc_(normalized_vector, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + top_candidates.emplace(fstdistfunc_(effective_data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); if (top_candidates.size() > ef_construction_) top_candidates.pop(); } - currObj = mutuallyConnectNewElement(normalized_vector, cur_c, top_candidates, level, false); + currObj = mutuallyConnectNewElement(effective_data_point, cur_c, top_candidates, level, false); } } else @@ -2048,8 +2090,18 @@ namespace hnswlib if (cur_element_count == 0) return result; + // If TurboQuant is enabled, quantize the query so all distance computations + // are symmetric (codes vs codes with dequantization) + const void* effective_query = query_data; + std::vector quantized_query; + if (quantizer_) { + quantized_query.resize(quantizer_->get_storage_size()); + quantizer_->quantize((const float*)query_data, quantized_query.data()); + effective_query = quantized_query.data(); + } + tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + dist_t curdist = fstdistfunc_(effective_query, getDataByInternalId(enterpoint_node_), dist_func_param_); for (int level = maxlevel_; level > 0; level--) { @@ -2069,7 +2121,7 @@ namespace hnswlib tableint cand = datal[i]; if (cand < 0 || cand > max_elements_) throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + dist_t d = fstdistfunc_(effective_query, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { @@ -2085,12 +2137,12 @@ namespace hnswlib if (num_deleted_) { top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); + currObj, effective_query, std::max(ef_, k), isIdAllowed); } else { top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); + currObj, effective_query, std::max(ef_, k), isIdAllowed); } while (top_candidates.size() > k) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 3e01016db..bb6feed96 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -217,5 +217,6 @@ namespace hnswlib #include "space_l2.h" #include "space_ip.h" +#include "turbo_quant.h" #include "bruteforce.h" #include "hnswalg.h" diff --git a/hnswlib/turbo_quant.h b/hnswlib/turbo_quant.h new file mode 100644 index 000000000..8cb64cddb --- /dev/null +++ b/hnswlib/turbo_quant.h @@ -0,0 +1,305 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace hnswlib { + +// Lloyd-Max optimal codebook centroids for 4-bit quantization (16 levels) +// of a Beta distribution that converges to N(0, 1/d) on the unit hypersphere. +// Pre-computed by solving the continuous k-means problem (Eq. 4 in TurboQuant paper). +// These are symmetric around 0, normalized for unit-variance Gaussian. +static const float TURBOQUANT_4BIT_CENTROIDS[16] = { + -2.4008f, -1.8384f, -1.4364f, -1.0968f, + -0.7914f, -0.5044f, -0.2252f, 0.0000f, + 0.2252f, 0.5044f, 0.7914f, 1.0968f, + 1.4364f, 1.8384f, 2.4008f, 3.0000f // last bin catches tail +}; + +// Decision boundaries (midpoints between consecutive centroids) +static const float TURBOQUANT_4BIT_BOUNDARIES[15] = { + -2.1196f, -1.6374f, -1.2666f, -0.9441f, + -0.6479f, -0.3648f, -0.1126f, 0.1126f, + 0.3648f, 0.6479f, 0.9441f, 1.2666f, + 1.6374f, 2.1196f, 2.7004f +}; + +class TurboQuantizer { +public: + int dim_; + int bits_; + int num_levels_; // 2^bits + size_t code_size_; // ceil(dim * bits / 8) + std::vector rotation_signs_; // Random ±1 diagonal (fast rotation: D matrix) + const float* centroids_; + const float* boundaries_; + int num_boundaries_; + + // Scratch buffers for rotation (avoid per-call allocation) + // NOT thread-safe — each thread needs its own or use thread_local + mutable std::vector rotated_buf_; + + TurboQuantizer() : dim_(0), bits_(0), num_levels_(0), code_size_(0), + centroids_(nullptr), boundaries_(nullptr), num_boundaries_(0) {} + + TurboQuantizer(int dim, int bits, uint64_t seed = 42) + : dim_(dim), bits_(bits) { + num_levels_ = 1 << bits; + + if (bits == 4) { + centroids_ = TURBOQUANT_4BIT_CENTROIDS; + boundaries_ = TURBOQUANT_4BIT_BOUNDARIES; + num_boundaries_ = 15; + } else { + throw std::runtime_error("TurboQuant: only 4-bit quantization supported"); + } + + // For 4-bit: 2 codes per byte + code_size_ = (dim * bits + 7) / 8; + + // Generate random ±1 diagonal for fast structured rotation (D matrix in Π = HD) + // Using just D (random sign flips) is a simpler approximation that works well + // in high dimensions due to concentration of measure + rotation_signs_.resize(dim); + std::mt19937_64 rng(seed); + std::uniform_int_distribution dist(0, 1); + for (int i = 0; i < dim; i++) { + rotation_signs_[i] = dist(rng) ? 1.0f : -1.0f; + } + + rotated_buf_.resize(dim); + } + + // Get the total bytes needed to store a quantized vector (codes + norm) + size_t get_storage_size() const { + return code_size_ + sizeof(float); // codes + L2 norm + } + + // Quantize a float32 vector to b-bit codes + store norm + // output must have get_storage_size() bytes available + void quantize(const float* input, uint8_t* output) const { + float norm = 0.0f; + float inv_norm; + + // Compute L2 norm + for (int i = 0; i < dim_; i++) { + norm += input[i] * input[i]; + } + norm = std::sqrt(norm); + inv_norm = (norm > 1e-10f) ? 1.0f / norm : 0.0f; + + // Apply random sign rotation and normalize: rotated[i] = sign[i] * input[i] / norm + // After normalization to unit sphere, coordinates follow Beta → N(0, 1/d) + // Scale by sqrt(d) to get standard normal for codebook lookup + float scale = std::sqrt((float)dim_) * inv_norm; + + // Quantize each coordinate using 4-bit codebook + if (bits_ == 4) { + for (int i = 0; i < dim_; i += 2) { + float val0 = rotation_signs_[i] * input[i] * scale; + float val1 = (i + 1 < dim_) ? rotation_signs_[i + 1] * input[i + 1] * scale : 0.0f; + + // Find nearest centroid using binary search on boundaries + uint8_t code0 = find_bin(val0); + uint8_t code1 = find_bin(val1); + + // Pack two 4-bit codes into one byte + output[i / 2] = (code0 & 0x0F) | ((code1 & 0x0F) << 4); + } + } + + // Store norm after codes + memcpy(output + code_size_, &norm, sizeof(float)); + } + + // Dequantize codes back to approximate float32 vector + void dequantize(const uint8_t* codes, float* output) const { + float norm; + memcpy(&norm, codes + code_size_, sizeof(float)); + + float inv_scale = norm / std::sqrt((float)dim_); + + if (bits_ == 4) { + for (int i = 0; i < dim_; i += 2) { + uint8_t packed = codes[i / 2]; + uint8_t code0 = packed & 0x0F; + uint8_t code1 = (packed >> 4) & 0x0F; + + // Lookup centroid, undo scaling and rotation + output[i] = rotation_signs_[i] * centroids_[code0] * inv_scale; + if (i + 1 < dim_) { + output[i + 1] = rotation_signs_[i + 1] * centroids_[code1] * inv_scale; + } + } + } + } + + // Asymmetric L2 squared distance: float32 query vs quantized database vector + // This is the HOT PATH — called millions of times during search + float distance_asymmetric_l2(const float* query, const uint8_t* codes) const { + float db_norm; + memcpy(&db_norm, codes + code_size_, sizeof(float)); + + float inv_scale = db_norm / std::sqrt((float)dim_); + float dist = 0.0f; + + if (bits_ == 4) { + for (int i = 0; i < dim_; i += 2) { + uint8_t packed = codes[i / 2]; + uint8_t code0 = packed & 0x0F; + uint8_t code1 = (packed >> 4) & 0x0F; + + // Reconstruct database coordinate + float db0 = rotation_signs_[i] * centroids_[code0] * inv_scale; + float diff0 = query[i] - db0; + dist += diff0 * diff0; + + if (i + 1 < dim_) { + float db1 = rotation_signs_[i + 1] * centroids_[code1] * inv_scale; + float diff1 = query[i + 1] - db1; + dist += diff1 * diff1; + } + } + } + + return dist; + } + + // Asymmetric inner product: float32 query vs quantized database vector + float distance_asymmetric_ip(const float* query, const uint8_t* codes) const { + float db_norm; + memcpy(&db_norm, codes + code_size_, sizeof(float)); + + float inv_scale = db_norm / std::sqrt((float)dim_); + float ip = 0.0f; + + if (bits_ == 4) { + for (int i = 0; i < dim_; i += 2) { + uint8_t packed = codes[i / 2]; + uint8_t code0 = packed & 0x0F; + uint8_t code1 = (packed >> 4) & 0x0F; + + float db0 = rotation_signs_[i] * centroids_[code0] * inv_scale; + ip += query[i] * db0; + + if (i + 1 < dim_) { + float db1 = rotation_signs_[i + 1] * centroids_[code1] * inv_scale; + ip += query[i + 1] * db1; + } + } + } + + return 1.0f - ip; // inner product distance + } + +private: +public: + // Binary search for the quantization bin + inline uint8_t find_bin(float val) const { + // Linear scan is faster than binary search for 15 boundaries + for (int b = 0; b < num_boundaries_; b++) { + if (val < boundaries_[b]) return (uint8_t)b; + } + return (uint8_t)num_boundaries_; // last bin + } + + // Symmetric distance: both vectors are quantized codes + // Used for graph maintenance (mutuallyConnectNewElement, etc.) + float distance_symmetric_l2(const uint8_t* codes_a, const uint8_t* codes_b) const { + // Dequantize both to float32 and compute exact L2 + // This is slower but only called during graph construction, not search + std::vector vec_a(dim_), vec_b(dim_); + dequantize(codes_a, vec_a.data()); + dequantize(codes_b, vec_b.data()); + + float dist = 0.0f; + for (int i = 0; i < dim_; i++) { + float d = vec_a[i] - vec_b[i]; + dist += d * d; + } + return dist; +} + +float distance_symmetric_ip(const uint8_t* codes_a, const uint8_t* codes_b) const { + std::vector vec_a(dim_), vec_b(dim_); + dequantize(codes_a, vec_a.data()); + dequantize(codes_b, vec_b.data()); + + float ip = 0.0f; + for (int i = 0; i < dim_; i++) { + ip += vec_a[i] * vec_b[i]; + } + return 1.0f - ip; +} + +}; // end class TurboQuantizer + +// Distance function compatible with hnswlib's DISTFUNC signature +// Handles BOTH asymmetric (float query vs codes) AND symmetric (codes vs codes) +// by checking if pVect1 looks like it could be quantized data +// IMPORTANT: The HNSW algorithm calls this for: +// 1. query-vs-stored (asymmetric): during search +// 2. stored-vs-stored (symmetric): during graph maintenance +// We detect this by dequantizing both arguments always (safe but slower). +static float TurboQuantL2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + const TurboQuantizer* quantizer = (const TurboQuantizer*)qty_ptr; + // Always dequantize pVect2 (always stored codes) + // For pVect1: it could be float32 query OR stored codes + // Safest approach: dequantize both and compute in float32 + // This handles both asymmetric and symmetric cases correctly + const uint8_t* codes_a = (const uint8_t*)pVect1v; + const uint8_t* codes_b = (const uint8_t*)pVect2v; + return quantizer->distance_symmetric_l2(codes_a, codes_b); +} + +static float TurboQuantIP(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + const TurboQuantizer* quantizer = (const TurboQuantizer*)qty_ptr; + const uint8_t* codes_a = (const uint8_t*)pVect1v; + const uint8_t* codes_b = (const uint8_t*)pVect2v; + return quantizer->distance_symmetric_ip(codes_a, codes_b); +} + +// SpaceInterface implementation for TurboQuant +class TurboQuantL2Space : public SpaceInterface { + TurboQuantizer quantizer_; + size_t data_size_; + DISTFUNC fstdistfunc_; + +public: + TurboQuantL2Space(size_t dim, int bits = 4, uint64_t seed = 42) + : quantizer_(dim, bits, seed) { + data_size_ = quantizer_.get_storage_size(); + fstdistfunc_ = TurboQuantL2Sqr; + } + + size_t get_data_size() override { return data_size_; } + DISTFUNC get_dist_func() override { return fstdistfunc_; } + void* get_dist_func_param() override { return (void*)&quantizer_; } + + TurboQuantizer* get_quantizer() { return &quantizer_; } +}; + +class TurboQuantIPSpace : public SpaceInterface { + TurboQuantizer quantizer_; + size_t data_size_; + DISTFUNC fstdistfunc_; + +public: + TurboQuantIPSpace(size_t dim, int bits = 4, uint64_t seed = 42) + : quantizer_(dim, bits, seed) { + data_size_ = quantizer_.get_storage_size(); + fstdistfunc_ = TurboQuantIP; + } + + size_t get_data_size() override { return data_size_; } + DISTFUNC get_dist_func() override { return fstdistfunc_; } + void* get_dist_func_param() override { return (void*)&quantizer_; } + + TurboQuantizer* get_quantizer() { return &quantizer_; } +}; + +} // namespace hnswlib diff --git a/src/bindings.cpp b/src/bindings.cpp index 97069bf6d..a46663fb7 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -1,6 +1,7 @@ // Assumes that chroma-hnswlib is checked out at the same level as chroma #include "../hnswlib/hnswlib.h" #include "../hnswlib/hnswalg.h" +#include "../hnswlib/turbo_quant.h" class AllowAndDisallowListFilterFunctor : public hnswlib::BaseFilterFunctor { @@ -44,23 +45,53 @@ class Index hnswlib::HierarchicalNSW *appr_alg; hnswlib::SpaceInterface *l2space; - Index(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) + int quantization_bits; + + Index(const std::string &space_name, const int dim, const int quant_bits = 0) + : space_name(space_name), dim(dim), quantization_bits(quant_bits) { - if (space_name == "l2") - { - l2space = new hnswlib::L2Space(dim); - normalize = false; - } - if (space_name == "ip") + if (quant_bits > 0) { - l2space = new hnswlib::InnerProductSpace(dim); - // For IP, we expect the vectors to be normalized - normalize = false; + // TurboQuant mode: use quantized space + if (space_name == "l2") + { + l2space = new hnswlib::TurboQuantL2Space(dim, quant_bits); + normalize = false; + } + else if (space_name == "ip") + { + l2space = new hnswlib::TurboQuantIPSpace(dim, quant_bits); + normalize = false; + } + else if (space_name == "cosine") + { + l2space = new hnswlib::TurboQuantIPSpace(dim, quant_bits); + normalize = true; + } + else + { + l2space = new hnswlib::TurboQuantL2Space(dim, quant_bits); + normalize = false; + } } - if (space_name == "cosine") + else { - l2space = new hnswlib::InnerProductSpace(dim); - normalize = true; + // Standard float32 mode + if (space_name == "l2") + { + l2space = new hnswlib::L2Space(dim); + normalize = false; + } + if (space_name == "ip") + { + l2space = new hnswlib::InnerProductSpace(dim); + normalize = false; + } + if (space_name == "cosine") + { + l2space = new hnswlib::InnerProductSpace(dim); + normalize = true; + } } appr_alg = NULL; index_inited = false; @@ -84,6 +115,11 @@ class Index } appr_alg = new hnswlib::HierarchicalNSW(l2space, max_elements, M, ef_construction, random_seed, allow_replace_deleted, normalize, is_persistent_index, persistence_location); appr_alg->ef_ = 10; // This is a default value for ef_ + // If TurboQuant space, set the quantizer on the HNSW index. + // dist_func_param_ already points to the TurboQuantizer inside the space. + if (quantization_bits > 0) { + appr_alg->set_quantizer((hnswlib::TurboQuantizer*)l2space->get_dist_func_param()); + } index_inited = true; } @@ -293,6 +329,22 @@ extern "C" return new Index(space_name, dim); } + // Create index with TurboQuant quantization + Index *create_index_quantized(const char *space_name, const int dim, const int quantization_bits) + { + try + { + auto* index = new Index(space_name, dim, quantization_bits); + last_error.clear(); + return index; + } + catch (std::exception &e) + { + last_error = e.what(); + return nullptr; + } + } + void free_index(Index *index) { delete index; diff --git a/src/hnsw.rs b/src/hnsw.rs index 8ef1e4be5..ae09fa18c 100644 --- a/src/hnsw.rs +++ b/src/hnsw.rs @@ -31,6 +31,7 @@ pub struct DataViewFFI { #[link(name = "bindings", kind = "static")] extern "C" { fn create_index(space_name: *const c_char, dim: c_int) -> *const HnswIndexPtrFFI; + fn create_index_quantized(space_name: *const c_char, dim: c_int, quantization_bits: c_int) -> *const HnswIndexPtrFFI; fn free_index(index: *const HnswIndexPtrFFI); @@ -193,6 +194,7 @@ pub struct HnswIndexInitConfig { pub ef_construction: usize, pub ef_search: usize, pub random_seed: usize, + pub quantization_bits: i32, // 0 = no quantization, 4 = TurboQuant 4-bit } impl HnswIndex { @@ -201,7 +203,11 @@ impl HnswIndex { let space_name = CString::new(distance_function_string) .map_err(|e| HnswInitError::InvalidDistanceFunction(e.to_string()))?; - let ffi_ptr = unsafe { create_index(space_name.as_ptr(), config.dimensionality) }; + let ffi_ptr = if config.quantization_bits > 0 { + unsafe { create_index_quantized(space_name.as_ptr(), config.dimensionality, config.quantization_bits) } + } else { + unsafe { create_index(space_name.as_ptr(), config.dimensionality) } + }; read_and_return_hnsw_error(ffi_ptr)?; let path = match config.persist_path.clone() {