Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 66 additions & 14 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ namespace hnswlib
std::vector<int> element_levels_; // keeps level of each element

size_t data_size_{0};
size_t original_data_size_{0}; // original float32 data size (before quantization)
TurboQuantizer* quantizer_{nullptr}; // non-null when TurboQuant is enabled

DISTFUNC<dist_t> fstdistfunc_;
void *dist_func_param_{nullptr};
Expand Down Expand Up @@ -294,8 +296,12 @@ namespace hnswlib
max_elements_ = max_elements;
num_deleted_ = 0;
data_size_ = s->get_data_size();
original_data_size_ = data_size_; // may differ from data_size_ if quantized
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();

// Quantizer is not set here — it will be set externally
// via set_quantizer() from code that has access to TurboQuantL2Space/IPSpace
M_ = M;
maxM_ = M_;
maxM0_ = M_ * 2;
Expand Down Expand Up @@ -344,6 +350,14 @@ namespace hnswlib
}
}

// Set the quantizer after construction (avoids circular include dependencies)
void set_quantizer(TurboQuantizer* q) {
quantizer_ = q;
if (q) {
original_data_size_ = q->dim_ * sizeof(float);
}
}

~HierarchicalNSW()
{
free(data_level0_memory_);
Expand Down Expand Up @@ -1521,6 +1535,22 @@ namespace hnswlib
lock_table.unlock();

char *data_ptrv = getDataByInternalId(internalId);

// When TurboQuant is enabled, dequantize the stored codes back to float32
if (quantizer_) {
size_t dim = quantizer_->dim_;
std::vector<data_t> data(dim);
quantizer_->dequantize((const uint8_t*)data_ptrv, data.data());
// If normalize was used during insert, undo normalization
if (normalize_) {
float length = ((float *)length_memory_)[internalId];
for (size_t i = 0; i < dim; i++) {
data[i] *= length;
}
}
return data;
}

float length = 1.0;
if (normalize_)
{
Expand Down Expand Up @@ -1710,7 +1740,7 @@ namespace hnswlib
{
const void *newPoint = dataPoint;

size_t dim = *((size_t *)dist_func_param_);
size_t dim = quantizer_ ? (size_t)quantizer_->dim_ : *((size_t *)dist_func_param_);
std::vector<float> norm_array(dim);
if (normalize_)
{
Expand Down Expand Up @@ -1951,17 +1981,24 @@ namespace hnswlib

// Initialisation of the data and label and if appropriate the length
const void *normalized_vector = data_point;
size_t dim = *((size_t *)dist_func_param_);
std::vector<float> norm_array(dim);
// When TurboQuant is enabled, dist_func_param_ points to TurboQuantizer, not size_t*
size_t dim = quantizer_ ? (size_t)quantizer_->dim_ : *((size_t *)dist_func_param_);
std::vector<float> norm_array;
if (normalize_)
{
norm_array.resize(dim);
float length = normalize_vector((float *)data_point, norm_array.data(), dim);
void *lengthPtr = length_memory_ + cur_c * sizeof(float);
memcpy(length_memory_ + cur_c * sizeof(float), &length, sizeof(float));
normalized_vector = norm_array.data();
}
memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
memcpy(getDataByInternalId(cur_c), normalized_vector, data_size_);
if (quantizer_) {
// TurboQuant: quantize the float32 vector to b-bit codes + norm
quantizer_->quantize((const float*)normalized_vector,
(uint8_t*)getDataByInternalId(cur_c));
} else {
memcpy(getDataByInternalId(cur_c), normalized_vector, data_size_);
}

if (curlevel)
{
Expand All @@ -1971,11 +2008,16 @@ namespace hnswlib
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
}

// For graph construction: use stored (possibly quantized) data as the query
// This ensures distance computations are always symmetric (codes vs codes)
const void* effective_data_point = quantizer_ ?
(const void*)getDataByInternalId(cur_c) : (const void*)normalized_vector;

if ((signed)currObj != -1)
{
if (curlevel < maxlevelcopy)
{
dist_t curdist = fstdistfunc_(normalized_vector, getDataByInternalId(currObj), dist_func_param_);
dist_t curdist = fstdistfunc_(effective_data_point, getDataByInternalId(currObj), dist_func_param_);
for (int level = maxlevelcopy; level > curlevel; level--)
{
bool changed = true;
Expand All @@ -1993,7 +2035,7 @@ namespace hnswlib
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(normalized_vector, getDataByInternalId(cand), dist_func_param_);
dist_t d = fstdistfunc_(effective_data_point, getDataByInternalId(cand), dist_func_param_);
if (d < curdist)
{
curdist = d;
Expand All @@ -2012,14 +2054,14 @@ namespace hnswlib
throw std::runtime_error("Level error");

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
currObj, normalized_vector, level);
currObj, effective_data_point, level);
if (epDeleted)
{
top_candidates.emplace(fstdistfunc_(normalized_vector, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
top_candidates.emplace(fstdistfunc_(effective_data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
if (top_candidates.size() > ef_construction_)
top_candidates.pop();
}
currObj = mutuallyConnectNewElement(normalized_vector, cur_c, top_candidates, level, false);
currObj = mutuallyConnectNewElement(effective_data_point, cur_c, top_candidates, level, false);
}
}
else
Expand Down Expand Up @@ -2048,8 +2090,18 @@ namespace hnswlib
if (cur_element_count == 0)
return result;

// If TurboQuant is enabled, quantize the query so all distance computations
// are symmetric (codes vs codes with dequantization)
const void* effective_query = query_data;
std::vector<uint8_t> quantized_query;
if (quantizer_) {
quantized_query.resize(quantizer_->get_storage_size());
quantizer_->quantize((const float*)query_data, quantized_query.data());
effective_query = quantized_query.data();
}

tableint currObj = enterpoint_node_;
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
dist_t curdist = fstdistfunc_(effective_query, getDataByInternalId(enterpoint_node_), dist_func_param_);

for (int level = maxlevel_; level > 0; level--)
{
Expand All @@ -2069,7 +2121,7 @@ namespace hnswlib
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
dist_t d = fstdistfunc_(effective_query, getDataByInternalId(cand), dist_func_param_);

if (d < curdist)
{
Expand All @@ -2085,12 +2137,12 @@ namespace hnswlib
if (num_deleted_)
{
top_candidates = searchBaseLayerST<true, true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
currObj, effective_query, std::max(ef_, k), isIdAllowed);
}
else
{
top_candidates = searchBaseLayerST<false, true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
currObj, effective_query, std::max(ef_, k), isIdAllowed);
}

while (top_candidates.size() > k)
Expand Down
1 change: 1 addition & 0 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,6 @@ namespace hnswlib

#include "space_l2.h"
#include "space_ip.h"
#include "turbo_quant.h"
#include "bruteforce.h"
#include "hnswalg.h"
Loading