From 3071abecfe62fa97e33fbe3423e333134f0571ec Mon Sep 17 00:00:00 2001 From: Robert Escriva Date: Fri, 20 Mar 2026 17:45:36 -0700 Subject: [PATCH 1/3] [BUG} skip deleted nodes in checks Avoid including deleted elements when validating inbound connection counts in the HNSW graph check. Tests live in the chroma-core/chroma GitHub repo. Co-authored-by: AI --- hnswlib/hnswalg.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 3fcb99448..b05b9fa60 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -2131,10 +2131,11 @@ namespace hnswlib } if (cur_element_count > 1) { - int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0]; + int min1 = INT_MAX, max1 = 0; for (int i = 0; i < cur_element_count; i++) { - // This should always be true regardless the data is corrupted or not + if (isMarkedDeleted(i)) + continue; assert(inbound_connections_num[i] > 0); min1 = std::min(inbound_connections_num[i], min1); max1 = std::max(inbound_connections_num[i], max1); From a59a553ddcc975f269c92626bcaf14c5410cee3e Mon Sep 17 00:00:00 2001 From: Robert Escriva Date: Mon, 23 Mar 2026 08:14:53 -0700 Subject: [PATCH 2/3] #include --- hnswlib/hnswalg.h | 1 + 1 file changed, 1 insertion(+) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index b05b9fa60..d550c3572 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace hnswlib From 58f886d91eecb920effabcec0e5964b9459997ce Mon Sep 17 00:00:00 2001 From: Robert Escriva Date: Mon, 23 Mar 2026 10:29:08 -0700 Subject: [PATCH 3/3] test it --- CMakeLists.txt | 4 +- tests/cpp/integrity_test.cpp | 114 +++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 tests/cpp/integrity_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 704750dbc..4ce4b106d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,5 +62,7 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) add_executable(api_tests tests/cpp/api_test.cpp) target_link_libraries(api_tests hnswlib) -endif() + add_executable(integrity_test tests/cpp/integrity_test.cpp) + target_link_libraries(integrity_test hnswlib) +endif() diff --git a/tests/cpp/integrity_test.cpp b/tests/cpp/integrity_test.cpp new file mode 100644 index 000000000..b6afd77b8 --- /dev/null +++ b/tests/cpp/integrity_test.cpp @@ -0,0 +1,114 @@ +#include "../../hnswlib/hnswlib.h" + +#include + +#include +#include +#include +#include +#include + +namespace +{ + std::vector inbound_counts(const hnswlib::HierarchicalNSW &index) + { + std::vector inbound(index.cur_element_count, 0); + for (size_t i = 0; i < index.cur_element_count; ++i) + { + for (int level = 0; level <= index.element_levels_[i]; ++level) + { + hnswlib::linklistsizeint *link_list = index.get_linklist_at_level(i, level); + int size = index.getListCount(link_list); + hnswlib::tableint *neighbors = reinterpret_cast(link_list + 1); + for (int j = 0; j < size; ++j) + { + inbound[neighbors[j]]++; + } + } + } + return inbound; + } + + void remove_inbound_references(hnswlib::HierarchicalNSW &index, hnswlib::tableint target) + { + for (size_t i = 0; i < index.cur_element_count; ++i) + { + for (int level = 0; level <= index.element_levels_[i]; ++level) + { + hnswlib::linklistsizeint *link_list = index.get_linklist_at_level(i, level); + int size = index.getListCount(link_list); + hnswlib::tableint *neighbors = reinterpret_cast(link_list + 1); + + int write = 0; + for (int read = 0; read < size; ++read) + { + if (neighbors[read] != target) + { + neighbors[write++] = neighbors[read]; + } + } + index.setListCount(link_list, write); + } + } + } + + void testCheckIntegritySkipsDeletedNodesInInboundStats() + { + const int d = 8; + const int n = 32; + std::mt19937 rng(123); + std::uniform_real_distribution distrib(0.0f, 1.0f); + std::vector data(n * d); + for (float &value : data) + { + value = distrib(rng); + } + + hnswlib::L2Space space(d); + hnswlib::HierarchicalNSW index(&space, n, 8, 40, 17); + for (int i = 0; i < n; ++i) + { + index.addPoint(data.data() + i * d, i); + } + + std::vector before = inbound_counts(index); + for (int count : before) + { + assert(count > 0); + } + + index.markDelete(0); + remove_inbound_references(index, 0); + + std::vector after = inbound_counts(index); + assert(after[0] == 0); + + int expected_min = after[1]; + int expected_max = after[1]; + for (int i = 1; i < n; ++i) + { + assert(after[i] > 0); + expected_min = std::min(expected_min, after[i]); + expected_max = std::max(expected_max, after[i]); + } + + std::ostringstream captured; + std::streambuf *old = std::cout.rdbuf(captured.rdbuf()); + index.checkIntegrity(); + std::cout.rdbuf(old); + + const std::string output = captured.str(); + const std::string expected_line = + "Min inbound: " + std::to_string(expected_min) + ", Max inbound:" + std::to_string(expected_max); + assert(output.find(expected_line) != std::string::npos); + assert(output.find("Min inbound: 0") == std::string::npos); + } +} // namespace + +int main() +{ + std::cout << "Testing ..." << std::endl; + testCheckIntegritySkipsDeletedNodesInInboundStats(); + std::cout << "Test testCheckIntegritySkipsDeletedNodesInInboundStats ok" << std::endl; + return 0; +}