From 48bbed6cddfaac9b7974b64115b3fd3af860bd0a Mon Sep 17 00:00:00 2001 From: yoshoku Date: Mon, 21 Mar 2022 20:35:20 +0900 Subject: [PATCH 01/41] add throw statement to errors in BruteforceSearch --- hnswlib/bruteforce.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 24260400..969ad4c8 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -23,7 +23,7 @@ namespace hnswlib { size_per_element_ = data_size_ + sizeof(labeltype); data_ = (char *) malloc(maxElements * size_per_element_); if (data_ == nullptr) - std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); cur_element_count = 0; } @@ -140,7 +140,7 @@ namespace hnswlib { size_per_element_ = data_size_ + sizeof(labeltype); data_ = (char *) malloc(maxelements_ * size_per_element_); if (data_ == nullptr) - std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); input.read(data_, maxelements_ * size_per_element_); From 5f074d5bcb5fdde8bbc02f926981553fc3225951 Mon Sep 17 00:00:00 2001 From: yoshoku Date: Mon, 21 Mar 2022 20:58:09 +0900 Subject: [PATCH 02/41] remove unnecessary new operators --- python_bindings/bindings.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 12f38e2e..3050d972 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -96,7 +96,7 @@ class Index { l2space = new hnswlib::InnerProductSpace(dim); normalize=true; } else { - throw new std::runtime_error("Space name must be one of l2, ip, or cosine."); + throw std::runtime_error("Space name must be one of l2, ip, or cosine."); } appr_alg = NULL; ep_added = true; @@ -129,7 +129,7 @@ class Index { void init_new_index(const size_t maxElements, const size_t M, const size_t efConstruction, const size_t random_seed) { if (appr_alg) { - throw new std::runtime_error("The index is already initiated."); + throw std::runtime_error("The index is already initiated."); } cur_l = 0; appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed); @@ -668,7 +668,7 @@ class BFIndex { space = new hnswlib::InnerProductSpace(dim); normalize=true; } else { - throw new std::runtime_error("Space name must be one of l2, ip, or cosine."); + throw std::runtime_error("Space name must be one of l2, ip, or cosine."); } alg = NULL; index_inited = false; @@ -693,7 +693,7 @@ class BFIndex { void init_new_index(const size_t maxElements) { if (alg) { - throw new std::runtime_error("The index is already initiated."); + throw std::runtime_error("The index is already initiated."); } cur_l = 0; alg = new hnswlib::BruteforceSearch(space, maxElements); From d51c3122570363a9ee045e8af9a44721a11ba186 Mon Sep 17 00:00:00 2001 From: Paul-Louis NECH <1821404+PLNech@users.noreply.github.com> Date: Mon, 2 May 2022 17:28:23 +0200 Subject: [PATCH 03/41] chore(ALGO_PARAMS.md): Fix typo --- ALGO_PARAMS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ALGO_PARAMS.md b/ALGO_PARAMS.md index b0a6b7ad..0d5133f3 100644 --- a/ALGO_PARAMS.md +++ b/ALGO_PARAMS.md @@ -27,5 +27,5 @@ ef_construction leads to longer construction, but better index quality. At some not improve the quality of the index. One way to check if the selection of ef_construction was ok is to measure a recall for M nearest neighbor search when ```ef``` =```ef_construction```: if the recall is lower than 0.9, than there is room for improvement. -* ```num_elements``` - defines the maximum number of elements in the index. The index can be extened by saving/loading(load_index +* ```num_elements``` - defines the maximum number of elements in the index. The index can be extended by saving/loading (load_index function has a parameter which defines the new maximum number of elements). From 492e15e469c3c4c70de2e806ca205ea4bd4106d2 Mon Sep 17 00:00:00 2001 From: Yuriy Korzhenevskiy Date: Thu, 2 Jun 2022 13:48:17 +0300 Subject: [PATCH 04/41] Highlight code --- TESTING_RECALL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TESTING_RECALL.md b/TESTING_RECALL.md index 23a6f654..29136ec8 100644 --- a/TESTING_RECALL.md +++ b/TESTING_RECALL.md @@ -27,7 +27,7 @@ max_elements defines the maximum number of elements that can be stored in the st ### measuring recall example -``` +```python import hnswlib import numpy as np From fb3a699efbd9890f52e00ebb9549563cad6aa8fc Mon Sep 17 00:00:00 2001 From: MasterAler Date: Wed, 8 Jun 2022 14:40:08 +0300 Subject: [PATCH 05/41] fix global linkage --- hnswlib/hnswlib.h | 12 ++++++------ hnswlib/space_ip.h | 8 ++++---- hnswlib/space_l2.h | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 58eb7607..61029e90 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -16,20 +16,20 @@ #include #include #include "cpu_x86.h" -void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) { +static void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) { __cpuidex(out, eax, ecx); } -__int64 xgetbv(unsigned int x) { +static __int64 xgetbv(unsigned int x) { return _xgetbv(x); } #else #include #include #include -void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { +static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); } -uint64_t xgetbv(unsigned int index) { +static uint64_t xgetbv(unsigned int index) { uint32_t eax, edx; __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); return ((uint64_t)edx << 32) | eax; @@ -51,7 +51,7 @@ uint64_t xgetbv(unsigned int index) { // Adapted from https://github.com/Mysticial/FeatureDetector #define _XCR_XFEATURE_ENABLED_MASK 0 -bool AVXCapable() { +static bool AVXCapable() { int cpuInfo[4]; // CPU support @@ -78,7 +78,7 @@ bool AVXCapable() { return HW_AVX && avxSupported; } -bool AVX512Capable() { +static bool AVX512Capable() { if (!AVXCapable()) return false; int cpuInfo[4]; diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index b4266f78..d45a4c66 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -281,10 +281,10 @@ namespace hnswlib { #endif #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; - DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; - DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; - DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; + static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; + static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; + static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; + static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; static float InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h index 44135370..355cc7b8 100644 --- a/hnswlib/space_l2.h +++ b/hnswlib/space_l2.h @@ -144,7 +144,7 @@ namespace hnswlib { #endif #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; + static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; static float L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { From 632da8f2d12b606cc98fee09ba1f913811bfffc8 Mon Sep 17 00:00:00 2001 From: James Melville Date: Sat, 30 Jul 2022 11:22:50 -0700 Subject: [PATCH 06/41] add missing quote --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9bcb6775..d5c9b998 100644 --- a/README.md +++ b/README.md @@ -247,7 +247,7 @@ Please make pull requests against the `develop` branch. When making changes please run tests (and please add a test to `python_bindings/tests` in case there is new functionality): ```bash -python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py +python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py" ``` From d3197c5f06ea3bfc0f303d2fe610f085d57a16ca Mon Sep 17 00:00:00 2001 From: James Melville Date: Sat, 30 Jul 2022 11:41:19 -0700 Subject: [PATCH 07/41] initialize fields in constructor --- hnswlib/bruteforce.h | 8 ++++++-- hnswlib/hnswalg.h | 25 ++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 24260400..691424bc 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -8,10 +8,14 @@ namespace hnswlib { template class BruteforceSearch : public AlgorithmInterface { public: - BruteforceSearch(SpaceInterface *s) { + BruteforceSearch(SpaceInterface *s) : data_(nullptr), maxelements_(0), + cur_element_count(0), size_per_element_(0), data_size_(0), + dist_func_param_(nullptr) { } - BruteforceSearch(SpaceInterface *s, const std::string &location) { + BruteforceSearch(SpaceInterface *s, const std::string &location) : + data_(nullptr), maxelements_(0), cur_element_count(0), size_per_element_(0), + data_size_(0), dist_func_param_(nullptr) { loadIndex(location, s); } diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e95e0b52..7054839b 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -17,15 +17,34 @@ namespace hnswlib { class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; - HierarchicalNSW(SpaceInterface *s) { + HierarchicalNSW(SpaceInterface *s) : + max_elements_(0), cur_element_count(0), size_data_per_element_(0), + size_links_per_element_(0), num_deleted_(0), M_(0), maxM_(0), maxM0_(0), + ef_construction_(0), mult_(0.0), revSize_(0.0), maxlevel_(0), + visited_list_pool_(nullptr), enterpoint_node_(0), size_links_level0_(0), + offsetData_(0), offsetLevel0_(0), + data_level0_memory_(nullptr), + linkLists_(nullptr), data_size_(0), label_offset_(0), + dist_func_param_(nullptr), metric_distance_computations(0), metric_hops(0), + ef_(0){ } - HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { + HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) : + max_elements_(0), cur_element_count(0), size_data_per_element_(0), + size_links_per_element_(0), num_deleted_(0), M_(0), maxM_(0), maxM0_(0), + ef_construction_(0), mult_(0.0), revSize_(0.0), maxlevel_(0), + visited_list_pool_(nullptr), enterpoint_node_(0), size_links_level0_(0), + offsetData_(0), offsetLevel0_(0), + data_level0_memory_(nullptr), + linkLists_(nullptr), data_size_(0), label_offset_(0), + dist_func_param_(nullptr), metric_distance_computations(0), metric_hops(0), + ef_(0) { loadIndex(location, s, max_elements); } HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : - link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) { + link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements), + metric_distance_computations(0), metric_hops(0) { max_elements_ = max_elements; num_deleted_ = 0; From 25c738386e14d03134e480dbc0d5b334d3078a0a Mon Sep 17 00:00:00 2001 From: James Melville Date: Wed, 3 Aug 2022 22:58:33 -0700 Subject: [PATCH 08/41] direct member initialize fields --- hnswlib/hnswalg.h | 72 +++++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 46 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 7054839b..8060683c 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -17,34 +17,15 @@ namespace hnswlib { class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; - HierarchicalNSW(SpaceInterface *s) : - max_elements_(0), cur_element_count(0), size_data_per_element_(0), - size_links_per_element_(0), num_deleted_(0), M_(0), maxM_(0), maxM0_(0), - ef_construction_(0), mult_(0.0), revSize_(0.0), maxlevel_(0), - visited_list_pool_(nullptr), enterpoint_node_(0), size_links_level0_(0), - offsetData_(0), offsetLevel0_(0), - data_level0_memory_(nullptr), - linkLists_(nullptr), data_size_(0), label_offset_(0), - dist_func_param_(nullptr), metric_distance_computations(0), metric_hops(0), - ef_(0){ + HierarchicalNSW(SpaceInterface *s) { } - HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) : - max_elements_(0), cur_element_count(0), size_data_per_element_(0), - size_links_per_element_(0), num_deleted_(0), M_(0), maxM_(0), maxM0_(0), - ef_construction_(0), mult_(0.0), revSize_(0.0), maxlevel_(0), - visited_list_pool_(nullptr), enterpoint_node_(0), size_links_level0_(0), - offsetData_(0), offsetLevel0_(0), - data_level0_memory_(nullptr), - linkLists_(nullptr), data_size_(0), label_offset_(0), - dist_func_param_(nullptr), metric_distance_computations(0), metric_hops(0), - ef_(0) { + HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { loadIndex(location, s, max_elements); } HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : - link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements), - metric_distance_computations(0), metric_hops(0) { + link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) { max_elements_ = max_elements; num_deleted_ = 0; @@ -104,22 +85,21 @@ namespace hnswlib { delete visited_list_pool_; } - size_t max_elements_; - size_t cur_element_count; - size_t size_data_per_element_; - size_t size_links_per_element_; - size_t num_deleted_; + size_t max_elements_{0}; + size_t cur_element_count{0}; + size_t size_data_per_element_{0}; + size_t size_links_per_element_{0}; + size_t num_deleted_{0}; + size_t M_{0}; + size_t maxM_{0}; + size_t maxM0_{0}; + size_t ef_construction_{0}; - size_t M_; - size_t maxM_; - size_t maxM0_; - size_t ef_construction_; + double mult_{0.0}, revSize_{0.0}; + int maxlevel_{0}; - double mult_, revSize_; - int maxlevel_; - - VisitedListPool *visited_list_pool_; + VisitedListPool *visited_list_pool_{nullptr}; std::mutex cur_element_count_guard_; std::vector link_list_locks_; @@ -127,20 +107,20 @@ namespace hnswlib { // Locks to prevent race condition during update/insert of an element at same time. // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. std::vector link_list_update_locks_; - tableint enterpoint_node_; + tableint enterpoint_node_{0}; - size_t size_links_level0_; - size_t offsetData_, offsetLevel0_; + size_t size_links_level0_{0}; + size_t offsetData_{0}, offsetLevel0_{0}; - char *data_level0_memory_; - char **linkLists_; + char *data_level0_memory_{nullptr}; + char **linkLists_{nullptr}; std::vector element_levels_; - size_t data_size_; + size_t data_size_{0}; - size_t label_offset_; + size_t label_offset_{0}; DISTFUNC fstdistfunc_; - void *dist_func_param_; + void *dist_func_param_{nullptr}; std::unordered_map label_lookup_; std::default_random_engine level_generator_; @@ -253,8 +233,8 @@ namespace hnswlib { return top_candidates; } - mutable std::atomic metric_distance_computations; - mutable std::atomic metric_hops; + mutable std::atomic metric_distance_computations{0}; + mutable std::atomic metric_hops{0}; template std::priority_queue, std::vector>, CompareByFirst> @@ -523,7 +503,7 @@ namespace hnswlib { } std::mutex global; - size_t ef_; + size_t ef_{0}; void setEf(size_t ef) { ef_ = ef; From 406731d84ebfded6ba4f9b11c81c9b2a270495e0 Mon Sep 17 00:00:00 2001 From: Jianshu_Zhao <38149286+jianshu93@users.noreply.github.com> Date: Mon, 8 Aug 2022 22:34:34 -0400 Subject: [PATCH 09/41] Add rust implementation A beautiful Rust implementation. I have run benchmarks for various dataset with very good performance and recall. More metric distance are supported including hamming, Hellinger distance, Jensen-Shannon distance between probability distributions (f32 and f64). There is A Trait to enable the user to implement its own distances. It takes as data slices of types T satisfying T:Serialize+Clone+Send+Sync. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9bcb6775..5ec6ebed 100644 --- a/README.md +++ b/README.md @@ -267,7 +267,7 @@ https://github.com/dbaranchuk/ivf-hnsw * Java bindings using Java Native Access: https://github.com/stepstone-tech/hnswlib-jna * .Net implementation: https://github.com/microsoft/HNSW.Net * CUDA implementation: https://github.com/js1010/cuhnsw - +* Rust implementation for memory and thread safety purposes and There is A Trait to enable the user to implement its own distances. It takes as data slices of types T satisfying T:Serialize+Clone+Send+Sync.: https://github.com/jean-pierreBoth/hnswlib-rs ### 200M SIFT test reproduction To download and extract the bigann dataset (from root directory): From 765c4ab4ba00e2e8a54b349c5df1f028b08953ed Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Mon, 15 Aug 2022 15:55:57 +0530 Subject: [PATCH 10/41] Filter elements with an optional filtering function. --- CMakeLists.txt | 3 + examples/searchKnnWithFilter_test.cpp | 95 +++++++++++++++++++++++++++ hnswlib/bruteforce.h | 18 +++-- hnswlib/hnswalg.h | 16 ++--- hnswlib/hnswlib.h | 20 ++++-- 5 files changed, 131 insertions(+), 21 deletions(-) create mode 100644 examples/searchKnnWithFilter_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e2f3d716..e42d6cee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp) target_link_libraries(searchKnnCloserFirst_test hnswlib) + add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp) + target_link_libraries(searchKnnWithFilter_test hnswlib) + add_executable(main main.cpp sift_1b.cpp) target_link_libraries(main hnswlib) endif() diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp new file mode 100644 index 00000000..290054d3 --- /dev/null +++ b/examples/searchKnnWithFilter_test.cpp @@ -0,0 +1,95 @@ +// This is a test file for testing the filtering feature + +#include "../hnswlib/hnswlib.h" + +#include + +#include +#include + +namespace +{ + +using idx_t = hnswlib::labeltype; + +bool pickIdsDivisibleByThree(unsigned int ep_id) { + return ep_id % 3 == 0; +} + +bool pickIdsDivisibleBySeven(unsigned int ep_id) { + return ep_id % 7 == 0; +} + +template +void test(filter_func_t filter_func, size_t div_num) { + int d = 4; + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; ++i) { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) { + query[i] = distrib(rng); + } + + + hnswlib::L2Space space(d); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + + for (size_t i = 0; i < n; ++i) { + alg_brute->addPoint(data.data() + d * i, i); + alg_hnsw->addPoint(data.data() + d * i, i); + } + + // test searchKnnCloserFirst of BruteforceSearch with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_brute->searchKnn(p, k, filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + assert((gd.top().second % div_num) == 0); + gd.pop(); + } + } + + // test searchKnnCloserFirst of hnsw with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k, filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + assert((gd.top().second % div_num) == 0); + gd.pop(); + } + } + + delete alg_brute; + delete alg_hnsw; +} + +} // namespace + +int main() { + std::cout << "Testing ..." << std::endl; + test(pickIdsDivisibleByThree, 3); + test(pickIdsDivisibleBySeven, 7); + std::cout << "Test ok" << std::endl; + + return 0; +} diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index f8e0aeb3..c16c19a0 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -5,8 +5,8 @@ #include namespace hnswlib { - template - class BruteforceSearch : public AlgorithmInterface { + template + class BruteforceSearch : public AlgorithmInterface { public: BruteforceSearch(SpaceInterface *s) : data_(nullptr), maxelements_(0), cur_element_count(0), size_per_element_(0), data_size_(0), @@ -92,20 +92,24 @@ namespace hnswlib { std::priority_queue> - searchKnn(const void *query_data, size_t k) const { + searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { std::priority_queue> topResults; if (cur_element_count == 0) return topResults; for (int i = 0; i < k; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + - data_size_)))); + labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); + if(isIdAllowed(label)) { + topResults.push(std::pair(dist, label)); + } } dist_t lastdist = topResults.top().first; for (int i = k; i < cur_element_count; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { - topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + - data_size_)))); + labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); + if(isIdAllowed(label)) { + topResults.push(std::pair(dist, label)); + } if (topResults.size() > k) topResults.pop(); lastdist = topResults.top().first; diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 8060683c..7ca41f9a 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -13,8 +13,8 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; - template - class HierarchicalNSW : public AlgorithmInterface { + template + class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; HierarchicalNSW(SpaceInterface *s) { @@ -238,7 +238,7 @@ namespace hnswlib { template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t isIdAllowed) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -247,7 +247,7 @@ namespace hnswlib { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if (!has_deletions || !isMarkedDeleted(ep_id)) { + if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); @@ -307,7 +307,7 @@ namespace hnswlib { _MM_HINT_T0);//////////////////////// #endif - if (!has_deletions || !isMarkedDeleted(candidate_id)) + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id)) top_candidates.emplace(dist, candidate_id); if (top_candidates.size() > ef) @@ -1111,7 +1111,7 @@ namespace hnswlib { }; std::priority_queue> - searchKnn(const void *query_data, size_t k) const { + searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { std::priority_queue> result; if (cur_element_count == 0) return result; @@ -1148,11 +1148,11 @@ namespace hnswlib { std::priority_queue, std::vector>, CompareByFirst> top_candidates; if (num_deleted_) { top_candidates=searchBaseLayerST( - currObj, query_data, std::max(ef_, k)); + currObj, query_data, std::max(ef_, k), isIdAllowed); } else{ top_candidates=searchBaseLayerST( - currObj, query_data, std::max(ef_, k)); + currObj, query_data, std::max(ef_, k), isIdAllowed); } while (top_candidates.size() > k) { diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 61029e90..fc48af29 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -116,6 +116,10 @@ static bool AVX512Capable() { namespace hnswlib { typedef size_t labeltype; + bool allowAllIds(unsigned int ep_id) { + return true; + } + template class pairGreater { public: @@ -137,6 +141,7 @@ namespace hnswlib { template using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + using FILTERFUNC = bool(*)(unsigned int); template class SpaceInterface { @@ -151,28 +156,31 @@ namespace hnswlib { virtual ~SpaceInterface() {} }; - template + template class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label)=0; - virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; + + virtual std::priority_queue> + searchKnn(const void*, size_t, filter_func_t isIdAllowed=allowAllIds) const = 0; // Return k nearest neighbor in the order of closer fist virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k) const; + searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const; virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ } }; - template + template std::vector> - AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k) const { + AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + filter_func_t isIdAllowed) const { std::vector> result; // here searchKnn returns the result in the order of further first - auto ret = searchKnn(query_data, k); + auto ret = searchKnn(query_data, k, isIdAllowed); { size_t sz = ret.size(); result.resize(sz); From ad3440c83555d9a76eef0e23bc6505c86b026716 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 19 Aug 2022 09:02:11 +0530 Subject: [PATCH 11/41] Filter function should be sent the label and not the internal ID. --- examples/searchKnnWithFilter_test.cpp | 12 ++++++------ hnswlib/hnswalg.h | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index 290054d3..71a055dd 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -21,7 +21,7 @@ bool pickIdsDivisibleBySeven(unsigned int ep_id) { } template -void test(filter_func_t filter_func, size_t div_num) { +void test(filter_func_t filter_func, size_t div_num, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -40,15 +40,15 @@ void test(filter_func_t filter_func, size_t div_num) { for (idx_t i = 0; i < nq * d; ++i) { query[i] = distrib(rng); } - hnswlib::L2Space space(d); hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { - alg_brute->addPoint(data.data() + d * i, i); - alg_hnsw->addPoint(data.data() + d * i, i); + // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs + alg_brute->addPoint(data.data() + d * i, label_id_start + i); + alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); } // test searchKnnCloserFirst of BruteforceSearch with filtering @@ -87,8 +87,8 @@ void test(filter_func_t filter_func, size_t div_num) { int main() { std::cout << "Testing ..." << std::endl; - test(pickIdsDivisibleByThree, 3); - test(pickIdsDivisibleBySeven, 7); + test(pickIdsDivisibleByThree, 3, 17); + test(pickIdsDivisibleBySeven, 7, 17); std::cout << "Test ok" << std::endl; return 0; diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 7ca41f9a..d319aa7e 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -247,7 +247,7 @@ namespace hnswlib { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) { + if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(getExternalLabel(ep_id))) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); @@ -307,7 +307,7 @@ namespace hnswlib { _MM_HINT_T0);//////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id)) + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id))) top_candidates.emplace(dist, candidate_id); if (top_candidates.size() > ef) From 4f6dcc38e8af8068cff455852e57a833d9ef6a22 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 19 Aug 2022 09:13:25 +0530 Subject: [PATCH 12/41] Ensure that results are not empty when reading from top results. --- examples/searchKnnWithFilter_test.cpp | 77 ++++++++++++++++++++++++--- hnswlib/bruteforce.h | 7 ++- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index 71a055dd..b048baf7 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -12,16 +12,20 @@ namespace using idx_t = hnswlib::labeltype; -bool pickIdsDivisibleByThree(unsigned int ep_id) { - return ep_id % 3 == 0; +bool pickIdsDivisibleByThree(unsigned int label_id) { + return label_id % 3 == 0; } -bool pickIdsDivisibleBySeven(unsigned int ep_id) { - return ep_id % 7 == 0; +bool pickIdsDivisibleBySeven(unsigned int label_id) { + return label_id % 7 == 0; +} + +bool pickNothing(unsigned int label_id) { + return false; } template -void test(filter_func_t filter_func, size_t div_num, size_t label_id_start) { +void test_some_filtering(filter_func_t filter_func, size_t div_num, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -83,12 +87,71 @@ void test(filter_func_t filter_func, size_t div_num, size_t label_id_start) { delete alg_hnsw; } +template +void test_none_filtering(filter_func_t filter_func, size_t label_id_start) { + int d = 4; + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; ++i) { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) { + query[i] = distrib(rng); + } + + hnswlib::L2Space space(d); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + + for (size_t i = 0; i < n; ++i) { + // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs + alg_brute->addPoint(data.data() + d * i, label_id_start + i); + alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); + } + + // test searchKnnCloserFirst of BruteforceSearch with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_brute->searchKnn(p, k, filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); + assert(gd.size() == res.size()); + assert(0 == gd.size()); + } + + // test searchKnnCloserFirst of hnsw with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k, filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); + assert(gd.size() == res.size()); + assert(0 == gd.size()); + } + + delete alg_brute; + delete alg_hnsw; +} + } // namespace int main() { std::cout << "Testing ..." << std::endl; - test(pickIdsDivisibleByThree, 3, 17); - test(pickIdsDivisibleBySeven, 7, 17); + + // some of the elements are filtered + test_some_filtering(pickIdsDivisibleByThree, 3, 17); + test_some_filtering(pickIdsDivisibleBySeven, 7, 17); + + // all of the elements are filtered + test_none_filtering(pickNothing, 17); + std::cout << "Test ok" << std::endl; return 0; diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index c16c19a0..3de18eeb 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -102,7 +102,7 @@ namespace hnswlib { topResults.push(std::pair(dist, label)); } } - dist_t lastdist = topResults.top().first; + dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; for (int i = k; i < cur_element_count; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { @@ -112,7 +112,10 @@ namespace hnswlib { } if (topResults.size() > k) topResults.pop(); - lastdist = topResults.top().first; + + if (!topResults.empty()) { + lastdist = topResults.top().first; + } } } From c5be3f5079422940ac04475e1834b8b588347ad0 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sun, 21 Aug 2022 19:06:26 +0200 Subject: [PATCH 13/41] Update port git_tester.py on Windows --- examples/git_tester.py | 64 ++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/examples/git_tester.py b/examples/git_tester.py index aaf70c82..990f7eba 100644 --- a/examples/git_tester.py +++ b/examples/git_tester.py @@ -1,34 +1,44 @@ +import os + +from sys import platform from pydriller import Repository -import os -import datetime -os.system("cp examples/speedtest.py examples/speedtest2.py") # the file has to be outside of git -for idx, commit in enumerate(Repository('.', from_tag="v0.6.0").traverse_commits()): - name=commit.msg.replace('\n', ' ').replace('\r', ' ') - print(idx, commit.hash, name) +if platform == "win32": + copy_cmd = "copy" + rm_dir_cmd = "rmdir /s /q" +else: + copy_cmd = "cp" + rm_dir_cmd = "rm -rf" +speedtest_src_path = os.path.join("examples", "speedtest.py") +speedtest_path = os.path.join("examples", "speedtest2.py") +os.system(f"{copy_cmd} {speedtest_src_path} {speedtest_path}") # the file has to be outside of git + +commits = list(Repository('.', from_tag="v0.6.0").traverse_commits()) +print("Found commits:") +for idx, commit in enumerate(commits): + name = commit.msg.replace('\n', ' ').replace('\r', ' ') + print(idx, commit.hash, name) -for commit in Repository('.', from_tag="v0.6.0").traverse_commits(): - - name=commit.msg.replace('\n', ' ').replace('\r', ' ') - print(commit.hash, name) - - os.system(f"git checkout {commit.hash}; rm -rf build; ") +for commit in commits: + name = commit.msg.replace('\n', ' ').replace('\r', ' ') + print("\nProcessing", commit.hash, name) + + os.system(f"git checkout {commit.hash}") + os.system(f"{rm_dir_cmd} build") print("\n\n--------------------\n\n") - ret=os.system("python -m pip install .") - print(ret) - - if ret != 0: - print ("build failed!!!!") - print ("build failed!!!!") - print ("build failed!!!!") - print ("build failed!!!!") - continue - - os.system(f'python examples/speedtest2.py -n "{name}" -d 4 -t 1') - os.system(f'python examples/speedtest2.py -n "{name}" -d 64 -t 1') - os.system(f'python examples/speedtest2.py -n "{name}" -d 128 -t 1') - os.system(f'python examples/speedtest2.py -n "{name}" -d 4 -t 24') - os.system(f'python examples/speedtest2.py -n "{name}" -d 128 -t 24') + ret = os.system("python -m pip install .") + print("Install result:", ret) + if ret != 0: + print("build failed!!!!") + print("build failed!!!!") + print("build failed!!!!") + print("build failed!!!!") + continue + os.system(f'python {speedtest_path} -n "{name}" -d 4 -t 1') + os.system(f'python {speedtest_path} -n "{name}" -d 64 -t 1') + os.system(f'python {speedtest_path} -n "{name}" -d 128 -t 1') + os.system(f'python {speedtest_path} -n "{name}" -d 4 -t 24') + os.system(f'python {speedtest_path} -n "{name}" -d 128 -t 24') From 1c833a73f504ab383bb7c31a036b71bf5e53a861 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 25 Aug 2022 15:56:08 +0530 Subject: [PATCH 14/41] Make allowAllIds static. --- hnswlib/hnswlib.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index fc48af29..0b6f84a6 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -116,7 +116,7 @@ static bool AVX512Capable() { namespace hnswlib { typedef size_t labeltype; - bool allowAllIds(unsigned int ep_id) { + static bool allowAllIds(unsigned int ep_id) { return true; } From aaee13a931c9b4320cfc88e5ab549199d405f787 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 26 Aug 2022 16:12:26 +0530 Subject: [PATCH 15/41] Use functor for filtering. --- examples/searchKnnWithFilter_test.cpp | 27 +++++++++++++++++++++------ hnswlib/bruteforce.h | 4 ++-- hnswlib/hnswalg.h | 6 +++--- hnswlib/hnswlib.h | 20 +++++++++++--------- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index b048baf7..9219be03 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -25,7 +25,7 @@ bool pickNothing(unsigned int label_id) { } template -void test_some_filtering(filter_func_t filter_func, size_t div_num, size_t label_id_start) { +void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -46,8 +46,8 @@ void test_some_filtering(filter_func_t filter_func, size_t div_num, size_t label } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -88,7 +88,7 @@ void test_some_filtering(filter_func_t filter_func, size_t div_num, size_t label } template -void test_none_filtering(filter_func_t filter_func, size_t label_id_start) { +void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -109,8 +109,8 @@ void test_none_filtering(filter_func_t filter_func, size_t label_id_start) { } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -142,6 +142,17 @@ void test_none_filtering(filter_func_t filter_func, size_t label_id_start) { } // namespace +class CustomFilterFunctor: public hnswlib::FilterFunctor { + std::unordered_set allowed_values; + +public: + explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} + + constexpr bool operator()(unsigned int id) const { + return allowed_values.count(id) != 0; + } +}; + int main() { std::cout << "Testing ..." << std::endl; @@ -152,6 +163,10 @@ int main() { // all of the elements are filtered test_none_filtering(pickNothing, 17); + // functor style which can capture context + CustomFilterFunctor pickIdsDivisibleByThirteen({26, 39, 52, 65}); + test_some_filtering(pickIdsDivisibleByThirteen, 13, 21); + std::cout << "Test ok" << std::endl; return 0; diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 3de18eeb..a56f75f1 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -5,7 +5,7 @@ #include namespace hnswlib { - template + template class BruteforceSearch : public AlgorithmInterface { public: BruteforceSearch(SpaceInterface *s) : data_(nullptr), maxelements_(0), @@ -92,7 +92,7 @@ namespace hnswlib { std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { + searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { std::priority_queue> topResults; if (cur_element_count == 0) return topResults; for (int i = 0; i < k; i++) { diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index d319aa7e..23fddcc1 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -13,7 +13,7 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; - template + template class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; @@ -238,7 +238,7 @@ namespace hnswlib { template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t isIdAllowed) const { + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -1111,7 +1111,7 @@ namespace hnswlib { }; std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { + searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { std::priority_queue> result; if (cur_element_count == 0) return result; diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 0b6f84a6..b1c88df5 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -116,9 +116,13 @@ static bool AVX512Capable() { namespace hnswlib { typedef size_t labeltype; - static bool allowAllIds(unsigned int ep_id) { - return true; - } + // This can be extended to store state for filtering (e.g. from a std::set) + struct FilterFunctor { + template + bool operator()(Args&&...) { return true; } + }; + + FilterFunctor allowAllIds; template class pairGreater { @@ -141,8 +145,6 @@ namespace hnswlib { template using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); - using FILTERFUNC = bool(*)(unsigned int); - template class SpaceInterface { public: @@ -156,17 +158,17 @@ namespace hnswlib { virtual ~SpaceInterface() {} }; - template + template class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label)=0; virtual std::priority_queue> - searchKnn(const void*, size_t, filter_func_t isIdAllowed=allowAllIds) const = 0; + searchKnn(const void*, size_t, filter_func_t& isIdAllowed=allowAllIds) const = 0; // Return k nearest neighbor in the order of closer fist virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const; + searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const; virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ @@ -176,7 +178,7 @@ namespace hnswlib { template std::vector> AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, - filter_func_t isIdAllowed) const { + filter_func_t& isIdAllowed) const { std::vector> result; // here searchKnn returns the result in the order of further first From b87f6230dbe59e874b3099cfcab689b42e887a20 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sat, 27 Aug 2022 13:20:23 +0530 Subject: [PATCH 16/41] Explicitly check for filter functor being default. --- hnswlib/hnswalg.h | 6 ++++-- hnswlib/hnswlib.h | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 23fddcc1..d7fd385f 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -247,7 +247,8 @@ namespace hnswlib { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(getExternalLabel(ep_id))) { + bool is_filter_disabled = std::is_same::value; + if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); @@ -307,7 +308,8 @@ namespace hnswlib { _MM_HINT_T0);//////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id))) + is_filter_disabled = std::is_same::value; + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) top_candidates.emplace(dist, candidate_id); if (top_candidates.size() > ef) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index b1c88df5..d8997044 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -122,7 +122,7 @@ namespace hnswlib { bool operator()(Args&&...) { return true; } }; - FilterFunctor allowAllIds; + static FilterFunctor allowAllIds; template class pairGreater { From e8da5a09955fb8f899ff360616e6b9f570d38677 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sat, 27 Aug 2022 12:02:08 +0200 Subject: [PATCH 17/41] Refactoring --- examples/git_tester.py | 14 +++++++------- python_bindings/bindings.cpp | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/git_tester.py b/examples/git_tester.py index 990f7eba..39a3af6c 100644 --- a/examples/git_tester.py +++ b/examples/git_tester.py @@ -11,8 +11,8 @@ rm_dir_cmd = "rm -rf" speedtest_src_path = os.path.join("examples", "speedtest.py") -speedtest_path = os.path.join("examples", "speedtest2.py") -os.system(f"{copy_cmd} {speedtest_src_path} {speedtest_path}") # the file has to be outside of git +speedtest_copy_path = os.path.join("examples", "speedtest2.py") +os.system(f"{copy_cmd} {speedtest_src_path} {speedtest_copy_path}") # the file has to be outside of git commits = list(Repository('.', from_tag="v0.6.0").traverse_commits()) print("Found commits:") @@ -37,8 +37,8 @@ print("build failed!!!!") continue - os.system(f'python {speedtest_path} -n "{name}" -d 4 -t 1') - os.system(f'python {speedtest_path} -n "{name}" -d 64 -t 1') - os.system(f'python {speedtest_path} -n "{name}" -d 128 -t 1') - os.system(f'python {speedtest_path} -n "{name}" -d 4 -t 24') - os.system(f'python {speedtest_path} -n "{name}" -d 128 -t 24') + os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 1') + os.system(f'python {speedtest_copy_path} -n "{name}" -d 64 -t 1') + os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 1') + os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 24') + os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 24') diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 3050d972..a72b5b21 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -155,7 +155,7 @@ class Index { void loadIndex(const std::string &path_to_index, size_t max_elements) { if (appr_alg) { - std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated."; + std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl; delete appr_alg; } appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements); @@ -768,7 +768,7 @@ class BFIndex { void loadIndex(const std::string &path_to_index, size_t max_elements) { if (alg) { - std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated."; + std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl; delete alg; } alg = new hnswlib::BruteforceSearch(space, path_to_index); From 74bf4a335ea61b302e80a1cb059fe816d3d520cd Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sat, 27 Aug 2022 19:44:03 +0200 Subject: [PATCH 18/41] Add cpp tests to CI --- .github/workflows/build.yml | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7dfba102..f8fde085 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,15 +3,15 @@ name: HNSW CI on: [push, pull_request] jobs: - test: + test_python: runs-on: ${{matrix.os}} strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: ['3.6', '3.7', '3.8', '3.9'] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -20,3 +20,31 @@ jobs: - name: Test run: python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py" + + test_cpp: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Build + run: | + mkdir build + cd build + cmake .. + make + + - name: Prepare test data + run: | + pip install numpy + cd examples + python update_gen_data.py + + - name: Test + run: | + cd build + ./searchKnnCloserFirst_test + ./test_updates + ./test_updates update From bdd022035b5ffcf7ba1bd322bcccd6bc527f3ab1 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sun, 28 Aug 2022 10:24:23 +0200 Subject: [PATCH 19/41] Use shutil --- examples/git_tester.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/examples/git_tester.py b/examples/git_tester.py index 39a3af6c..be3b8a25 100644 --- a/examples/git_tester.py +++ b/examples/git_tester.py @@ -1,18 +1,13 @@ import os +import shutil from sys import platform from pydriller import Repository -if platform == "win32": - copy_cmd = "copy" - rm_dir_cmd = "rmdir /s /q" -else: - copy_cmd = "cp" - rm_dir_cmd = "rm -rf" speedtest_src_path = os.path.join("examples", "speedtest.py") speedtest_copy_path = os.path.join("examples", "speedtest2.py") -os.system(f"{copy_cmd} {speedtest_src_path} {speedtest_copy_path}") # the file has to be outside of git +shutil.copyfile(speedtest_src_path, speedtest_copy_path) # the file has to be outside of git commits = list(Repository('.', from_tag="v0.6.0").traverse_commits()) print("Found commits:") @@ -24,8 +19,9 @@ name = commit.msg.replace('\n', ' ').replace('\r', ' ') print("\nProcessing", commit.hash, name) + if os.path.exists("build"): + shutil.rmtree("build") os.system(f"git checkout {commit.hash}") - os.system(f"{rm_dir_cmd} build") print("\n\n--------------------\n\n") ret = os.system("python -m pip install .") print("Install result:", ret) From f0dedf3956de1762fa0b0611ac03b4c29d236bf5 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 28 Aug 2022 15:05:41 +0530 Subject: [PATCH 20/41] Remove duplicate assignment. --- hnswlib/hnswalg.h | 1 - 1 file changed, 1 deletion(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index d7fd385f..57fba444 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -308,7 +308,6 @@ namespace hnswlib { _MM_HINT_T0);//////////////////////// #endif - is_filter_disabled = std::is_same::value; if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) top_candidates.emplace(dist, candidate_id); From e4705fd3f09dd56d05278aebb0f6e1a25383e4f5 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 28 Aug 2022 15:11:22 +0530 Subject: [PATCH 21/41] Add search with filter test to CI. --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f8fde085..5e0c1f9d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -46,5 +46,6 @@ jobs: run: | cd build ./searchKnnCloserFirst_test + ./searchKnnWithFilter_test ./test_updates ./test_updates update From 7f419eaaa36c83b22e623a99c7be0d22ce47f4f4 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 28 Aug 2022 18:01:53 +0530 Subject: [PATCH 22/41] Remove constexpr for functor in test. --- examples/searchKnnWithFilter_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index 9219be03..ead0c6fd 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -148,7 +148,7 @@ class CustomFilterFunctor: public hnswlib::FilterFunctor { public: explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} - constexpr bool operator()(unsigned int id) const { + bool operator()(unsigned int id) { return allowed_values.count(id) != 0; } }; From f7d33662fcae799d9050fc4384466304aea4e535 Mon Sep 17 00:00:00 2001 From: Alexander Vieth Date: Mon, 29 Aug 2022 11:14:14 +0200 Subject: [PATCH 23/41] USE_SSE with msvc compilers --- hnswlib/hnswlib.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 58eb7607..41579df6 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -1,6 +1,6 @@ #pragma once #ifndef NO_MANUAL_VECTORIZATION -#ifdef __SSE__ +#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) #define USE_SSE #ifdef __AVX__ #define USE_AVX From 23f53517a82978defdec04e406b391c38222a538 Mon Sep 17 00:00:00 2001 From: Alexander Vieth Date: Mon, 29 Aug 2022 11:26:51 +0200 Subject: [PATCH 24/41] Remove inclusion of cpu_x86.h --- hnswlib/hnswlib.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 41579df6..cda87e59 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -15,8 +15,7 @@ #ifdef _MSC_VER #include #include -#include "cpu_x86.h" -void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) { +void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { __cpuidex(out, eax, ecx); } __int64 xgetbv(unsigned int x) { From e8b3e449e4e8fa9e5a28526e633ffaab0c8bafd8 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sun, 28 Aug 2022 11:46:38 +0200 Subject: [PATCH 25/41] Add cpp tests for Windows in CI --- .github/workflows/build.yml | 17 +++++++++++++++-- examples/updates_test.cpp | 25 ++++++++++++++++--------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f8fde085..219efec3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,7 +22,10 @@ jobs: run: python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py" test_cpp: - runs-on: ubuntu-latest + runs-on: ${{matrix.os}} + strategy: + matrix: + os: [ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -34,17 +37,27 @@ jobs: mkdir build cd build cmake .. - make + if [ "$RUNNER_OS" == "Linux" ]; then + make + elif [ "$RUNNER_OS" == "Windows" ]; then + cmake --build ./ --config Release + fi + shell: bash - name: Prepare test data run: | pip install numpy cd examples python update_gen_data.py + shell: bash - name: Test run: | cd build + if [ "$RUNNER_OS" == "Windows" ]; then + cp ./Release/* ./ + fi ./searchKnnCloserFirst_test ./test_updates ./test_updates update + shell: bash diff --git a/examples/updates_test.cpp b/examples/updates_test.cpp index c8775877..d4cc995b 100644 --- a/examples/updates_test.cpp +++ b/examples/updates_test.cpp @@ -1,5 +1,7 @@ #include "../hnswlib/hnswlib.h" #include + + class StopW { std::chrono::steady_clock::time_point time_begin; @@ -22,6 +24,7 @@ class StopW } }; + /* * replacement for the openmp '#pragma omp parallel for' directive * only handles a subset of functionality (no reductions etc) @@ -81,8 +84,6 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn std::rethrow_exception(lastException); } } - - } @@ -94,7 +95,7 @@ std::vector load_batch(std::string path, int size) assert(sizeof(datatype) == 4); std::ifstream file; - file.open(path); + file.open(path, std::ios::binary); if (!file.is_open()) { std::cout << "Cannot open " << path << "\n"; @@ -107,6 +108,7 @@ std::vector load_batch(std::string path, int size) return batch; } + template static float test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, @@ -137,6 +139,7 @@ test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW< return 1.0f * correct / total; } + static void test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, std::vector> &answers, size_t k) @@ -155,6 +158,8 @@ test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalN efs.push_back(i); } std::cout << "ef\trecall\ttime\thops\tdistcomp\n"; + + bool test_passed = false; for (size_t ef : efs) { appr_alg.setEf(ef); @@ -171,20 +176,24 @@ test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalN std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"< 0.99) { + test_passed = true; std::cout << "Recall is over 0.99! "< Date: Tue, 6 Sep 2022 15:37:09 +0530 Subject: [PATCH 26/41] Add check for is_filter_disabled. --- hnswlib/bruteforce.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index a56f75f1..33130273 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -93,12 +93,14 @@ namespace hnswlib { std::priority_queue> searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { + assert(k <= cur_element_count); std::priority_queue> topResults; if (cur_element_count == 0) return topResults; + bool is_filter_disabled = std::is_same::value; for (int i = 0; i < k; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); - if(isIdAllowed(label)) { + if(is_filter_disabled || isIdAllowed(label)) { topResults.push(std::pair(dist, label)); } } @@ -107,7 +109,7 @@ namespace hnswlib { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); - if(isIdAllowed(label)) { + if(is_filter_disabled || isIdAllowed(label)) { topResults.push(std::pair(dist, label)); } if (topResults.size() > k) From c9897b0f730c48428b587a791cbf901a4550add8 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 6 Sep 2022 16:52:54 +0530 Subject: [PATCH 27/41] Add assert header. --- hnswlib/bruteforce.h | 1 + 1 file changed, 1 insertion(+) diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 33130273..9fe97c09 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace hnswlib { template From 6d28ec0d43ec669154dc51540921c4306b1636ef Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sun, 18 Sep 2022 12:54:55 +0300 Subject: [PATCH 28/41] Refactoring (#410) --- .gitignore | 2 +- examples/searchKnnCloserFirst_test.cpp | 10 +- examples/searchKnnWithFilter_test.cpp | 19 +- examples/updates_test.cpp | 125 +- hnswlib/bruteforce.h | 282 ++-- hnswlib/hnswalg.h | 1913 ++++++++++++------------ hnswlib/hnswlib.h | 130 +- hnswlib/space_ip.h | 597 ++++---- hnswlib/space_l2.h | 538 ++++--- hnswlib/visited_list_pool.h | 109 +- main.cpp | 2 +- python_bindings/bindings.cpp | 805 +++++----- sift_1b.cpp | 72 +- sift_test.cpp | 68 +- 14 files changed, 2320 insertions(+), 2352 deletions(-) diff --git a/.gitignore b/.gitignore index dab30385..a338107c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,4 @@ hnswlib.cpython*.so var/ .idea/ .vscode/ - +.vs/ diff --git a/examples/searchKnnCloserFirst_test.cpp b/examples/searchKnnCloserFirst_test.cpp index cc1392c8..d87102cd 100644 --- a/examples/searchKnnCloserFirst_test.cpp +++ b/examples/searchKnnCloserFirst_test.cpp @@ -10,8 +10,7 @@ #include #include -namespace -{ +namespace { using idx_t = hnswlib::labeltype; @@ -20,7 +19,7 @@ void test() { idx_t n = 100; idx_t nq = 10; size_t k = 10; - + std::vector data(n * d); std::vector query(nq * d); @@ -34,7 +33,6 @@ void test() { for (idx_t i = 0; i < nq * d; ++i) { query[i] = distrib(rng); } - hnswlib::L2Space space(d); hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); @@ -68,12 +66,12 @@ void test() { gd.pop(); } } - + delete alg_brute; delete alg_hnsw; } -} // namespace +} // namespace int main() { std::cout << "Testing ..." << std::endl; diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index ead0c6fd..4aee49b0 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -7,8 +7,7 @@ #include #include -namespace -{ +namespace { using idx_t = hnswlib::labeltype; @@ -30,7 +29,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe idx_t n = 100; idx_t nq = 10; size_t k = 10; - + std::vector data(n * d); std::vector query(nq * d); @@ -46,8 +45,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -82,7 +81,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe gd.pop(); } } - + delete alg_brute; delete alg_hnsw; } @@ -109,8 +108,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -140,12 +139,12 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { delete alg_hnsw; } -} // namespace +} // namespace class CustomFilterFunctor: public hnswlib::FilterFunctor { std::unordered_set allowed_values; -public: + public: explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} bool operator()(unsigned int id) { diff --git a/examples/updates_test.cpp b/examples/updates_test.cpp index d4cc995b..8e4ac644 100644 --- a/examples/updates_test.cpp +++ b/examples/updates_test.cpp @@ -2,24 +2,20 @@ #include -class StopW -{ +class StopW { std::chrono::steady_clock::time_point time_begin; -public: - StopW() - { + public: + StopW() { time_begin = std::chrono::steady_clock::now(); } - float getElapsedTimeMicro() - { + float getElapsedTimeMicro() { std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now(); return (std::chrono::duration_cast(time_end - time_begin).count()); } - void reset() - { + void reset() { time_begin = std::chrono::steady_clock::now(); } }; @@ -88,16 +84,14 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn template -std::vector load_batch(std::string path, int size) -{ +std::vector load_batch(std::string path, int size) { std::cout << "Loading " << path << "..."; // float or int32 (python) assert(sizeof(datatype) == 4); std::ifstream file; file.open(path, std::ios::binary); - if (!file.is_open()) - { + if (!file.is_open()) { std::cout << "Cannot open " << path << "\n"; exit(1); } @@ -112,26 +106,17 @@ std::vector load_batch(std::string path, int size) template static float test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, - std::vector> &answers, size_t K) -{ + std::vector> &answers, size_t K) { size_t correct = 0; size_t total = 0; - //uncomment to test in parallel mode: - - - for (int i = 0; i < qsize; i++) - { + for (int i = 0; i < qsize; i++) { std::priority_queue> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K); total += K; - while (result.size()) - { - if (answers[i].find(result.top().second) != answers[i].end()) - { + while (result.size()) { + if (answers[i].find(result.top().second) != answers[i].end()) { correct++; - } - else - { + } else { } result.pop(); } @@ -141,31 +126,32 @@ test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW< static void -test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, - std::vector> &answers, size_t k) -{ +test_vs_recall( + std::vector &queries, + size_t qsize, + hnswlib::HierarchicalNSW &appr_alg, + size_t vecdim, + std::vector> &answers, + size_t k) { + std::vector efs = {1}; - for (int i = k; i < 30; i++) - { + for (int i = k; i < 30; i++) { efs.push_back(i); } - for (int i = 30; i < 400; i+=10) - { + for (int i = 30; i < 400; i+=10) { efs.push_back(i); } - for (int i = 1000; i < 100000; i += 5000) - { + for (int i = 1000; i < 100000; i += 5000) { efs.push_back(i); } std::cout << "ef\trecall\ttime\thops\tdistcomp\n"; bool test_passed = false; - for (size_t ef : efs) - { + for (size_t ef : efs) { appr_alg.setEf(ef); - appr_alg.metric_hops=0; - appr_alg.metric_distance_computations=0; + appr_alg.metric_hops = 0; + appr_alg.metric_distance_computations = 0; StopW stopw = StopW(); float recall = test_approx(queries, qsize, appr_alg, vecdim, answers, k); @@ -173,44 +159,37 @@ test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalN float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize); float hops_per_query = appr_alg.metric_hops / (1.0f * qsize); - std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"< 0.99) - { + std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n"; + if (recall > 0.99) { test_passed = true; - std::cout << "Recall is over 0.99! "<2){ - std::cout<<"Usage ./test_updates [update]\n"; + } else if (argc > 2) { + std::cout << "Usage ./test_updates [update]\n"; exit(1); } @@ -224,8 +203,7 @@ int main(int argc, char **argv) { std::ifstream configfile; configfile.open(path + "/config.txt"); - if (!configfile.is_open()) - { + if (!configfile.is_open()) { std::cout << "Cannot open config.txt\n"; return 1; } @@ -245,10 +223,9 @@ int main(int argc, char **argv) StopW stopw = StopW(); - if (update) - { + if (update) { std::cout << "Update iteration 0\n"; - + ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) { appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); }); @@ -259,14 +236,13 @@ int main(int argc, char **argv) }); appr_alg.checkIntegrity(); - for (int b = 1; b < dummy_data_multiplier; b++) - { + for (int b = 1; b < dummy_data_multiplier; b++) { std::cout << "Update iteration " << b << "\n"; char cpath[1024]; sprintf(cpath, "batch_dummy_%02d.bin", b); std::vector dummy_batchb = load_batch(path + cpath, N * d); - - ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { + + ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); }); appr_alg.checkIntegrity(); @@ -275,31 +251,28 @@ int main(int argc, char **argv) std::cout << "Inserting final elements\n"; std::vector final_batch = load_batch(path + "batch_final.bin", N * d); - + stopw.reset(); ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { appr_alg.addPoint((void *)(final_batch.data() + i * d), i); }); - std::cout<<"Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n"; + std::cout << "Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n"; std::cout << "Running tests\n"; std::vector queries_batch = load_batch(path + "queries.bin", N_queries * d); std::vector gt = load_batch(path + "gt.bin", N_queries * K); std::vector> answers(N_queries); - for (int i = 0; i < N_queries; i++) - { - for (int j = 0; j < K; j++) - { + for (int i = 0; i < N_queries; i++) { + for (int j = 0; j < K; j++) { answers[i].insert(gt[i * K + j]); } } - for (int i = 0; i < 3; i++) - { + for (int i = 0; i < 3; i++) { std::cout << "Test iteration " << i << "\n"; test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K); } return 0; -}; +} diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 9fe97c09..ec2ef350 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -6,161 +6,163 @@ #include namespace hnswlib { - template - class BruteforceSearch : public AlgorithmInterface { - public: - BruteforceSearch(SpaceInterface *s) : data_(nullptr), maxelements_(0), - cur_element_count(0), size_per_element_(0), data_size_(0), - dist_func_param_(nullptr) { - - } - BruteforceSearch(SpaceInterface *s, const std::string &location) : - data_(nullptr), maxelements_(0), cur_element_count(0), size_per_element_(0), - data_size_(0), dist_func_param_(nullptr) { - loadIndex(location, s); - } - - BruteforceSearch(SpaceInterface *s, size_t maxElements) { - maxelements_ = maxElements; - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - size_per_element_ = data_size_ + sizeof(labeltype); - data_ = (char *) malloc(maxElements * size_per_element_); - if (data_ == nullptr) - throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); - cur_element_count = 0; - } - - ~BruteforceSearch() { - free(data_); - } - - char *data_; - size_t maxelements_; - size_t cur_element_count; - size_t size_per_element_; - - size_t data_size_; - DISTFUNC fstdistfunc_; - void *dist_func_param_; - std::mutex index_lock; - - std::unordered_map dict_external_to_internal; - - void addPoint(const void *datapoint, labeltype label) { - - int idx; - { - std::unique_lock lock(index_lock); - - - - auto search=dict_external_to_internal.find(label); - if (search != dict_external_to_internal.end()) { - idx=search->second; - } - else{ - if (cur_element_count >= maxelements_) { - throw std::runtime_error("The number of elements exceeds the specified limit\n"); - } - idx=cur_element_count; - dict_external_to_internal[label] = idx; - cur_element_count++; +template +class BruteforceSearch : public AlgorithmInterface { + public: + char *data_; + size_t maxelements_; + size_t cur_element_count; + size_t size_per_element_; + + size_t data_size_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::mutex index_lock; + + std::unordered_map dict_external_to_internal; + + + BruteforceSearch(SpaceInterface *s) + : data_(nullptr), + maxelements_(0), + cur_element_count(0), + size_per_element_(0), + data_size_(0), + dist_func_param_(nullptr) { + } + + + BruteforceSearch(SpaceInterface *s, const std::string &location) + : data_(nullptr), + maxelements_(0), + cur_element_count(0), + size_per_element_(0), + data_size_(0), + dist_func_param_(nullptr) { + loadIndex(location, s); + } + + + BruteforceSearch(SpaceInterface *s, size_t maxElements) { + maxelements_ = maxElements; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxElements * size_per_element_); + if (data_ == nullptr) + throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + cur_element_count = 0; + } + + + ~BruteforceSearch() { + free(data_); + } + + + void addPoint(const void *datapoint, labeltype label) { + int idx; + { + std::unique_lock lock(index_lock); + + auto search = dict_external_to_internal.find(label); + if (search != dict_external_to_internal.end()) { + idx = search->second; + } else { + if (cur_element_count >= maxelements_) { + throw std::runtime_error("The number of elements exceeds the specified limit\n"); } + idx = cur_element_count; + dict_external_to_internal[label] = idx; + cur_element_count++; } - memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); - memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); - - - - - }; - - void removePoint(labeltype cur_external) { - size_t cur_c=dict_external_to_internal[cur_external]; - - dict_external_to_internal.erase(cur_external); - - labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); - dict_external_to_internal[label]=cur_c; - memcpy(data_ + size_per_element_ * cur_c, - data_ + size_per_element_ * (cur_element_count-1), - data_size_+sizeof(labeltype)); - cur_element_count--; - } - - - std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { - assert(k <= cur_element_count); - std::priority_queue> topResults; - if (cur_element_count == 0) return topResults; - bool is_filter_disabled = std::is_same::value; - for (int i = 0; i < k; i++) { - dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); - if(is_filter_disabled || isIdAllowed(label)) { - topResults.push(std::pair(dist, label)); - } + memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); + memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); + } + + + void removePoint(labeltype cur_external) { + size_t cur_c = dict_external_to_internal[cur_external]; + + dict_external_to_internal.erase(cur_external); + + labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); + dict_external_to_internal[label] = cur_c; + memcpy(data_ + size_per_element_ * cur_c, + data_ + size_per_element_ * (cur_element_count-1), + data_size_+sizeof(labeltype)); + cur_element_count--; + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const { + assert(k <= cur_element_count); + std::priority_queue> topResults; + if (cur_element_count == 0) return topResults; + bool is_filter_disabled = std::is_same::value; + for (int i = 0; i < k; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); + if (is_filter_disabled || isIdAllowed(label)) { + topResults.push(std::pair(dist, label)); } - dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; - for (int i = k; i < cur_element_count; i++) { - dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - if (dist <= lastdist) { - labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); - if(is_filter_disabled || isIdAllowed(label)) { - topResults.push(std::pair(dist, label)); - } - if (topResults.size() > k) - topResults.pop(); - - if (!topResults.empty()) { - lastdist = topResults.top().first; - } + } + dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; + for (int i = k; i < cur_element_count; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + if (dist <= lastdist) { + labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); + if (is_filter_disabled || isIdAllowed(label)) { + topResults.push(std::pair(dist, label)); } + if (topResults.size() > k) + topResults.pop(); + if (!topResults.empty()) { + lastdist = topResults.top().first; + } } - return topResults; - }; - - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; - - writeBinaryPOD(output, maxelements_); - writeBinaryPOD(output, size_per_element_); - writeBinaryPOD(output, cur_element_count); + } + return topResults; + } - output.write(data_, maxelements_ * size_per_element_); - output.close(); - } + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; - void loadIndex(const std::string &location, SpaceInterface *s) { + writeBinaryPOD(output, maxelements_); + writeBinaryPOD(output, size_per_element_); + writeBinaryPOD(output, cur_element_count); + output.write(data_, maxelements_ * size_per_element_); - std::ifstream input(location, std::ios::binary); - std::streampos position; + output.close(); + } - readBinaryPOD(input, maxelements_); - readBinaryPOD(input, size_per_element_); - readBinaryPOD(input, cur_element_count); - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - size_per_element_ = data_size_ + sizeof(labeltype); - data_ = (char *) malloc(maxelements_ * size_per_element_); - if (data_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + void loadIndex(const std::string &location, SpaceInterface *s) { + std::ifstream input(location, std::ios::binary); + std::streampos position; - input.read(data_, maxelements_ * size_per_element_); + readBinaryPOD(input, maxelements_); + readBinaryPOD(input, size_per_element_); + readBinaryPOD(input, cur_element_count); - input.close(); + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxelements_ * size_per_element_); + if (data_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); - } + input.read(data_, maxelements_ * size_per_element_); - }; -} + input.close(); + } +}; +} // namespace hnswlib diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 57fba444..32b173e1 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -10,931 +10,1039 @@ #include namespace hnswlib { - typedef unsigned int tableint; - typedef unsigned int linklistsizeint; - - template - class HierarchicalNSW : public AlgorithmInterface { - public: - static const tableint max_update_element_locks = 65536; - HierarchicalNSW(SpaceInterface *s) { - } +typedef unsigned int tableint; +typedef unsigned int linklistsizeint; + +template +class HierarchicalNSW : public AlgorithmInterface { + public: + static const tableint max_update_element_locks = 65536; + static const unsigned char DELETE_MARK = 0x01; + + size_t max_elements_{0}; + size_t cur_element_count{0}; + size_t size_data_per_element_{0}; + size_t size_links_per_element_{0}; + mutable std::atomic num_deleted_{0}; // number of deleted elements + size_t M_{0}; + size_t maxM_{0}; + size_t maxM0_{0}; + size_t ef_construction_{0}; + size_t ef_{ 0 }; + + double mult_{0.0}, revSize_{0.0}; + int maxlevel_{0}; + + VisitedListPool *visited_list_pool_{nullptr}; + + // Locks to prevent race condition during update/insert of an element at same time. + // Note: Locks for additions can also be used to prevent this race condition + // if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. + std::vector link_list_update_locks_; + + std::mutex global; + std::mutex cur_element_count_guard_; + std::vector link_list_locks_; + + tableint enterpoint_node_{0}; + + size_t size_links_level0_{0}; + size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 }; + + char *data_level0_memory_{nullptr}; + char **linkLists_{nullptr}; + std::vector element_levels_; // keeps level of each element + + size_t data_size_{0}; + + DISTFUNC fstdistfunc_; + void *dist_func_param_{nullptr}; + std::mutex label_lookup_lock; + std::unordered_map label_lookup_; + + std::default_random_engine level_generator_; + std::default_random_engine update_probability_generator_; + + mutable std::atomic metric_distance_computations{0}; + mutable std::atomic metric_hops{0}; + + + HierarchicalNSW(SpaceInterface *s) { + } + + + HierarchicalNSW( + SpaceInterface *s, + const std::string &location, + bool nmslib = false, + size_t max_elements = 0) { + loadIndex(location, s, max_elements); + } + + + HierarchicalNSW( + SpaceInterface *s, + size_t max_elements, + size_t M = 16, + size_t ef_construction = 200, + size_t random_seed = 100) + : link_list_locks_(max_elements), + link_list_update_locks_(max_update_element_locks), + element_levels_(max_elements) { + max_elements_ = max_elements; + num_deleted_ = 0; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction, M_); + ef_ = 10; + + level_generator_.seed(random_seed); + update_probability_generator_.seed(random_seed + 1); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + offsetData_ = size_links_level0_; + label_offset_ = size_links_level0_ + data_size_; + offsetLevel0_ = 0; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + // initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } - HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { - loadIndex(location, s, max_elements); - } - HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : - link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) { - max_elements_ = max_elements; - - num_deleted_ = 0; - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - M_ = M; - maxM_ = M_; - maxM0_ = M_ * 2; - ef_construction_ = std::max(ef_construction,M_); - ef_ = 10; - - level_generator_.seed(random_seed); - update_probability_generator_.seed(random_seed + 1); - - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); - offsetData_ = size_links_level0_; - label_offset_ = size_links_level0_ + data_size_; - offsetLevel0_ = 0; - - data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory"); - - cur_element_count = 0; - - visited_list_pool_ = new VisitedListPool(1, max_elements); - - //initializations for special treatment of the first node - enterpoint_node_ = -1; - maxlevel_ = -1; - - linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - mult_ = 1 / log(1.0 * M_); - revSize_ = 1.0 / mult_; + ~HierarchicalNSW() { + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); } + free(linkLists_); + delete visited_list_pool_; + } - struct CompareByFirst { - constexpr bool operator()(std::pair const &a, - std::pair const &b) const noexcept { - return a.first < b.first; - } - }; - ~HierarchicalNSW() { - - free(data_level0_memory_); - for (tableint i = 0; i < cur_element_count; i++) { - if (element_levels_[i] > 0) - free(linkLists_[i]); - } - free(linkLists_); - delete visited_list_pool_; + struct CompareByFirst { + constexpr bool operator()(std::pair const& a, + std::pair const& b) const noexcept { + return a.first < b.first; } + }; - size_t max_elements_{0}; - size_t cur_element_count{0}; - size_t size_data_per_element_{0}; - size_t size_links_per_element_{0}; - size_t num_deleted_{0}; - size_t M_{0}; - size_t maxM_{0}; - size_t maxM0_{0}; - size_t ef_construction_{0}; - double mult_{0.0}, revSize_{0.0}; - int maxlevel_{0}; + void setEf(size_t ef) { + ef_ = ef; + } - VisitedListPool *visited_list_pool_{nullptr}; - std::mutex cur_element_count_guard_; + inline labeltype getExternalLabel(tableint internal_id) const { + labeltype return_label; + memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + return return_label; + } - std::vector link_list_locks_; - // Locks to prevent race condition during update/insert of an element at same time. - // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. - std::vector link_list_update_locks_; - tableint enterpoint_node_{0}; + inline void setExternalLabel(tableint internal_id, labeltype label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } - size_t size_links_level0_{0}; - size_t offsetData_{0}, offsetLevel0_{0}; - char *data_level0_memory_{nullptr}; - char **linkLists_{nullptr}; - std::vector element_levels_; + inline labeltype *getExternalLabeLp(tableint internal_id) const { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } - size_t data_size_{0}; - size_t label_offset_{0}; - DISTFUNC fstdistfunc_; - void *dist_func_param_{nullptr}; - std::unordered_map label_lookup_; + inline char *getDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } - std::default_random_engine level_generator_; - std::default_random_engine update_probability_generator_; - inline labeltype getExternalLabel(tableint internal_id) const { - labeltype return_label; - memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); - return return_label; - } + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } - inline void setExternalLabel(tableint internal_id, labeltype label) const { - memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); - } - inline labeltype *getExternalLabeLp(tableint internal_id) const { - return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); - } + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; - inline char *getDataByInternalId(tableint internal_id) const { - return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); - } + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; - int getRandomLevel(double reverse_size) { - std::uniform_real_distribution distribution(0.0, 1.0); - double r = -log(distribution(level_generator_)) * reverse_size; - return (int) r; + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); } + visited_array[ep_id] = visited_array_tag; - - std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayer(tableint ep_id, const void *data_point, int layer) { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - std::priority_queue, std::vector>, CompareByFirst> candidateSet; - - dist_t lowerBound; - if (!isMarkedDeleted(ep_id)) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); - top_candidates.emplace(dist, ep_id); - lowerBound = dist; - candidateSet.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidateSet.emplace(-lowerBound, ep_id); + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { + break; } - visited_array[ep_id] = visited_array_tag; + candidateSet.pop(); - while (!candidateSet.empty()) { - std::pair curr_el_pair = candidateSet.top(); - if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { - break; - } - candidateSet.pop(); + tableint curNodeNum = curr_el_pair.second; - tableint curNodeNum = curr_el_pair.second; + std::unique_lock lock(link_list_locks_[curNodeNum]); - std::unique_lock lock(link_list_locks_[curNodeNum]); - - int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); - if (layer == 0) { - data = (int*)get_linklist0(curNodeNum); - } else { - data = (int*)get_linklist(curNodeNum, layer); + int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); - } - size_t size = getListCount((linklistsizeint*)data); - tableint *datal = (tableint *) (data + 1); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); #endif - for (size_t j = 0; j < size; j++) { - tableint candidate_id = *(datal + j); + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); #endif - if (visited_array[candidate_id] == visited_array_tag) continue; - visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(candidate_id)); + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { - candidateSet.emplace(-dist1, candidate_id); + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); #endif - if (!isMarkedDeleted(candidate_id)) - top_candidates.emplace(dist1, candidate_id); + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; } } - visited_list_pool_->releaseVisitedList(vl); - - return top_candidates; + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; + bool is_filter_disabled = std::is_same::value; + if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); } - mutable std::atomic metric_distance_computations{0}; - mutable std::atomic metric_hops{0}; - - template - std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - std::priority_queue, std::vector>, CompareByFirst> candidate_set; - - dist_t lowerBound; - bool is_filter_disabled = std::is_same::value; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); - lowerBound = dist; - top_candidates.emplace(dist, ep_id); - candidate_set.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidate_set.emplace(-lowerBound, ep_id); - } - - visited_array[ep_id] = visited_array_tag; - - while (!candidate_set.empty()) { + visited_array[ep_id] = visited_array_tag; - std::pair current_node_pair = candidate_set.top(); + while (!candidate_set.empty()) { + std::pair current_node_pair = candidate_set.top(); - if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) { - break; - } - candidate_set.pop(); + if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) { + break; + } + candidate_set.pop(); - tableint current_node_id = current_node_pair.second; - int *data = (int *) get_linklist0(current_node_id); - size_t size = getListCount((linklistsizeint*)data); + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); // bool cur_node_deleted = isMarkedDeleted(current_node_id); - if(collect_metrics){ - metric_hops++; - metric_distance_computations+=size; - } + if (collect_metrics) { + metric_hops++; + metric_distance_computations+=size; + } #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); - _mm_prefetch((char *) (data + 2), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); #endif - for (size_t j = 1; j <= size; j++) { - int candidate_id = *(data + j); + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, - _MM_HINT_T0);//////////// + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0); //////////// #endif - if (!(visited_array[candidate_id] == visited_array_tag)) { - - visited_array[candidate_id] = visited_array_tag; + if (!(visited_array[candidate_id] == visited_array_tag)) { + visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + char *currObj1 = (getDataByInternalId(candidate_id)); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef || lowerBound > dist) { - candidate_set.emplace(-dist, candidate_id); + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE - _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + - offsetLevel0_,/////////// - _MM_HINT_T0);//////////////////////// + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_, /////////// + _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) - top_candidates.emplace(dist, candidate_id); + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) + top_candidates.emplace(dist, candidate_id); - if (top_candidates.size() > ef) - top_candidates.pop(); + if (top_candidates.size() > ef) + top_candidates.pop(); - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; } } } - - visited_list_pool_->releaseVisitedList(vl); - return top_candidates; } - void getNeighborsByHeuristic2( - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { - if (top_candidates.size() < M) { - return; - } + visited_list_pool_->releaseVisitedList(vl); + return top_candidates; + } - std::priority_queue> queue_closest; - std::vector> return_list; - while (top_candidates.size() > 0) { - queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); - top_candidates.pop(); - } - while (queue_closest.size()) { - if (return_list.size() >= M) + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { + if (top_candidates.size() < M) { + return; + } + + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(second_pair.second), + getDataByInternalId(curent_pair.second), + dist_func_param_); + if (curdist < dist_to_query) { + good = false; break; - std::pair curent_pair = queue_closest.top(); - dist_t dist_to_query = -curent_pair.first; - queue_closest.pop(); - bool good = true; - - for (std::pair second_pair : return_list) { - dist_t curdist = - fstdistfunc_(getDataByInternalId(second_pair.second), - getDataByInternalId(curent_pair.second), - dist_func_param_);; - if (curdist < dist_to_query) { - good = false; - break; - } - } - if (good) { - return_list.push_back(curent_pair); } } - - for (std::pair curent_pair : return_list) { - top_candidates.emplace(-curent_pair.first, curent_pair.second); + if (good) { + return_list.push_back(curent_pair); } } + for (std::pair curent_pair : return_list) { + top_candidates.emplace(-curent_pair.first, curent_pair.second); + } + } - linklistsizeint *get_linklist0(tableint internal_id) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - }; - linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - }; + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } - linklistsizeint *get_linklist(tableint internal_id, int level) const { - return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); - }; - linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { - return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); - }; + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } - tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - int level, bool isUpdate) { - size_t Mcurmax = level ? maxM_ : maxM0_; - getNeighborsByHeuristic2(top_candidates, M_); - if (top_candidates.size() > M_) - throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - std::vector selectedNeighbors; - selectedNeighbors.reserve(M_); - while (top_candidates.size() > 0) { - selectedNeighbors.push_back(top_candidates.top().second); - top_candidates.pop(); - } + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + } - tableint next_closest_entry_point = selectedNeighbors.back(); - { - linklistsizeint *ll_cur; - if (level == 0) - ll_cur = get_linklist0(cur_c); - else - ll_cur = get_linklist(cur_c, level); + linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { + return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); + } - if (*ll_cur && !isUpdate) { - throw std::runtime_error("The newly inserted element should have blank link list"); - } - setListCount(ll_cur,selectedNeighbors.size()); - tableint *data = (tableint *) (ll_cur + 1); - for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { - if (data[idx] && !isUpdate) - throw std::runtime_error("Possible memory corruption"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); - data[idx] = selectedNeighbors[idx]; + tableint mutuallyConnectNewElement( + const void *data_point, + tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, + bool isUpdate) { + size_t Mcurmax = level ? maxM_ : maxM0_; + getNeighborsByHeuristic2(top_candidates, M_); + if (top_candidates.size() > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - } - } + std::vector selectedNeighbors; + selectedNeighbors.reserve(M_); + while (top_candidates.size() > 0) { + selectedNeighbors.push_back(top_candidates.top().second); + top_candidates.pop(); + } + + tableint next_closest_entry_point = selectedNeighbors.back(); + { + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + + if (*ll_cur && !isUpdate) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur, selectedNeighbors.size()); + tableint *data = (tableint *) (ll_cur + 1); for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + if (data[idx] && !isUpdate) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); - std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + data[idx] = selectedNeighbors[idx]; + } + } - linklistsizeint *ll_other; - if (level == 0) - ll_other = get_linklist0(selectedNeighbors[idx]); - else - ll_other = get_linklist(selectedNeighbors[idx], level); + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); - size_t sz_link_list_other = getListCount(ll_other); + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); - if (sz_link_list_other > Mcurmax) - throw std::runtime_error("Bad value of sz_link_list_other"); - if (selectedNeighbors[idx] == cur_c) - throw std::runtime_error("Trying to connect an element to itself"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); + size_t sz_link_list_other = getListCount(ll_other); - tableint *data = (tableint *) (ll_other + 1); + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); - bool is_cur_c_present = false; - if (isUpdate) { - for (size_t j = 0; j < sz_link_list_other; j++) { - if (data[j] == cur_c) { - is_cur_c_present = true; - break; - } + tableint *data = (tableint *) (ll_other + 1); + + bool is_cur_c_present = false; + if (isUpdate) { + for (size_t j = 0; j < sz_link_list_other; j++) { + if (data[j] == cur_c) { + is_cur_c_present = true; + break; } } + } - // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. - if (!is_cur_c_present) { - if (sz_link_list_other < Mcurmax) { - data[sz_link_list_other] = cur_c; - setListCount(ll_other, sz_link_list_other + 1); - } else { - // finding the "weakest" element to replace it with the new one - dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_); - // Heuristic: - std::priority_queue, std::vector>, CompareByFirst> candidates; - candidates.emplace(d_max, cur_c); - - for (size_t j = 0; j < sz_link_list_other; j++) { - candidates.emplace( - fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_), data[j]); - } + // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. + if (!is_cur_c_present) { + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); - getNeighborsByHeuristic2(candidates, Mcurmax); + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } - int indx = 0; - while (candidates.size() > 0) { - data[indx] = candidates.top().second; - candidates.pop(); - indx++; - } + getNeighborsByHeuristic2(candidates, Mcurmax); - setListCount(ll_other, indx); - // Nearest K: - /*int indx = -1; - for (int j = 0; j < sz_link_list_other; j++) { - dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); - if (d > d_max) { - indx = j; - d_max = d; - } + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; } - if (indx >= 0) { - data[indx] = cur_c; - } */ } + if (indx >= 0) { + data[indx] = cur_c; + } */ } } - - return next_closest_entry_point; } - std::mutex global; - size_t ef_{0}; + return next_closest_entry_point; + } - void setEf(size_t ef) { - ef_ = ef; - } + void resizeIndex(size_t new_max_elements) { + if (new_max_elements < cur_element_count) + throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); - std::priority_queue> searchKnnInternal(void *query_data, int k) { - std::priority_queue> top_candidates; - if (cur_element_count == 0) return top_candidates; - tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + delete visited_list_pool_; + visited_list_pool_ = new VisitedListPool(1, new_max_elements); - for (size_t level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - int *data; - data = (int *) get_linklist(currObj,level); - int size = getListCount(data); - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + element_levels_.resize(new_max_elements); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } + std::vector(new_max_elements).swap(link_list_locks_); - if (num_deleted_) { - std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_); - top_candidates.swap(top_candidates1); - } - else{ - std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_); - top_candidates.swap(top_candidates1); - } + // Reallocate base layer + char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + data_level0_memory_ = data_level0_memory_new; - while (top_candidates.size() > k) { - top_candidates.pop(); - } - return top_candidates; - }; + // Reallocate all other layers + char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + linkLists_ = linkLists_new; - void resizeIndex(size_t new_max_elements){ - if (new_max_elements(new_max_elements).swap(link_list_locks_); + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); - // Reallocate base layer - char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); - if (data_level0_memory_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); - data_level0_memory_ = data_level0_memory_new; + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + // get file size: + input.seekg(0, input.end); + std::streampos total_filesize = input.tellg(); + input.seekg(0, input.beg); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements = max_elements_i; + if (max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos = input.tellg(); + + /// Optional - check if index is ok: + input.seekg(cur_element_count * size_data_per_element_, input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if (input.tellg() < 0 || input.tellg() >= total_filesize) { + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } - // Reallocate all other layers - char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); - if (linkLists_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); - linkLists_ = linkLists_new; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize, input.cur); + } + } - max_elements_ = new_max_elements; + // throw exception if it either corrupted or old index + if (input.tellg() != total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + /// Optional check end + + input.seekg(pos, input.beg); + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + std::vector(max_update_element_locks).swap(link_list_update_locks_); + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)] = i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } } - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; - - writeBinaryPOD(output, offsetLevel0_); - writeBinaryPOD(output, max_elements_); - writeBinaryPOD(output, cur_element_count); - writeBinaryPOD(output, size_data_per_element_); - writeBinaryPOD(output, label_offset_); - writeBinaryPOD(output, offsetData_); - writeBinaryPOD(output, maxlevel_); - writeBinaryPOD(output, enterpoint_node_); - writeBinaryPOD(output, maxM_); - - writeBinaryPOD(output, maxM0_); - writeBinaryPOD(output, M_); - writeBinaryPOD(output, mult_); - writeBinaryPOD(output, ef_construction_); - - output.write(data_level0_memory_, cur_element_count * size_data_per_element_); - - for (size_t i = 0; i < cur_element_count; i++) { - unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; - writeBinaryPOD(output, linkListSize); - if (linkListSize) - output.write(linkLists_[i], linkListSize); + for (size_t i = 0; i < cur_element_count; i++) { + if (isMarkedDeleted(i)) { + num_deleted_ += 1; } - output.close(); } - void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { - std::ifstream input(location, std::ios::binary); + input.close(); - if (!input.is_open()) - throw std::runtime_error("Cannot open file"); + return; + } - // get file size: - input.seekg(0,input.end); - std::streampos total_filesize=input.tellg(); - input.seekg(0,input.beg); - readBinaryPOD(input, offsetLevel0_); - readBinaryPOD(input, max_elements_); - readBinaryPOD(input, cur_element_count); + template + std::vector getDataByLabel(labeltype label) const { + tableint label_internal; + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + label_internal = search->second; + + char* data_ptrv = getDataByInternalId(label_internal); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } + + + /** + * Marks an element with the given label deleted, does NOT really change the current graph. + */ + void markDelete(labeltype label) { + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + markDeletedInternal(internalId); + } + + + /** + * Uses the last 16 bits of the memory for the linked list size to store the mark, + * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. + */ + void markDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (!isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + num_deleted_ += 1; + } else { + throw std::runtime_error("The requested to delete element is already deleted"); + } + } - size_t max_elements = max_elements_i; - if(max_elements < cur_element_count) - max_elements = max_elements_; - max_elements_ = max_elements; - readBinaryPOD(input, size_data_per_element_); - readBinaryPOD(input, label_offset_); - readBinaryPOD(input, offsetData_); - readBinaryPOD(input, maxlevel_); - readBinaryPOD(input, enterpoint_node_); - readBinaryPOD(input, maxM_); - readBinaryPOD(input, maxM0_); - readBinaryPOD(input, M_); - readBinaryPOD(input, mult_); - readBinaryPOD(input, ef_construction_); + /** + * Remove the deleted mark of the node, does NOT really change the current graph. + */ + void unmarkDelete(labeltype label) { + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + unmarkDeletedInternal(internalId); + } + + + /** + * Remove the deleted mark of the node. + */ + void unmarkDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; + *ll_cur &= ~DELETE_MARK; + num_deleted_ -= 1; + } else { + throw std::runtime_error("The requested to undelete element is not deleted"); + } + } - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); + /** + * Checks the first 16 bits of the memory to see if the element is marked deleted. + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; + return *ll_cur & DELETE_MARK; + } - auto pos=input.tellg(); + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } - /// Optional - check if index is ok: - input.seekg(cur_element_count * size_data_per_element_,input.cur); - for (size_t i = 0; i < cur_element_count; i++) { - if(input.tellg() < 0 || input.tellg()>=total_filesize){ - throw std::runtime_error("Index seems to be corrupted or unsupported"); - } + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize != 0) { - input.seekg(linkListSize,input.cur); - } - } - // throw exception if it either corrupted or old index - if(input.tellg()!=total_filesize) - throw std::runtime_error("Index seems to be corrupted or unsupported"); + /** + * Adds point. Updates the point if it is already in the index + */ + void addPoint(const void *data_point, labeltype label) { + addPoint(data_point, label, -1); + } - input.clear(); - /// Optional check end + void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { + // update the feature vector associated with existing point with new vector + memcpy(getDataByInternalId(internalId), dataPoint, data_size_); - input.seekg(pos,input.beg); + int maxLevelCopy = maxlevel_; + tableint entryPointCopy = enterpoint_node_; + // If point to be updated is entry point and graph just contains single element then just return. + if (entryPointCopy == internalId && cur_element_count == 1) + return; - data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + int elemLevel = element_levels_[internalId]; + std::uniform_real_distribution distribution(0.0, 1.0); + for (int layer = 0; layer <= elemLevel; layer++) { + std::unordered_set sCand; + std::unordered_set sNeigh; + std::vector listOneHop = getConnectionsWithLock(internalId, layer); + if (listOneHop.size() == 0) + continue; - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + sCand.insert(internalId); - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - std::vector(max_elements).swap(link_list_locks_); - std::vector(max_update_element_locks).swap(link_list_update_locks_); + for (auto&& elOneHop : listOneHop) { + sCand.insert(elOneHop); - visited_list_pool_ = new VisitedListPool(1, max_elements); + if (distribution(update_probability_generator_) > updateNeighborProbability) + continue; - linkLists_ = (char **) malloc(sizeof(void *) * max_elements); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); - element_levels_ = std::vector(max_elements); - revSize_ = 1.0 / mult_; - ef_ = 10; - for (size_t i = 0; i < cur_element_count; i++) { - label_lookup_[getExternalLabel(i)]=i; - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize == 0) { - element_levels_[i] = 0; + sNeigh.insert(elOneHop); - linkLists_[i] = nullptr; - } else { - element_levels_[i] = linkListSize / size_links_per_element_; - linkLists_[i] = (char *) malloc(linkListSize); - if (linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - input.read(linkLists_[i], linkListSize); + std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); + for (auto&& elTwoHop : listTwoHop) { + sCand.insert(elTwoHop); } } - for (size_t i = 0; i < cur_element_count; i++) { - if(isMarkedDeleted(i)) - num_deleted_ += 1; - } + for (auto&& neigh : sNeigh) { + // if (neigh == internalId) + // continue; - input.close(); + std::priority_queue, std::vector>, CompareByFirst> candidates; + size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 + size_t elementsToKeep = std::min(ef_construction_, size); + for (auto&& cand : sCand) { + if (cand == neigh) + continue; - return; - } + dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); + if (candidates.size() < elementsToKeep) { + candidates.emplace(distance, cand); + } else { + if (distance < candidates.top().first) { + candidates.pop(); + candidates.emplace(distance, cand); + } + } + } - template - std::vector getDataByLabel(labeltype label) const - { - tableint label_c; - auto search = label_lookup_.find(label); - if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { - throw std::runtime_error("Label not found"); - } - label_c = search->second; - - char* data_ptrv = getDataByInternalId(label_c); - size_t dim = *((size_t *) dist_func_param_); - std::vector data; - data_t* data_ptr = (data_t*) data_ptrv; - for (int i = 0; i < dim; i++) { - data.push_back(*data_ptr); - data_ptr += 1; - } - return data; - } + // Retrieve neighbours using heuristic and set connections. + getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); - static const unsigned char DELETE_MARK = 0x01; - // static const unsigned char REUSE_MARK = 0x10; - /** - * Marks an element with the given label deleted, does NOT really change the current graph. - * @param label - */ - void markDelete(labeltype label) - { - auto search = label_lookup_.find(label); - if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); + { + std::unique_lock lock(link_list_locks_[neigh]); + linklistsizeint *ll_cur; + ll_cur = get_linklist_at_level(neigh, layer); + size_t candSize = candidates.size(); + setListCount(ll_cur, candSize); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < candSize; idx++) { + data[idx] = candidates.top().second; + candidates.pop(); + } + } } - tableint internalId = search->second; - markDeletedInternal(internalId); } - /** - * Uses the first 8 bits of the memory for the linked list to store the mark, - * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. - * @param internalId - */ - void markDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (!isMarkedDeleted(internalId)) - { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur |= DELETE_MARK; - num_deleted_ += 1; - } - else - { - throw std::runtime_error("The requested to delete element is already deleted"); - } - } + repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); + } - /** - * Remove the deleted mark of the node, does NOT really change the current graph. - * @param label - */ - void unmarkDelete(labeltype label) - { - auto search = label_lookup_.find(label); - if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); - } - tableint internalId = search->second; - unmarkDeletedInternal(internalId); - } - /** - * Remove the deleted mark of the node. - * @param internalId - */ - void unmarkDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (isMarkedDeleted(internalId)) - { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur &= ~DELETE_MARK; - num_deleted_ -= 1; - } - else - { - throw std::runtime_error("The requested to undelete element is not deleted"); + void repairConnectionsForUpdate( + const void *dataPoint, + tableint entryPointInternalId, + tableint dataPointInternalId, + int dataPointLevel, + int maxLevel) { + tableint currObj = entryPointInternalId; + if (dataPointLevel < maxLevel) { + dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxLevel; level > dataPointLevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist_at_level(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); +#endif + for (int i = 0; i < size; i++) { +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); +#endif + tableint cand = datal[i]; + dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } } } - /** - * Checks the first 8 bits of the memory to see if the element is marked deleted. - * @param internalId - * @return - */ - bool isMarkedDeleted(tableint internalId) const { - unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; - return *ll_cur & DELETE_MARK; - } + if (dataPointLevel > maxLevel) + throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); - unsigned short int getListCount(linklistsizeint * ptr) const { - return *((unsigned short int *)ptr); - } + for (int level = dataPointLevel; level >= 0; level--) { + std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( + currObj, dataPoint, level); - void setListCount(linklistsizeint * ptr, unsigned short int size) const { - *((unsigned short int*)(ptr))=*((unsigned short int *)&size); - } + std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; + while (topCandidates.size() > 0) { + if (topCandidates.top().second != dataPointInternalId) + filteredTopCandidates.push(topCandidates.top()); - void addPoint(const void *data_point, labeltype label) { - addPoint(data_point, label,-1); + topCandidates.pop(); + } + + // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. + // To prevent self loops, the `topCandidates` is filtered and thus can be empty. + if (filteredTopCandidates.size() > 0) { + bool epDeleted = isMarkedDeleted(entryPointInternalId); + if (epDeleted) { + filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); + if (filteredTopCandidates.size() > ef_construction_) + filteredTopCandidates.pop(); + } + + currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + } } + } - void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { - // update the feature vector associated with existing point with new vector - memcpy(getDataByInternalId(internalId), dataPoint, data_size_); - - int maxLevelCopy = maxlevel_; - tableint entryPointCopy = enterpoint_node_; - // If point to be updated is entry point and graph just contains single element then just return. - if (entryPointCopy == internalId && cur_element_count == 1) - return; - - int elemLevel = element_levels_[internalId]; - std::uniform_real_distribution distribution(0.0, 1.0); - for (int layer = 0; layer <= elemLevel; layer++) { - std::unordered_set sCand; - std::unordered_set sNeigh; - std::vector listOneHop = getConnectionsWithLock(internalId, layer); - if (listOneHop.size() == 0) - continue; - sCand.insert(internalId); + std::vector getConnectionsWithLock(tableint internalId, int level) { + std::unique_lock lock(link_list_locks_[internalId]); + unsigned int *data = get_linklist_at_level(internalId, level); + int size = getListCount(data); + std::vector result(size); + tableint *ll = (tableint *) (data + 1); + memcpy(result.data(), ll, size * sizeof(tableint)); + return result; + } - for (auto&& elOneHop : listOneHop) { - sCand.insert(elOneHop); - if (distribution(update_probability_generator_) > updateNeighborProbability) - continue; + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; + { + // Checking if the element with the same label already exists + // if so, updating it *instead* of creating a new element. + std::unique_lock templock_curr(cur_element_count_guard_); + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + tableint existingInternalId = search->second; + templock_curr.unlock(); - sNeigh.insert(elOneHop); + std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); - std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); - for (auto&& elTwoHop : listTwoHop) { - sCand.insert(elTwoHop); - } + if (isMarkedDeleted(existingInternalId)) { + unmarkDeletedInternal(existingInternalId); } + updatePoint(data_point, existingInternalId, 1.0); - for (auto&& neigh : sNeigh) { - // if (neigh == internalId) - // continue; - - std::priority_queue, std::vector>, CompareByFirst> candidates; - size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 - size_t elementsToKeep = std::min(ef_construction_, size); - for (auto&& cand : sCand) { - if (cand == neigh) - continue; - - dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); - if (candidates.size() < elementsToKeep) { - candidates.emplace(distance, cand); - } else { - if (distance < candidates.top().first) { - candidates.pop(); - candidates.emplace(distance, cand); - } - } - } + return existingInternalId; + } - // Retrieve neighbours using heuristic and set connections. - getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); - - { - std::unique_lock lock(link_list_locks_[neigh]); - linklistsizeint *ll_cur; - ll_cur = get_linklist_at_level(neigh, layer); - size_t candSize = candidates.size(); - setListCount(ll_cur, candSize); - tableint *data = (tableint *) (ll_cur + 1); - for (size_t idx = 0; idx < candSize; idx++) { - data[idx] = candidates.top().second; - candidates.pop(); - } - } - } + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); } - repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); - }; + cur_c = cur_element_count; + cur_element_count++; + label_lookup_[label] = cur_c; + } - void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) { - tableint currObj = entryPointInternalId; - if (dataPointLevel < maxLevel) { - dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); - for (int level = maxLevel; level > dataPointLevel; level--) { + // Take update lock to prevent race conditions on an element with insertion/update at the same time. + std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + + element_levels_[cur_c] = curlevel; + + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + + if ((signed)currObj != -1) { + if (curlevel < maxlevelcopy) { + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { bool changed = true; while (changed) { changed = false; unsigned int *data; std::unique_lock lock(link_list_locks_[currObj]); - data = get_linklist_at_level(currObj,level); + data = get_linklist(currObj, level); int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); -#ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); -#endif for (int i = 0; i < size; i++) { -#ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); -#endif tableint cand = datal[i]; - dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { curdist = d; currObj = cand; @@ -945,262 +1053,121 @@ namespace hnswlib { } } - if (dataPointLevel > maxLevel) - throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); - - for (int level = dataPointLevel; level >= 0; level--) { - std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( - currObj, dataPoint, level); - - std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; - while (topCandidates.size() > 0) { - if (topCandidates.top().second != dataPointInternalId) - filteredTopCandidates.push(topCandidates.top()); - - topCandidates.pop(); - } - - // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. - // To prevent self loops, the `topCandidates` is filtered and thus can be empty. - if (filteredTopCandidates.size() > 0) { - bool epDeleted = isMarkedDeleted(entryPointInternalId); - if (epDeleted) { - filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); - if (filteredTopCandidates.size() > ef_construction_) - filteredTopCandidates.pop(); - } - - currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); } + currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); } + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; } - std::vector getConnectionsWithLock(tableint internalId, int level) { - std::unique_lock lock(link_list_locks_[internalId]); - unsigned int *data = get_linklist_at_level(internalId, level); - int size = getListCount(data); - std::vector result(size); - tableint *ll = (tableint *) (data + 1); - memcpy(result.data(), ll,size * sizeof(tableint)); - return result; - }; - - tableint addPoint(const void *data_point, labeltype label, int level) { - - tableint cur_c = 0; - { - // Checking if the element with the same label already exists - // if so, updating it *instead* of creating a new element. - std::unique_lock templock_curr(cur_element_count_guard_); - auto search = label_lookup_.find(label); - if (search != label_lookup_.end()) { - tableint existingInternalId = search->second; - templock_curr.unlock(); - - std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); - - if (isMarkedDeleted(existingInternalId)) { - unmarkDeletedInternal(existingInternalId); - } - updatePoint(data_point, existingInternalId, 1.0); - - return existingInternalId; - } - - if (cur_element_count >= max_elements_) { - throw std::runtime_error("The number of elements exceeds the specified limit"); - }; - - cur_c = cur_element_count; - cur_element_count++; - label_lookup_[label] = cur_c; - } - - // Take update lock to prevent race conditions on an element with insertion/update at the same time. - std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); - std::unique_lock lock_el(link_list_locks_[cur_c]); - int curlevel = getRandomLevel(mult_); - if (level > 0) - curlevel = level; - - element_levels_[cur_c] = curlevel; - - - std::unique_lock templock(global); - int maxlevelcopy = maxlevel_; - if (curlevel <= maxlevelcopy) - templock.unlock(); - tableint currObj = enterpoint_node_; - tableint enterpoint_copy = enterpoint_node_; + // Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + } - memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + std::priority_queue> + searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; - // Initialisation of the data and label - memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); - memcpy(getDataByInternalId(cur_c), data_point, data_size_); + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; - if (curlevel) { - linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); - if (linkLists_[cur_c] == nullptr) - throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); - memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); - } + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; - if ((signed)currObj != -1) { - - if (curlevel < maxlevelcopy) { - - dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); - for (int level = maxlevelcopy; level > curlevel; level--) { - - - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - std::unique_lock lock(link_list_locks_[currObj]); - data = get_linklist(currObj,level); - int size = getListCount(data); - - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } - } - - bool epDeleted = isMarkedDeleted(enterpoint_copy); - for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { - if (level > maxlevelcopy || level < 0) // possible? - throw std::runtime_error("Level error"); - - std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( - currObj, data_point, level); - if (epDeleted) { - top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; } - currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); } - - - } else { - // Do nothing for the first element - enterpoint_node_ = 0; - maxlevel_ = curlevel; - } + } - //Releasing lock for the maximum level - if (curlevel > maxlevelcopy) { - enterpoint_node_ = cur_c; - maxlevel_ = curlevel; - } - return cur_c; - }; - - std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { - std::priority_queue> result; - if (cur_element_count == 0) return result; - - tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); - - for (int level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - - data = (unsigned int *) get_linklist(currObj, level); - int size = getListCount(data); - metric_hops++; - metric_distance_computations+=size; - - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); - - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + if (num_deleted_) { + top_candidates = searchBaseLayerST( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } else { + top_candidates = searchBaseLayerST( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (num_deleted_) { - top_candidates=searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); - } - else{ - top_candidates=searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); - } + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); + result.push(std::pair(rez.first, getExternalLabel(rez.second))); + top_candidates.pop(); + } + return result; + } - while (top_candidates.size() > k) { - top_candidates.pop(); - } - while (top_candidates.size() > 0) { - std::pair rez = top_candidates.top(); - result.push(std::pair(rez.first, getExternalLabel(rez.second))); - top_candidates.pop(); - } - return result; - }; - - void checkIntegrity(){ - int connections_checked=0; - std::vector inbound_connections_num(cur_element_count,0); - for(int i = 0;i < cur_element_count; i++){ - for(int l = 0;l <= element_levels_[i]; l++){ - linklistsizeint *ll_cur = get_linklist_at_level(i,l); - int size = getListCount(ll_cur); - tableint *data = (tableint *) (ll_cur + 1); - std::unordered_set s; - for (int j=0; j 0); - assert(data[j] < cur_element_count); - assert (data[j] != i); - inbound_connections_num[data[j]]++; - s.insert(data[j]); - connections_checked++; - } - assert(s.size() == size); + void checkIntegrity() { + int connections_checked = 0; + std::vector inbound_connections_num(cur_element_count, 0); + for (int i = 0; i < cur_element_count; i++) { + for (int l = 0; l <= element_levels_[i]; l++) { + linklistsizeint *ll_cur = get_linklist_at_level(i, l); + int size = getListCount(ll_cur); + tableint *data = (tableint *) (ll_cur + 1); + std::unordered_set s; + for (int j = 0; j < size; j++) { + assert(data[j] > 0); + assert(data[j] < cur_element_count); + assert(data[j] != i); + inbound_connections_num[data[j]]++; + s.insert(data[j]); + connections_checked++; } + assert(s.size() == size); } - if(cur_element_count > 1){ - int min1=inbound_connections_num[0], max1=inbound_connections_num[0]; - for(int i=0; i < cur_element_count; i++){ - assert(inbound_connections_num[i] > 0); - min1=std::min(inbound_connections_num[i],min1); - max1=std::max(inbound_connections_num[i],max1); - } - std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; + } + if (cur_element_count > 1) { + int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0]; + for (int i=0; i < cur_element_count; i++) { + assert(inbound_connections_num[i] > 0); + min1 = std::min(inbound_connections_num[i], min1); + max1 = std::max(inbound_connections_num[i], max1); } - std::cout << "integrity ok, checked " << connections_checked << " connections\n"; - + std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; } - - }; - -} + std::cout << "integrity ok, checked " << connections_checked << " connections\n"; + } +}; +} // namespace hnswlib diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 1db5aabb..f11fd373 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -87,7 +87,7 @@ static bool AVX512Capable() { int nIds = cpuInfo[0]; bool HW_AVX512F = false; - if (nIds >= 0x00000007) { // AVX512 Foundation + if (nIds >= 0x00000007) { // AVX512 Foundation cpuid(cpuInfo, 0x00000007, 0); HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; } @@ -113,87 +113,87 @@ static bool AVX512Capable() { #include namespace hnswlib { - typedef size_t labeltype; +typedef size_t labeltype; - // This can be extended to store state for filtering (e.g. from a std::set) - struct FilterFunctor { - template - bool operator()(Args&&...) { return true; } - }; +// This can be extended to store state for filtering (e.g. from a std::set) +struct FilterFunctor { + template + bool operator()(Args&&...) { return true; } +}; - static FilterFunctor allowAllIds; +static FilterFunctor allowAllIds; - template - class pairGreater { - public: - bool operator()(const T& p1, const T& p2) { - return p1.first > p2.first; - } - }; - - template - static void writeBinaryPOD(std::ostream &out, const T &podRef) { - out.write((char *) &podRef, sizeof(T)); +template +class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; } +}; - template - static void readBinaryPOD(std::istream &in, T &podRef) { - in.read((char *) &podRef, sizeof(T)); - } +template +static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); +} + +template +static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); +} - template - using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); +template +using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); - template - class SpaceInterface { - public: - //virtual void search(void *); - virtual size_t get_data_size() = 0; +template +class SpaceInterface { + public: + // virtual void search(void *); + virtual size_t get_data_size() = 0; - virtual DISTFUNC get_dist_func() = 0; + virtual DISTFUNC get_dist_func() = 0; - virtual void *get_dist_func_param() = 0; + virtual void *get_dist_func_param() = 0; - virtual ~SpaceInterface() {} - }; + virtual ~SpaceInterface() {} +}; - template - class AlgorithmInterface { - public: - virtual void addPoint(const void *datapoint, labeltype label)=0; +template +class AlgorithmInterface { + public: + virtual void addPoint(const void *datapoint, labeltype label) = 0; - virtual std::priority_queue> - searchKnn(const void*, size_t, filter_func_t& isIdAllowed=allowAllIds) const = 0; + virtual std::priority_queue> + searchKnn(const void*, size_t, filter_func_t& isIdAllowed = allowAllIds) const = 0; - // Return k nearest neighbor in the order of closer fist - virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const; + // Return k nearest neighbor in the order of closer fist + virtual std::vector> + searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const; - virtual void saveIndex(const std::string &location)=0; - virtual ~AlgorithmInterface(){ - } - }; - - template - std::vector> - AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, - filter_func_t& isIdAllowed) const { - std::vector> result; - - // here searchKnn returns the result in the order of further first - auto ret = searchKnn(query_data, k, isIdAllowed); - { - size_t sz = ret.size(); - result.resize(sz); - while (!ret.empty()) { - result[--sz] = ret.top(); - ret.pop(); - } + virtual void saveIndex(const std::string &location) = 0; + virtual ~AlgorithmInterface(){ + } +}; + +template +std::vector> +AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + filter_func_t& isIdAllowed) const { + std::vector> result; + + // here searchKnn returns the result in the order of further first + auto ret = searchKnn(query_data, k, isIdAllowed); + { + size_t sz = ret.size(); + result.resize(sz); + while (!ret.empty()) { + result[--sz] = ret.top(); + ret.pop(); } - - return result; } + + return result; } +} // namespace hnswlib #include "space_l2.h" #include "space_ip.h" diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index d45a4c66..2b1c359e 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -3,374 +3,373 @@ namespace hnswlib { - static float - InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - float res = 0; - for (unsigned i = 0; i < qty; i++) { - res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; - } - return res; - +static float +InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; } + return res; +} - static float - InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) { - return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); - } +static float +InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) { + return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); +} #if defined(USE_AVX) // Favor using AVX if available. - static float - InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - size_t qty4 = qty / 4; - - const float *pEnd1 = pVect1 + 16 * qty16; - const float *pEnd2 = pVect1 + 4 * qty4; - - __m256 sum256 = _mm256_set1_ps(0); - - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - - __m256 v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - } +static float +InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } - __m128 v1, v2; - __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); + __m128 v1, v2; + __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); - while (pVect1 < pEnd2) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } - - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];; - return sum; - } - - static float - InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr); + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + return sum; +} + +static float +InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr); +} + #endif #if defined(USE_SSE) - static float - InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - size_t qty4 = qty / 4; - - const float *pEnd1 = pVect1 + 16 * qty16; - const float *pEnd2 = pVect1 + 4 * qty4; - - __m128 v1, v2; - __m128 sum_prod = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } +static float +InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } - while (pVect1 < pEnd2) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - return sum; - } + return sum; +} - static float - InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr); - } +static float +InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr); +} #endif #if defined(USE_AVX512) - static float - InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN64 TmpRes[16]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); +static float +InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN64 TmpRes[16]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty / 16; + size_t qty16 = qty / 16; - const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd1 = pVect1 + 16 * qty16; - __m512 sum512 = _mm512_set1_ps(0); + __m512 sum512 = _mm512_set1_ps(0); - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - __m512 v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; - __m512 v2 = _mm512_loadu_ps(pVect2); - pVect2 += 16; - sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2)); - } + __m512 v1 = _mm512_loadu_ps(pVect1); + pVect1 += 16; + __m512 v2 = _mm512_loadu_ps(pVect2); + pVect2 += 16; + sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2)); + } - _mm512_store_ps(TmpRes, sum512); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15]; + _mm512_store_ps(TmpRes, sum512); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15]; - return sum; - } + return sum; +} - static float - InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr); - } +static float +InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr); +} #endif #if defined(USE_AVX) - static float - InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); +static float +InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty / 16; + size_t qty16 = qty / 16; - const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd1 = pVect1 + 16 * qty16; - __m256 sum256 = _mm256_set1_ps(0); + __m256 sum256 = _mm256_set1_ps(0); - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - __m256 v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - } + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } - _mm256_store_ps(TmpRes, sum256); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + _mm256_store_ps(TmpRes, sum256); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; - return sum; - } + return sum; +} - static float - InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr); - } +static float +InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr); +} #endif #if defined(USE_SSE) - static float - InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - - const float *pEnd1 = pVect1 + 16 * qty16; - - __m128 v1, v2; - __m128 sum_prod = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - - return sum; +static float +InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - static float - InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr); - } + return sum; +} + +static float +InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr); +} #endif #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; - static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; - static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; - static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; - - static float - InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty >> 4 << 4; - float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); - float *pVect1 = (float *) pVect1v + qty16; - float *pVect2 = (float *) pVect2v + qty16; - - size_t qty_left = qty - qty16; - float res_tail = InnerProduct(pVect1, pVect2, &qty_left); - return 1.0f - (res + res_tail); - } +static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; +static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; +static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; +static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; + +static float +InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + return 1.0f - (res + res_tail); +} - static float - InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2 << 2; +static float +InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; - float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); - size_t qty_left = qty - qty4; + float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; - float *pVect1 = (float *) pVect1v + qty4; - float *pVect2 = (float *) pVect2v + qty4; - float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); - return 1.0f - (res + res_tail); - } + return 1.0f - (res + res_tail); +} #endif - class InnerProductSpace : public SpaceInterface { - - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - InnerProductSpace(size_t dim) { - fstdistfunc_ = InnerProductDistance; - #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) - #if defined(USE_AVX512) - if (AVX512Capable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; - } else if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; - } - #elif defined(USE_AVX) - if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; - } - #endif - #if defined(USE_AVX) - if (AVXCapable()) { - InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; - InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; - } - #endif - - if (dim % 16 == 0) - fstdistfunc_ = InnerProductDistanceSIMD16Ext; - else if (dim % 4 == 0) - fstdistfunc_ = InnerProductDistanceSIMD4Ext; - else if (dim > 16) - fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; - else if (dim > 4) - fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +class InnerProductSpace : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + InnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistance; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } #endif - dim_ = dim; - data_size_ = dim * sizeof(float); + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; } + #endif - size_t get_data_size() { - return data_size_; - } + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } - DISTFUNC get_dist_func() { - return fstdistfunc_; - } + size_t get_data_size() { + return data_size_; + } - void *get_dist_func_param() { - return &dim_; - } + DISTFUNC get_dist_func() { + return fstdistfunc_; + } - ~InnerProductSpace() {} - }; + void *get_dist_func_param() { + return &dim_; + } -} +~InnerProductSpace() {} +}; + +} // namespace hnswlib diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h index 355cc7b8..834d19f7 100644 --- a/hnswlib/space_l2.h +++ b/hnswlib/space_l2.h @@ -3,328 +3,322 @@ namespace hnswlib { - static float - L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - float res = 0; - for (size_t i = 0; i < qty; i++) { - float t = *pVect1 - *pVect2; - pVect1++; - pVect2++; - res += t * t; - } - return (res); +static float +L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + float res = 0; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; + res += t * t; } + return (res); +} #if defined(USE_AVX512) - // Favor using AVX512 if available. - static float - L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN64 TmpRes[16]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m512 diff, v1, v2; - __m512 sum = _mm512_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; - v2 = _mm512_loadu_ps(pVect2); - pVect2 += 16; - diff = _mm512_sub_ps(v1, v2); - // sum = _mm512_fmadd_ps(diff, diff, sum); - sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); - } +// Favor using AVX512 if available. +static float +L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN64 TmpRes[16]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m512 diff, v1, v2; + __m512 sum = _mm512_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm512_loadu_ps(pVect1); + pVect1 += 16; + v2 = _mm512_loadu_ps(pVect2); + pVect2 += 16; + diff = _mm512_sub_ps(v1, v2); + // sum = _mm512_fmadd_ps(diff, diff, sum); + sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); + } - _mm512_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + - TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + - TmpRes[13] + TmpRes[14] + TmpRes[15]; + _mm512_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + + TmpRes[13] + TmpRes[14] + TmpRes[15]; - return (res); + return (res); } #endif #if defined(USE_AVX) - // Favor using AVX if available. - static float - L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN32 TmpRes[8]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m256 diff, v1, v2; - __m256 sum = _mm256_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - diff = _mm256_sub_ps(v1, v2); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); - - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - diff = _mm256_sub_ps(v1, v2); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); - } - - _mm256_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; +// Favor using AVX if available. +static float +L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); } + _mm256_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; +} + #endif #if defined(USE_SSE) - static float - L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN32 TmpRes[8]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - } - - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +static float +L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } + + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} #endif #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; - - static float - L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty >> 4 << 4; - float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); - float *pVect1 = (float *) pVect1v + qty16; - float *pVect2 = (float *) pVect2v + qty16; - - size_t qty_left = qty - qty16; - float res_tail = L2Sqr(pVect1, pVect2, &qty_left); - return (res + res_tail); - } +static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; + +static float +L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + return (res + res_tail); +} #endif #if defined(USE_SSE) - static float - L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); +static float +L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2; + size_t qty4 = qty >> 2; - const float *pEnd1 = pVect1 + (qty4 << 2); + const float *pEnd1 = pVect1 + (qty4 << 2); - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - } - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} - static float - L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2 << 2; +static float +L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; - float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); - size_t qty_left = qty - qty4; + float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; - float *pVect1 = (float *) pVect1v + qty4; - float *pVect2 = (float *) pVect2v + qty4; - float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); - return (res + res_tail); - } + return (res + res_tail); +} #endif - class L2Space : public SpaceInterface { - - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - L2Space(size_t dim) { - fstdistfunc_ = L2Sqr; - #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - #if defined(USE_AVX512) - if (AVX512Capable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; - else if (AVXCapable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; - #elif defined(USE_AVX) - if (AVXCapable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; - #endif - - if (dim % 16 == 0) - fstdistfunc_ = L2SqrSIMD16Ext; - else if (dim % 4 == 0) - fstdistfunc_ = L2SqrSIMD4Ext; - else if (dim > 16) - fstdistfunc_ = L2SqrSIMD16ExtResiduals; - else if (dim > 4) - fstdistfunc_ = L2SqrSIMD4ExtResiduals; +class L2Space : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + L2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; #endif - dim_ = dim; - data_size_ = dim * sizeof(float); - } - size_t get_data_size() { - return data_size_; - } + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } - DISTFUNC get_dist_func() { - return fstdistfunc_; - } + size_t get_data_size() { + return data_size_; + } - void *get_dist_func_param() { - return &dim_; - } + DISTFUNC get_dist_func() { + return fstdistfunc_; + } - ~L2Space() {} - }; - - static int - L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { - - size_t qty = *((size_t *) qty_ptr); - int res = 0; - unsigned char *a = (unsigned char *) pVect1; - unsigned char *b = (unsigned char *) pVect2; - - qty = qty >> 2; - for (size_t i = 0; i < qty; i++) { - - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - } - return (res); + void *get_dist_func_param() { + return &dim_; } - static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { - size_t qty = *((size_t*)qty_ptr); - int res = 0; - unsigned char* a = (unsigned char*)pVect1; - unsigned char* b = (unsigned char*)pVect2; - - for(size_t i = 0; i < qty; i++) - { - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - } - return (res); + ~L2Space() {} +}; + +static int +L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + int res = 0; + unsigned char *a = (unsigned char *) pVect1; + unsigned char *b = (unsigned char *) pVect2; + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; } + return (res); +} - class L2SpaceI : public SpaceInterface { - - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - L2SpaceI(size_t dim) { - if(dim % 4 == 0) { - fstdistfunc_ = L2SqrI4x; - } - else { - fstdistfunc_ = L2SqrI; - } - dim_ = dim; - data_size_ = dim * sizeof(unsigned char); - } +static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + int res = 0; + unsigned char* a = (unsigned char*)pVect1; + unsigned char* b = (unsigned char*)pVect2; - size_t get_data_size() { - return data_size_; - } + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + } + return (res); +} - DISTFUNC get_dist_func() { - return fstdistfunc_; +class L2SpaceI : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + L2SpaceI(size_t dim) { + if (dim % 4 == 0) { + fstdistfunc_ = L2SqrI4x; + } else { + fstdistfunc_ = L2SqrI; } + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); + } - void *get_dist_func_param() { - return &dim_; - } + size_t get_data_size() { + return data_size_; + } - ~L2SpaceI() {} - }; + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + void *get_dist_func_param() { + return &dim_; + } -} + ~L2SpaceI() {} +}; +} // namespace hnswlib diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h index 5e1a4a58..2e201ec4 100644 --- a/hnswlib/visited_list_pool.h +++ b/hnswlib/visited_list_pool.h @@ -5,75 +5,74 @@ #include namespace hnswlib { - typedef unsigned short int vl_type; +typedef unsigned short int vl_type; - class VisitedList { - public: - vl_type curV; - vl_type *mass; - unsigned int numelements; +class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; - VisitedList(int numelements1) { - curV = -1; - numelements = numelements1; - mass = new vl_type[numelements]; - } + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } - void reset() { + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); curV++; - if (curV == 0) { - memset(mass, 0, sizeof(vl_type) * numelements); - curV++; - } - }; + } + } - ~VisitedList() { delete[] mass; } - }; + ~VisitedList() { delete[] mass; } +}; /////////////////////////////////////////////////////////// // // Class for multi-threaded pool-management of VisitedLists // ///////////////////////////////////////////////////////// - class VisitedListPool { - std::deque pool; - std::mutex poolguard; - int numelements; - - public: - VisitedListPool(int initmaxpools, int numelements1) { - numelements = numelements1; - for (int i = 0; i < initmaxpools; i++) - pool.push_front(new VisitedList(numelements)); - } +class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; - VisitedList *getFreeVisitedList() { - VisitedList *rez; - { - std::unique_lock lock(poolguard); - if (pool.size() > 0) { - rez = pool.front(); - pool.pop_front(); - } else { - rez = new VisitedList(numelements); - } - } - rez->reset(); - return rez; - }; + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } - void releaseVisitedList(VisitedList *vl) { + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { std::unique_lock lock(poolguard); - pool.push_front(vl); - }; - - ~VisitedListPool() { - while (pool.size()) { - VisitedList *rez = pool.front(); + if (pool.size() > 0) { + rez = pool.front(); pool.pop_front(); - delete rez; + } else { + rez = new VisitedList(numelements); } - }; - }; -} + } + rez->reset(); + return rez; + } + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + } + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + } +}; +} // namespace hnswlib diff --git a/main.cpp b/main.cpp index 6c8acc9b..bf0fc2bf 100644 --- a/main.cpp +++ b/main.cpp @@ -5,4 +5,4 @@ int main() { sift_test1B(); return 0; -}; \ No newline at end of file +} diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index a72b5b21..fcb444da 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -9,7 +9,7 @@ #include namespace py = pybind11; -using namespace pybind11::literals; // needed to bring in _a literal +using namespace pybind11::literals; // needed to bring in _a literal /* * replacement for the openmp '#pragma omp parallel for' directive @@ -74,60 +74,63 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn inline void assert_true(bool expr, const std::string & msg) { - if (expr == false) - throw std::runtime_error("Unpickle Error: "+msg); + if (expr == false) throw std::runtime_error("Unpickle Error: " + msg); return; } -template +template class Index { -public: - Index(const std::string &space_name, const int dim) : - space_name(space_name), dim(dim) { - normalize=false; - if(space_name=="l2") { - l2space = new hnswlib::L2Space(dim); - } - else if(space_name=="ip") { - l2space = new hnswlib::InnerProductSpace(dim); - } - else if(space_name=="cosine") { - l2space = new hnswlib::InnerProductSpace(dim); - normalize=true; - } else { - throw std::runtime_error("Space name must be one of l2, ip, or cosine."); - } - appr_alg = NULL; - ep_added = true; - index_inited = false; - num_threads_default = std::thread::hardware_concurrency(); + public: + static const int ser_version = 1; // serialization version + + std::string space_name; + int dim; + size_t seed; + size_t default_ef; - default_ef=10; - } + bool index_inited; + bool ep_added; + bool normalize; + int num_threads_default; + hnswlib::labeltype cur_l; + hnswlib::HierarchicalNSW* appr_alg; + hnswlib::SpaceInterface* l2space; + + + Index(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) { + normalize = false; + if (space_name == "l2") { + l2space = new hnswlib::L2Space(dim); + } else if (space_name == "ip") { + l2space = new hnswlib::InnerProductSpace(dim); + } else if (space_name == "cosine") { + l2space = new hnswlib::InnerProductSpace(dim); + normalize = true; + } else { + throw std::runtime_error("Space name must be one of l2, ip, or cosine."); + } + appr_alg = NULL; + ep_added = true; + index_inited = false; + num_threads_default = std::thread::hardware_concurrency(); - static const int ser_version = 1; // serialization version + default_ef = 10; + } - std::string space_name; - int dim; - size_t seed; - size_t default_ef; - bool index_inited; - bool ep_added; - bool normalize; - int num_threads_default; - hnswlib::labeltype cur_l; - hnswlib::HierarchicalNSW *appr_alg; - hnswlib::SpaceInterface *l2space; + ~Index() { + delete l2space; + if (appr_alg) + delete appr_alg; + } - ~Index() { - delete l2space; - if (appr_alg) - delete appr_alg; - } - void init_new_index(const size_t maxElements, const size_t M, const size_t efConstruction, const size_t random_seed) { + void init_new_index( + size_t maxElements, + size_t M, + size_t efConstruction, + size_t random_seed) { if (appr_alg) { throw std::runtime_error("The index is already initiated."); } @@ -136,23 +139,27 @@ class Index { index_inited = true; ep_added = false; appr_alg->ef_ = default_ef; - seed=random_seed; + seed = random_seed; } + void set_ef(size_t ef) { - default_ef=ef; + default_ef = ef; if (appr_alg) appr_alg->ef_ = ef; } + void set_num_threads(int num_threads) { this->num_threads_default = num_threads; } + void saveIndex(const std::string &path_to_index) { appr_alg->saveIndex(path_to_index); } + void loadIndex(const std::string &path_to_index, size_t max_elements) { if (appr_alg) { std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl; @@ -163,15 +170,17 @@ class Index { index_inited = true; } - void normalize_vector(float *data, float *norm_array){ - float norm=0.0f; - for(int i=0;i items(input); auto buffer = items.request(); @@ -184,8 +193,7 @@ class Index { if (buffer.ndim == 2) { rows = buffer.shape[0]; features = buffer.shape[1]; - } - else{ + } else { rows = 1; features = buffer.shape[0]; } @@ -193,9 +201,9 @@ class Index { if (features != dim) throw std::runtime_error("wrong dimensionality of the vectors"); - // avoid using threads when the number of searches is small: - if(rows<=num_threads*4){ - num_threads=1; + // avoid using threads when the number of additions is small: + if (rows <= num_threads * 4) { + num_threads = 1; } std::vector ids; @@ -203,58 +211,56 @@ class Index { if (!ids_.is_none()) { py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); auto ids_numpy = items.request(); - if(ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { + if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { std::vector ids1(ids_numpy.shape[0]); for (size_t i = 0; i < ids1.size(); i++) { ids1[i] = items.data()[i]; } ids.swap(ids1); - } - else if(ids_numpy.ndim == 0 && rows == 1) { + } else if (ids_numpy.ndim == 0 && rows == 1) { ids.push_back(*items.data()); - } - else + } else { throw std::runtime_error("wrong dimensionality of the labels"); + } } - { - - int start = 0; - if (!ep_added) { - size_t id = ids.size() ? ids.at(0) : (cur_l); - float *vector_data = (float *) items.data(0); - std::vector norm_array(dim); - if(normalize){ - normalize_vector(vector_data, norm_array.data()); - vector_data = norm_array.data(); + int start = 0; + if (!ep_added) { + size_t id = ids.size() ? ids.at(0) : (cur_l); + float* vector_data = (float*)items.data(0); + std::vector norm_array(dim); + if (normalize) { + normalize_vector(vector_data, norm_array.data()); + vector_data = norm_array.data(); + } + appr_alg->addPoint((void*)vector_data, (size_t)id); + start = 1; + ep_added = true; } - appr_alg->addPoint((void *) vector_data, (size_t) id); - start = 1; - ep_added = true; - } py::gil_scoped_release l; - if(normalize==false) { + if (normalize == false) { ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { - size_t id = ids.size() ? ids.at(row) : (cur_l+row); - appr_alg->addPoint((void *) items.data(row), (size_t) id); - }); - } else{ + size_t id = ids.size() ? ids.at(row) : (cur_l + row); + appr_alg->addPoint((void*)items.data(row), (size_t)id); + }); + } else { std::vector norm_array(num_threads * dim); ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { // normalize vector: size_t start_idx = threadId * dim; - normalize_vector((float *) items.data(row), (norm_array.data()+start_idx)); + normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); - size_t id = ids.size() ? ids.at(row) : (cur_l+row); - appr_alg->addPoint((void *) (norm_array.data()+start_idx), (size_t) id); - }); - }; - cur_l+=rows; + size_t id = ids.size() ? ids.at(row) : (cur_l + row); + appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id); + }); + } + cur_l += rows; } } + std::vector> getDataReturnList(py::object ids_ = py::none()) { std::vector ids; if (!ids_.is_none()) { @@ -262,13 +268,13 @@ class Index { auto ids_numpy = items.request(); if (ids_numpy.ndim == 0) { - throw std::invalid_argument("get_items accepts a list of indices and returns a list of vectors"); + throw std::invalid_argument("get_items accepts a list of indices and returns a list of vectors"); } else { - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; - } - ids.swap(ids1); + std::vector ids1(ids_numpy.shape[0]); + for (size_t i = 0; i < ids1.size(); i++) { + ids1[i] = items.data()[i]; + } + ids.swap(ids1); } } @@ -279,10 +285,11 @@ class Index { return data; } + std::vector getIdsList() { std::vector ids; - for(auto kv : appr_alg->label_lookup_) { + for (auto kv : appr_alg->label_lookup_) { ids.push_back(kv.first); } return ids; @@ -290,133 +297,131 @@ class Index { py::dict getAnnData() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */ - std::unique_lock templock(appr_alg->global); + std::unique_lock templock(appr_alg->global); - size_t level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_; - size_t link_npy_size = 0; - std::vector link_npy_offsets(appr_alg->cur_element_count); + size_t level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_; + size_t link_npy_size = 0; + std::vector link_npy_offsets(appr_alg->cur_element_count); - for (size_t i = 0; i < appr_alg->cur_element_count; i++){ - size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; - link_npy_offsets[i]=link_npy_size; - if (linkListSize) - link_npy_size += linkListSize; - } + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + link_npy_offsets[i] = link_npy_size; + if (linkListSize) + link_npy_size += linkListSize; + } - char* data_level0_npy = (char *) malloc(level0_npy_size); - char* link_list_npy = (char *) malloc(link_npy_size); - int* element_levels_npy = (int *) malloc(appr_alg->element_levels_.size()*sizeof(int)); + char* data_level0_npy = (char*)malloc(level0_npy_size); + char* link_list_npy = (char*)malloc(link_npy_size); + int* element_levels_npy = (int*)malloc(appr_alg->element_levels_.size() * sizeof(int)); - hnswlib::labeltype* label_lookup_key_npy = (hnswlib::labeltype *) malloc(appr_alg->label_lookup_.size()*sizeof(hnswlib::labeltype)); - hnswlib::tableint* label_lookup_val_npy = (hnswlib::tableint *) malloc(appr_alg->label_lookup_.size()*sizeof(hnswlib::tableint)); + hnswlib::labeltype* label_lookup_key_npy = (hnswlib::labeltype*)malloc(appr_alg->label_lookup_.size() * sizeof(hnswlib::labeltype)); + hnswlib::tableint* label_lookup_val_npy = (hnswlib::tableint*)malloc(appr_alg->label_lookup_.size() * sizeof(hnswlib::tableint)); - memset(label_lookup_key_npy, -1, appr_alg->label_lookup_.size()*sizeof(hnswlib::labeltype)); - memset(label_lookup_val_npy, -1, appr_alg->label_lookup_.size()*sizeof(hnswlib::tableint)); + memset(label_lookup_key_npy, -1, appr_alg->label_lookup_.size() * sizeof(hnswlib::labeltype)); + memset(label_lookup_val_npy, -1, appr_alg->label_lookup_.size() * sizeof(hnswlib::tableint)); - size_t idx=0; - for ( auto it = appr_alg->label_lookup_.begin(); it != appr_alg->label_lookup_.end(); ++it ){ - label_lookup_key_npy[idx]= it->first; - label_lookup_val_npy[idx]= it->second; - idx++; - } + size_t idx = 0; + for (auto it = appr_alg->label_lookup_.begin(); it != appr_alg->label_lookup_.end(); ++it) { + label_lookup_key_npy[idx] = it->first; + label_lookup_val_npy[idx] = it->second; + idx++; + } - memset(link_list_npy, 0, link_npy_size); + memset(link_list_npy, 0, link_npy_size); - memcpy(data_level0_npy, appr_alg->data_level0_memory_, level0_npy_size); - memcpy(element_levels_npy, appr_alg->element_levels_.data(), appr_alg->element_levels_.size() * sizeof(int)); + memcpy(data_level0_npy, appr_alg->data_level0_memory_, level0_npy_size); + memcpy(element_levels_npy, appr_alg->element_levels_.data(), appr_alg->element_levels_.size() * sizeof(int)); - for (size_t i = 0; i < appr_alg->cur_element_count; i++){ - size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; - if (linkListSize){ - memcpy(link_list_npy+link_npy_offsets[i], appr_alg->linkLists_[i], linkListSize); + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + if (linkListSize) { + memcpy(link_list_npy + link_npy_offsets[i], appr_alg->linkLists_[i], linkListSize); + } } - } - py::capsule free_when_done_l0(data_level0_npy, [](void *f) { - delete[] f; - }); - py::capsule free_when_done_lvl(element_levels_npy, [](void *f) { - delete[] f; - }); - py::capsule free_when_done_lb(label_lookup_key_npy, [](void *f) { - delete[] f; - }); - py::capsule free_when_done_id(label_lookup_val_npy, [](void *f) { - delete[] f; - }); - py::capsule free_when_done_ll(link_list_npy, [](void *f) { - delete[] f; - }); - - /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */ - /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */ - - return py::dict( - "offset_level0"_a=appr_alg->offsetLevel0_, - "max_elements"_a=appr_alg->max_elements_, - "cur_element_count"_a=appr_alg->cur_element_count, - "size_data_per_element"_a=appr_alg->size_data_per_element_, - "label_offset"_a=appr_alg->label_offset_, - "offset_data"_a=appr_alg->offsetData_, - "max_level"_a=appr_alg->maxlevel_, - "enterpoint_node"_a=appr_alg->enterpoint_node_, - "max_M"_a=appr_alg->maxM_, - "max_M0"_a=appr_alg->maxM0_, - "M"_a=appr_alg->M_, - "mult"_a=appr_alg->mult_, - "ef_construction"_a=appr_alg->ef_construction_, - "ef"_a=appr_alg->ef_, - "has_deletions"_a=(bool)appr_alg->num_deleted_, - "size_links_per_element"_a=appr_alg->size_links_per_element_, - - "label_lookup_external"_a=py::array_t( - {appr_alg->label_lookup_.size()}, // shape - {sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double - label_lookup_key_npy, // the data pointer - free_when_done_lb), - - "label_lookup_internal"_a=py::array_t( - {appr_alg->label_lookup_.size()}, // shape - {sizeof(hnswlib::tableint)}, // C-style contiguous strides for double - label_lookup_val_npy, // the data pointer - free_when_done_id), - - "element_levels"_a=py::array_t( - {appr_alg->element_levels_.size()}, // shape - {sizeof(int)}, // C-style contiguous strides for double - element_levels_npy, // the data pointer - free_when_done_lvl), - - // linkLists_,element_levels_,data_level0_memory_ - "data_level0"_a=py::array_t( - {level0_npy_size}, // shape - {sizeof(char)}, // C-style contiguous strides for double - data_level0_npy, // the data pointer - free_when_done_l0), - - "link_lists"_a=py::array_t( - {link_npy_size}, // shape - {sizeof(char)}, // C-style contiguous strides for double - link_list_npy, // the data pointer - free_when_done_ll) - ); + py::capsule free_when_done_l0(data_level0_npy, [](void* f) { + delete[] f; + }); + py::capsule free_when_done_lvl(element_levels_npy, [](void* f) { + delete[] f; + }); + py::capsule free_when_done_lb(label_lookup_key_npy, [](void* f) { + delete[] f; + }); + py::capsule free_when_done_id(label_lookup_val_npy, [](void* f) { + delete[] f; + }); + py::capsule free_when_done_ll(link_list_npy, [](void* f) { + delete[] f; + }); + + /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */ + /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */ + + return py::dict( + "offset_level0"_a = appr_alg->offsetLevel0_, + "max_elements"_a = appr_alg->max_elements_, + "cur_element_count"_a = appr_alg->cur_element_count, + "size_data_per_element"_a = appr_alg->size_data_per_element_, + "label_offset"_a = appr_alg->label_offset_, + "offset_data"_a = appr_alg->offsetData_, + "max_level"_a = appr_alg->maxlevel_, + "enterpoint_node"_a = appr_alg->enterpoint_node_, + "max_M"_a = appr_alg->maxM_, + "max_M0"_a = appr_alg->maxM0_, + "M"_a = appr_alg->M_, + "mult"_a = appr_alg->mult_, + "ef_construction"_a = appr_alg->ef_construction_, + "ef"_a = appr_alg->ef_, + "has_deletions"_a = (bool)appr_alg->num_deleted_, + "size_links_per_element"_a = appr_alg->size_links_per_element_, + + "label_lookup_external"_a = py::array_t( + { appr_alg->label_lookup_.size() }, // shape + { sizeof(hnswlib::labeltype) }, // C-style contiguous strides for each index + label_lookup_key_npy, // the data pointer + free_when_done_lb), + + "label_lookup_internal"_a = py::array_t( + { appr_alg->label_lookup_.size() }, // shape + { sizeof(hnswlib::tableint) }, // C-style contiguous strides for each index + label_lookup_val_npy, // the data pointer + free_when_done_id), + + "element_levels"_a = py::array_t( + { appr_alg->element_levels_.size() }, // shape + { sizeof(int) }, // C-style contiguous strides for each index + element_levels_npy, // the data pointer + free_when_done_lvl), + + // linkLists_,element_levels_,data_level0_memory_ + "data_level0"_a = py::array_t( + { level0_npy_size }, // shape + { sizeof(char) }, // C-style contiguous strides for each index + data_level0_npy, // the data pointer + free_when_done_l0), + + "link_lists"_a = py::array_t( + { link_npy_size }, // shape + { sizeof(char) }, // C-style contiguous strides for each index + link_list_npy, // the data pointer + free_when_done_ll)); } py::dict getIndexParams() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */ auto params = py::dict( - "ser_version"_a=py::int_(Index::ser_version), //serialization version - "space"_a=space_name, - "dim"_a=dim, - "index_inited"_a=index_inited, - "ep_added"_a=ep_added, - "normalize"_a=normalize, - "num_threads"_a=num_threads_default, - "seed"_a=seed - ); - - if(index_inited == false) - return py::dict( **params, "ef"_a=default_ef); + "ser_version"_a = py::int_(Index::ser_version), // serialization version + "space"_a = space_name, + "dim"_a = dim, + "index_inited"_a = index_inited, + "ep_added"_a = ep_added, + "normalize"_a = normalize, + "num_threads"_a = num_threads_default, + "seed"_a = seed); + + if (index_inited == false) + return py::dict(**params, "ef"_a = default_ef); auto ann_params = getAnnData(); @@ -424,125 +429,131 @@ class Index { } - static Index * createFromParams(const py::dict d) { - // check serialization version - assert_true(((int)py::int_(Index::ser_version)) >= d["ser_version"].cast(), "Invalid serialization version!"); + static Index* createFromParams(const py::dict d) { + // check serialization version + assert_true(((int)py::int_(Index::ser_version)) >= d["ser_version"].cast(), "Invalid serialization version!"); - auto space_name_=d["space"].cast(); - auto dim_=d["dim"].cast(); - auto index_inited_=d["index_inited"].cast(); + auto space_name_ = d["space"].cast(); + auto dim_ = d["dim"].cast(); + auto index_inited_ = d["index_inited"].cast(); - Index *new_index = new Index(space_name_, dim_); + Index* new_index = new Index(space_name_, dim_); - /* TODO: deserialize state of random generators into new_index->level_generator_ and new_index->update_probability_generator_ */ - /* for full reproducibility / state of generators is serialized inside Index::getIndexParams */ - new_index->seed = d["seed"].cast(); + /* TODO: deserialize state of random generators into new_index->level_generator_ and new_index->update_probability_generator_ */ + /* for full reproducibility / state of generators is serialized inside Index::getIndexParams */ + new_index->seed = d["seed"].cast(); - if (index_inited_){ - new_index->appr_alg = new hnswlib::HierarchicalNSW(new_index->l2space, d["max_elements"].cast(), d["M"].cast(), d["ef_construction"].cast(), new_index->seed); - new_index->cur_l = d["cur_element_count"].cast(); - } + if (index_inited_) { + new_index->appr_alg = new hnswlib::HierarchicalNSW( + new_index->l2space, + d["max_elements"].cast(), + d["M"].cast(), + d["ef_construction"].cast(), + new_index->seed); + new_index->cur_l = d["cur_element_count"].cast(); + } - new_index->index_inited = index_inited_; - new_index->ep_added=d["ep_added"].cast(); - new_index->num_threads_default=d["num_threads"].cast(); - new_index->default_ef=d["ef"].cast(); + new_index->index_inited = index_inited_; + new_index->ep_added = d["ep_added"].cast(); + new_index->num_threads_default = d["num_threads"].cast(); + new_index->default_ef = d["ef"].cast(); - if (index_inited_) - new_index->setAnnData(d); + if (index_inited_) + new_index->setAnnData(d); - return new_index; + return new_index; } + static Index * createFromIndex(const Index & index) { return createFromParams(index.getIndexParams()); } + void setAnnData(const py::dict d) { /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */ - std::unique_lock templock(appr_alg->global); + std::unique_lock templock(appr_alg->global); - assert_true(appr_alg->offsetLevel0_ == d["offset_level0"].cast(), "Invalid value of offsetLevel0_ "); - assert_true(appr_alg->max_elements_ == d["max_elements"].cast(), "Invalid value of max_elements_ "); + assert_true(appr_alg->offsetLevel0_ == d["offset_level0"].cast(), "Invalid value of offsetLevel0_ "); + assert_true(appr_alg->max_elements_ == d["max_elements"].cast(), "Invalid value of max_elements_ "); - appr_alg->cur_element_count = d["cur_element_count"].cast(); + appr_alg->cur_element_count = d["cur_element_count"].cast(); - assert_true(appr_alg->size_data_per_element_ == d["size_data_per_element"].cast(), "Invalid value of size_data_per_element_ "); - assert_true(appr_alg->label_offset_ == d["label_offset"].cast(), "Invalid value of label_offset_ "); - assert_true(appr_alg->offsetData_ == d["offset_data"].cast(), "Invalid value of offsetData_ "); + assert_true(appr_alg->size_data_per_element_ == d["size_data_per_element"].cast(), "Invalid value of size_data_per_element_ "); + assert_true(appr_alg->label_offset_ == d["label_offset"].cast(), "Invalid value of label_offset_ "); + assert_true(appr_alg->offsetData_ == d["offset_data"].cast(), "Invalid value of offsetData_ "); - appr_alg->maxlevel_ = d["max_level"].cast(); - appr_alg->enterpoint_node_ = d["enterpoint_node"].cast(); + appr_alg->maxlevel_ = d["max_level"].cast(); + appr_alg->enterpoint_node_ = d["enterpoint_node"].cast(); - assert_true(appr_alg->maxM_ == d["max_M"].cast(), "Invalid value of maxM_ "); - assert_true(appr_alg->maxM0_ == d["max_M0"].cast(), "Invalid value of maxM0_ "); - assert_true(appr_alg->M_ == d["M"].cast(), "Invalid value of M_ "); - assert_true(appr_alg->mult_ == d["mult"].cast(), "Invalid value of mult_ "); - assert_true(appr_alg->ef_construction_ == d["ef_construction"].cast(), "Invalid value of ef_construction_ "); + assert_true(appr_alg->maxM_ == d["max_M"].cast(), "Invalid value of maxM_ "); + assert_true(appr_alg->maxM0_ == d["max_M0"].cast(), "Invalid value of maxM0_ "); + assert_true(appr_alg->M_ == d["M"].cast(), "Invalid value of M_ "); + assert_true(appr_alg->mult_ == d["mult"].cast(), "Invalid value of mult_ "); + assert_true(appr_alg->ef_construction_ == d["ef_construction"].cast(), "Invalid value of ef_construction_ "); - appr_alg->ef_ = d["ef"].cast(); + appr_alg->ef_ = d["ef"].cast(); - assert_true(appr_alg->size_links_per_element_ == d["size_links_per_element"].cast(), "Invalid value of size_links_per_element_ "); + assert_true(appr_alg->size_links_per_element_ == d["size_links_per_element"].cast(), "Invalid value of size_links_per_element_ "); - auto label_lookup_key_npy = d["label_lookup_external"].cast >(); - auto label_lookup_val_npy = d["label_lookup_internal"].cast >(); - auto element_levels_npy = d["element_levels"].cast >(); - auto data_level0_npy = d["data_level0"].cast >(); - auto link_list_npy = d["link_lists"].cast >(); + auto label_lookup_key_npy = d["label_lookup_external"].cast >(); + auto label_lookup_val_npy = d["label_lookup_internal"].cast >(); + auto element_levels_npy = d["element_levels"].cast >(); + auto data_level0_npy = d["data_level0"].cast >(); + auto link_list_npy = d["link_lists"].cast >(); - for (size_t i = 0; i < appr_alg->cur_element_count; i++){ - if (label_lookup_val_npy.data()[i] < 0){ - throw std::runtime_error("internal id cannot be negative!"); - } - else{ - appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i])); + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + if (label_lookup_val_npy.data()[i] < 0) { + throw std::runtime_error("internal id cannot be negative!"); + } else { + appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i])); + } } - } - memcpy(appr_alg->element_levels_.data(), element_levels_npy.data(), element_levels_npy.nbytes()); + memcpy(appr_alg->element_levels_.data(), element_levels_npy.data(), element_levels_npy.nbytes()); - size_t link_npy_size = 0; - std::vector link_npy_offsets(appr_alg->cur_element_count); - - for (size_t i = 0; i < appr_alg->cur_element_count; i++){ - size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; - link_npy_offsets[i]=link_npy_size; - if (linkListSize) - link_npy_size += linkListSize; - } + size_t link_npy_size = 0; + std::vector link_npy_offsets(appr_alg->cur_element_count); - memcpy(appr_alg->data_level0_memory_, data_level0_npy.data(), data_level0_npy.nbytes()); + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + link_npy_offsets[i] = link_npy_size; + if (linkListSize) + link_npy_size += linkListSize; + } - for (size_t i = 0; i < appr_alg->max_elements_; i++) { - size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; - if (linkListSize == 0) { - appr_alg->linkLists_[i] = nullptr; - } else { - appr_alg->linkLists_[i] = (char *) malloc(linkListSize); - if (appr_alg->linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + memcpy(appr_alg->data_level0_memory_, data_level0_npy.data(), data_level0_npy.nbytes()); - memcpy(appr_alg->linkLists_[i], link_list_npy.data()+link_npy_offsets[i], linkListSize); + for (size_t i = 0; i < appr_alg->max_elements_; i++) { + size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + if (linkListSize == 0) { + appr_alg->linkLists_[i] = nullptr; + } else { + appr_alg->linkLists_[i] = (char*)malloc(linkListSize); + if (appr_alg->linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - } - } + memcpy(appr_alg->linkLists_[i], link_list_npy.data() + link_npy_offsets[i], linkListSize); + } + } - // set num_deleted - appr_alg->num_deleted_ = 0; - bool has_deletions = d["has_deletions"].cast(); - if (has_deletions) - { - for (size_t i = 0; i < appr_alg->cur_element_count; i++) { - if(appr_alg->isMarkedDeleted(i)) - appr_alg->num_deleted_ += 1; + // process deleted elements + appr_alg->num_deleted_ = 0; + bool has_deletions = d["has_deletions"].cast(); + if (has_deletions) { + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + if (appr_alg->isMarkedDeleted(i)) { + appr_alg->num_deleted_ += 1; + } + } } - } -} + } + py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); - hnswlib::labeltype *data_numpy_l; - dist_t *data_numpy_d; + hnswlib::labeltype* data_numpy_l; + dist_t* data_numpy_d; size_t rows, features; if (num_threads <= 0) @@ -555,118 +566,126 @@ class Index { if (buffer.ndim == 2) { rows = buffer.shape[0]; features = buffer.shape[1]; - } - else{ + } else { rows = 1; features = buffer.shape[0]; } // avoid using threads when the number of searches is small: - - if(rows<=num_threads*4){ - num_threads=1; + if (rows <= num_threads * 4) { + num_threads = 1; } data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; - if(normalize==false) { + if (normalize == false) { ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { - std::priority_queue> result = appr_alg->searchKnn( - (void *) items.data(row), k); - if (result.size() != k) - throw std::runtime_error( - "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); - for (int i = k - 1; i >= 0; i--) { - auto &result_tuple = result.top(); - data_numpy_d[row * k + i] = result_tuple.first; - data_numpy_l[row * k + i] = result_tuple.second; - result.pop(); - } - } - ); - } - else{ - std::vector norm_array(num_threads*features); + std::priority_queue> result = appr_alg->searchKnn( + (void*)items.data(row), k); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); + } else { + std::vector norm_array(num_threads * features); ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { - float *data= (float *) items.data(row); - - size_t start_idx = threadId * dim; - normalize_vector((float *) items.data(row), (norm_array.data()+start_idx)); - - std::priority_queue> result = appr_alg->searchKnn( - (void *) (norm_array.data()+start_idx), k); - if (result.size() != k) - throw std::runtime_error( - "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); - for (int i = k - 1; i >= 0; i--) { - auto &result_tuple = result.top(); - data_numpy_d[row * k + i] = result_tuple.first; - data_numpy_l[row * k + i] = result_tuple.second; - result.pop(); - } - } - ); + float* data = (float*)items.data(row); + + size_t start_idx = threadId * dim; + normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); + + std::priority_queue> result = appr_alg->searchKnn( + (void*)(norm_array.data() + start_idx), k); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); } } - py::capsule free_when_done_l(data_numpy_l, [](void *f) { + py::capsule free_when_done_l(data_numpy_l, [](void* f) { delete[] f; - }); - py::capsule free_when_done_d(data_numpy_d, [](void *f) { + }); + py::capsule free_when_done_d(data_numpy_d, [](void* f) { delete[] f; - }); + }); return py::make_tuple( - py::array_t( - {rows, k}, // shape - {k * sizeof(hnswlib::labeltype), - sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double - data_numpy_l, // the data pointer - free_when_done_l), - py::array_t( - {rows, k}, // shape - {k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for double - data_numpy_d, // the data pointer - free_when_done_d)); - + py::array_t( + { rows, k }, // shape + { k * sizeof(hnswlib::labeltype), + sizeof(hnswlib::labeltype) }, // C-style contiguous strides for each index + data_numpy_l, // the data pointer + free_when_done_l), + py::array_t( + { rows, k }, // shape + { k * sizeof(dist_t), sizeof(dist_t) }, // C-style contiguous strides for each index + data_numpy_d, // the data pointer + free_when_done_d)); } + void markDeleted(size_t label) { appr_alg->markDelete(label); } + void unmarkDeleted(size_t label) { appr_alg->unmarkDelete(label); } + void resizeIndex(size_t new_size) { appr_alg->resizeIndex(new_size); } + size_t getMaxElements() const { return appr_alg->max_elements_; } + size_t getCurrentCount() const { return appr_alg->cur_element_count; } }; -template +template class BFIndex { -public: - BFIndex(const std::string &space_name, const int dim) : - space_name(space_name), dim(dim) { - normalize=false; - if(space_name=="l2") { + public: + static const int ser_version = 1; // serialization version + + std::string space_name; + int dim; + bool index_inited; + bool normalize; + + hnswlib::labeltype cur_l; + hnswlib::BruteforceSearch* alg; + hnswlib::SpaceInterface* space; + + + BFIndex(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) { + normalize = false; + if (space_name == "l2") { space = new hnswlib::L2Space(dim); - } - else if(space_name=="ip") { + } else if (space_name == "ip") { space = new hnswlib::InnerProductSpace(dim); - } - else if(space_name=="cosine") { + } else if (space_name == "cosine") { space = new hnswlib::InnerProductSpace(dim); - normalize=true; + normalize = true; } else { throw std::runtime_error("Space name must be one of l2, ip, or cosine."); } @@ -674,16 +693,6 @@ class BFIndex { index_inited = false; } - static const int ser_version = 1; // serialization version - - std::string space_name; - int dim; - bool index_inited; - bool normalize; - - hnswlib::labeltype cur_l; - hnswlib::BruteforceSearch *alg; - hnswlib::SpaceInterface *space; ~BFIndex() { delete space; @@ -691,6 +700,7 @@ class BFIndex { delete alg; } + void init_new_index(const size_t maxElements) { if (alg) { throw std::runtime_error("The index is already initiated."); @@ -700,15 +710,17 @@ class BFIndex { index_inited = true; } - void normalize_vector(float *data, float *norm_array){ - float norm=0.0f; - for(int i=0;i items(input); auto buffer = items.request(); @@ -739,11 +751,12 @@ class BFIndex { ids.swap(ids1); } else if (ids_numpy.ndim == 0 && rows == 1) { ids.push_back(*items.data()); - } else + } else { throw std::runtime_error("wrong dimensionality of the labels"); + } } - { + { for (size_t row = 0; row < rows; row++) { size_t id = ids.size() ? ids.at(row) : cur_l + row; if (!normalize) { @@ -758,14 +771,17 @@ class BFIndex { } } + void deleteVector(size_t label) { alg->removePoint(label); } + void saveIndex(const std::string &path_to_index) { alg->saveIndex(path_to_index); } + void loadIndex(const std::string &path_to_index, size_t max_elements) { if (alg) { std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl; @@ -776,8 +792,8 @@ class BFIndex { index_inited = true; } - py::object knnQuery_return_numpy(py::object input, size_t k = 1) { + py::object knnQuery_return_numpy(py::object input, size_t k = 1) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); hnswlib::labeltype *data_numpy_l; @@ -820,21 +836,20 @@ class BFIndex { return py::make_tuple( py::array_t( - {rows, k}, // shape + {rows, k}, // shape {k * sizeof(hnswlib::labeltype), - sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double - data_numpy_l, // the data pointer + sizeof(hnswlib::labeltype)}, // C-style contiguous strides for each index + data_numpy_l, // the data pointer free_when_done_l), py::array_t( - {rows, k}, // shape - {k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for double - data_numpy_d, // the data pointer + {rows, k}, // shape + {k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for each index + data_numpy_d, // the data pointer free_when_done_d)); - } - }; + PYBIND11_PLUGIN(hnswlib) { py::module m("hnswlib"); @@ -843,15 +858,15 @@ PYBIND11_PLUGIN(hnswlib) { /* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */ .def(py::init(&Index::createFromIndex), py::arg("index")) .def(py::init(), py::arg("space"), py::arg("dim")) - .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M")=16, py::arg("ef_construction")=200, py::arg("random_seed")=100) - .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1, py::arg("num_threads")=-1) - .def("add_items", &Index::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads")=-1) + .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M") = 16, py::arg("ef_construction") = 200, py::arg("random_seed") = 100) + .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1) + .def("add_items", &Index::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads") = -1) .def("get_items", &Index::getDataReturnList, py::arg("ids") = py::none()) .def("get_ids_list", &Index::getIdsList) .def("set_ef", &Index::set_ef, py::arg("ef")) .def("set_num_threads", &Index::set_num_threads, py::arg("num_threads")) .def("save_index", &Index::saveIndex, py::arg("path_to_index")) - .def("load_index", &Index::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0) + .def("load_index", &Index::loadIndex, py::arg("path_to_index"), py::arg("max_elements") = 0) .def("mark_deleted", &Index::markDeleted, py::arg("label")) .def("unmark_deleted", &Index::unmarkDeleted, py::arg("label")) .def("resize_index", &Index::resizeIndex, py::arg("new_size")) @@ -865,7 +880,7 @@ PYBIND11_PLUGIN(hnswlib) { return index.index_inited ? index.appr_alg->ef_ : index.default_ef; }, [](Index & index, const size_t ef_) { - index.default_ef=ef_; + index.default_ef = ef_; if (index.appr_alg) index.appr_alg->ef_ = ef_; }) @@ -883,16 +898,14 @@ PYBIND11_PLUGIN(hnswlib) { }) .def(py::pickle( - [](const Index &ind) { // __getstate__ + [](const Index &ind) { // __getstate__ return py::make_tuple(ind.getIndexParams()); /* Return dict (wrapped in a tuple) that fully encodes state of the Index object */ }, - [](py::tuple t) { // __setstate__ + [](py::tuple t) { // __setstate__ if (t.size() != 1) throw std::runtime_error("Invalid state!"); - return Index::createFromParams(t[0].cast()); - } - )) + })) .def("__repr__", [](const Index &a) { return ""; @@ -901,11 +914,11 @@ PYBIND11_PLUGIN(hnswlib) { py::class_>(m, "BFIndex") .def(py::init(), py::arg("space"), py::arg("dim")) .def("init_index", &BFIndex::init_new_index, py::arg("max_elements")) - .def("knn_query", &BFIndex::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1) + .def("knn_query", &BFIndex::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1) .def("add_items", &BFIndex::addItems, py::arg("data"), py::arg("ids") = py::none()) .def("delete_vector", &BFIndex::deleteVector, py::arg("label")) .def("save_index", &BFIndex::saveIndex, py::arg("path_to_index")) - .def("load_index", &BFIndex::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0) + .def("load_index", &BFIndex::loadIndex, py::arg("path_to_index"), py::arg("max_elements") = 0) .def("__repr__", [](const BFIndex &a) { return ""; }); diff --git a/sift_1b.cpp b/sift_1b.cpp index 2739490c..96d83267 100644 --- a/sift_1b.cpp +++ b/sift_1b.cpp @@ -12,7 +12,7 @@ using namespace hnswlib; class StopW { std::chrono::steady_clock::time_point time_begin; -public: + public: StopW() { time_begin = std::chrono::steady_clock::now(); } @@ -25,7 +25,6 @@ class StopW { void reset() { time_begin = std::chrono::steady_clock::now(); } - }; @@ -80,8 +79,7 @@ static size_t getPeakRSS() { int fd = -1; if ((fd = open("/proc/self/psinfo", O_RDONLY)) == -1) return (size_t)0L; /* Can't open? */ - if (read(fd, &psinfo, sizeof(psinfo)) != sizeof(psinfo)) - { + if (read(fd, &psinfo, sizeof(psinfo)) != sizeof(psinfo)) { close(fd); return (size_t)0L; /* Can't read? */ } @@ -146,10 +144,16 @@ static size_t getCurrentRSS() { static void -get_gt(unsigned int *massQA, unsigned char *massQ, unsigned char *mass, size_t vecsize, size_t qsize, L2SpaceI &l2space, - size_t vecdim, vector>> &answers, size_t k) { - - +get_gt( + unsigned int *massQA, + unsigned char *massQ, + unsigned char *mass, + size_t vecsize, + size_t qsize, + L2SpaceI &l2space, + size_t vecdim, + vector>> &answers, + size_t k) { (vector>>(qsize)).swap(answers); DISTFUNC fstdistfunc_ = l2space.get_dist_func(); cout << qsize << "\n"; @@ -161,43 +165,50 @@ get_gt(unsigned int *massQA, unsigned char *massQ, unsigned char *mass, size_t v } static float -test_approx(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW &appr_alg, size_t vecdim, - vector>> &answers, size_t k) { +test_approx( + unsigned char *massQ, + size_t vecsize, + size_t qsize, + HierarchicalNSW &appr_alg, + size_t vecdim, + vector>> &answers, + size_t k) { size_t correct = 0; size_t total = 0; - //uncomment to test in parallel mode: + // uncomment to test in parallel mode: //#pragma omp parallel for for (int i = 0; i < qsize; i++) { - std::priority_queue> result = appr_alg.searchKnn(massQ + vecdim * i, k); std::priority_queue> gt(answers[i]); unordered_set g; total += gt.size(); while (gt.size()) { - - g.insert(gt.top().second); gt.pop(); } while (result.size()) { if (g.find(result.top().second) != g.end()) { - correct++; } else { } result.pop(); } - } return 1.0f * correct / total; } static void -test_vs_recall(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW &appr_alg, size_t vecdim, - vector>> &answers, size_t k) { - vector efs;// = { 10,10,10,10,10 }; +test_vs_recall( + unsigned char *massQ, + size_t vecsize, + size_t qsize, + HierarchicalNSW &appr_alg, + size_t vecdim, + vector>> &answers, + size_t k) { + vector efs; // = { 10,10,10,10,10 }; for (int i = k; i < 30; i++) { efs.push_back(i); } @@ -229,12 +240,9 @@ inline bool exists_test(const std::string &name) { void sift_test1B() { - - - int subset_size_milllions = 200; - int efConstruction = 40; - int M = 16; - + int subset_size_milllions = 200; + int efConstruction = 40; + int M = 16; size_t vecsize = subset_size_milllions * 1000000; @@ -248,7 +256,6 @@ void sift_test1B() { sprintf(path_gt, "../bigann/gnd/idx_%dM.ivecs", subset_size_milllions); - unsigned char *massb = new unsigned char[vecdim]; cout << "Loading GT:\n"; @@ -264,7 +271,7 @@ void sift_test1B() { } } inputGT.close(); - + cout << "Loading queries:\n"; unsigned char *massQ = new unsigned char[qsize * vecdim]; ifstream inputQ(path_q, ios::binary); @@ -280,7 +287,6 @@ void sift_test1B() { for (int j = 0; j < vecdim; j++) { massQ[i * vecdim + j] = massb[j]; } - } inputQ.close(); @@ -299,7 +305,6 @@ void sift_test1B() { cout << "Building index:\n"; appr_alg = new HierarchicalNSW(&l2space, vecsize, M, efConstruction); - input.read((char *) &in, 4); if (in != 128) { cout << "file error"; @@ -319,10 +324,9 @@ void sift_test1B() { #pragma omp parallel for for (int i = 1; i < vecsize; i++) { unsigned char mass[128]; - int j2=0; + int j2 = 0; #pragma omp critical { - input.read((char *) &in, 4); if (in != 128) { cout << "file error"; @@ -333,7 +337,7 @@ void sift_test1B() { mass[j] = massb[j]; } j1++; - j2=j1; + j2 = j1; if (j1 % report_every == 0) { cout << j1 / (0.01 * vecsize) << " %, " << report_every / (1000.0 * 1e-6 * stopw.getElapsedTimeMicro()) << " kips " << " Mem: " @@ -342,8 +346,6 @@ void sift_test1B() { } } appr_alg->addPoint((void *) (mass), (size_t) j2); - - } input.close(); cout << "Build time:" << 1e-6 * stopw_full.getElapsedTimeMicro() << " seconds\n"; @@ -360,6 +362,4 @@ void sift_test1B() { test_vs_recall(massQ, vecsize, qsize, *appr_alg, vecdim, answers, k); cout << "Actual memory usage: " << getCurrentRSS() / 1000000 << " Mb \n"; return; - - } diff --git a/sift_test.cpp b/sift_test.cpp index c6718f50..751580cb 100644 --- a/sift_test.cpp +++ b/sift_test.cpp @@ -22,7 +22,7 @@ static void readBinaryPOD(istream& in, T& podRef) { }*/ class StopW { std::chrono::steady_clock::time_point time_begin; -public: + public: StopW() { time_begin = std::chrono::steady_clock::now(); } @@ -35,11 +35,17 @@ class StopW { void reset() { time_begin = std::chrono::steady_clock::now(); } - }; -void get_gt(float *mass, float *massQ, size_t vecsize, size_t qsize, L2Space &l2space, size_t vecdim, - vector>> &answers, size_t k) { +void get_gt( + float *mass, + float *massQ, + size_t vecsize, + size_t qsize, + L2Space &l2space, + size_t vecdim, + vector>> &answers, + size_t k) { BruteforceSearch bs(&l2space, vecsize); for (int i = 0; i < vecsize; i++) { bs.addPoint((void *) (mass + vecdim * i), (size_t) i); @@ -53,9 +59,16 @@ void get_gt(float *mass, float *massQ, size_t vecsize, size_t qsize, L2Space &l2 } void -get_gt(unsigned int *massQA, float *massQ, float *mass, size_t vecsize, size_t qsize, L2Space &l2space, size_t vecdim, - vector>> &answers, size_t k) { - +get_gt( + unsigned int *massQA, + float *massQ, + float *mass, + size_t vecsize, + size_t qsize, + L2Space &l2space, + size_t vecdim, + vector>> &answers, + size_t k) { //answers.swap(vector>>(qsize)); (vector>>(qsize)).swap(answers); DISTFUNC fstdistfunc_ = l2space.get_dist_func(); @@ -69,13 +82,18 @@ get_gt(unsigned int *massQA, float *massQ, float *mass, size_t vecsize, size_t q } } -float test_approx(float *massQ, size_t vecsize, size_t qsize, HierarchicalNSW &appr_alg, size_t vecdim, - vector>> &answers, size_t k) { +float test_approx( + float *massQ, + size_t vecsize, + size_t qsize, + HierarchicalNSW &appr_alg, + size_t vecdim, + vector>> &answers, + size_t k) { size_t correct = 0; size_t total = 0; //#pragma omp parallel for for (int i = 0; i < qsize; i++) { - std::priority_queue> result = appr_alg.searchKnn(massQ + vecdim * i, 10); std::priority_queue> gt(answers[i]); unordered_set g; @@ -93,8 +111,14 @@ float test_approx(float *massQ, size_t vecsize, size_t qsize, HierarchicalNSW &appr_alg, size_t vecdim, - vector>> &answers, size_t k) { +void test_vs_recall( + float *massQ, + size_t vecsize, + size_t qsize, + HierarchicalNSW &appr_alg, + size_t vecdim, + vector>> &answers, + size_t k) { //vector efs = { 1,2,3,4,6,8,12,16,24,32,64,128,256,320 };// = ; { 23 }; vector efs; for (int i = 10; i < 30; i++) { @@ -121,7 +145,7 @@ void test_vs_recall(float *massQ, size_t vecsize, size_t qsize, HierarchicalNSW< } //void get_knn_quality(unsigned int *massA,size_t vecsize, size_t maxn, HierarchicalNSW &appr_alg) { // size_t total = 0; -// size_t correct = 0; +// size_t correct = 0; // for (int i = 0; i < vecsize; i++) { // int *data = (int *)(appr_alg.linkList0_ + i * appr_alg.size_links_per_element0_); // //cout << "numconn:" << *data<<"\n"; @@ -186,7 +210,7 @@ void sift_test() { //#define LOAD_I #ifdef LOAD_I - HierarchicalNSW appr_alg(&l2space, "hnswlib_sift",false); + HierarchicalNSW appr_alg(&l2space, "hnswlib_sift", false); //HierarchicalNSW appr_alg(&l2space, "D:/stuff/hnsw_lib/nmslib/similarity_search/release/temp",true); //HierarchicalNSW appr_alg(&l2space, "/mnt/d/stuff/hnsw_lib/nmslib/similarity_search/release/temp", true); @@ -243,7 +267,7 @@ void sift_test() { // // cout << appr_alg.maxlevel_ << "\n"; // //CHECK: -// //for (size_t io = 0; io < vecsize; io++) { +// //for (size_t io = 0; io < vecsize; io++) { // // if (appr_alg.getExternalLabel(io) != io) // // throw new exception("bad!"); // //} @@ -252,22 +276,22 @@ void sift_test() { // for (int i = 0; i < vecsize; i++) { // int *data = (int *)(appr_alg.linkList0_ + i * appr_alg.size_links_per_element0_); // //cout << "numconn:" << *data<<"\n"; -// tableint *datal = (tableint *)(data + 1); +// tableint *datal = (tableint *)(data + 1); // // std::priority_queue< std::pair< float, tableint >> rez; // unordered_set g; // for (int j = 0; j < *data; j++) { // g.insert(datal[j]); // } -// appr_alg.setEf(400); +// appr_alg.setEf(400); // std::priority_queue< std::pair< float, tableint >> closest_elements = appr_alg.searchKnnInternal(appr_alg.getDataByInternalId(i), 17); -// while (closest_elements.size() > 0) { +// while (closest_elements.size() > 0) { // if (closest_elements.top().second != i) { // g.insert(closest_elements.top().second); // } // closest_elements.pop(); // } -// +// // for (tableint l : g) { // float other = fstdistfunc_(appr_alg.getDataByInternalId(l), appr_alg.getDataByInternalId(i), l2space.get_dist_func_param()); // rez.emplace(other, l); @@ -285,18 +309,18 @@ void sift_test() { // } // // } -// +// // //get_knn_quality(massA, vecsize, maxn, appr_alg); // test_vs_recall( massQ, vecsize, qsize, appr_alg, vecdim, answers, k); // /*test_vs_recall( massQ, vecsize, qsize, appr_alg, vecdim, answers, k); // test_vs_recall( massQ, vecsize, qsize, appr_alg, vecdim, answers, k); // test_vs_recall( massQ, vecsize, qsize, appr_alg, vecdim, answers, k);*/ // -// +// // // // // /*for(int i=0;i<1000;i++) // cout << mass[i] << "\n";*/ // //("11", std::ios::binary); -} \ No newline at end of file +} From 4ab1d619338f11819084b61ea2dd90400a74ccb5 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Tue, 20 Sep 2022 08:44:02 +0300 Subject: [PATCH 29/41] Remove some code duplication in bindings (#416) * Remove some code duplication in bindings * Refactoring --- python_bindings/bindings.cpp | 141 ++++++++++++++++------------------- 1 file changed, 63 insertions(+), 78 deletions(-) diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index fcb444da..85751c0b 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -42,7 +42,7 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn while (true) { size_t id = current.fetch_add(1); - if ((id >= end)) { + if (id >= end) { break; } @@ -79,6 +79,54 @@ inline void assert_true(bool expr, const std::string & msg) { } +inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) { + if (buffer.ndim != 2 && buffer.ndim != 1) { + char msg[256]; + snprintf(msg, sizeof(msg), + "Input vector data wrong shape. Number of dimensions %d. Data must be a 1D or 2D array.", + buffer.ndim); + throw std::runtime_error(msg); + } + if (buffer.ndim == 2) { + *rows = buffer.shape[0]; + *features = buffer.shape[1]; + } else { + *rows = 1; + *features = buffer.shape[0]; + } +} + + +inline std::vector get_input_ids_and_check_shapes(const py::object& ids_, size_t feature_rows) { + std::vector ids; + if (!ids_.is_none()) { + py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); + auto ids_numpy = items.request(); + // check shapes + if (!((ids_numpy.ndim == 1 && ids_numpy.shape[0] == feature_rows) || + (ids_numpy.ndim == 0 && feature_rows == 1))) { + char msg[256]; + snprintf(msg, sizeof(msg), + "The input label shape %d does not match the input data vector shape %d", + ids_numpy.ndim, feature_rows); + throw std::runtime_error(msg); + } + // extract data + if (ids_numpy.ndim == 1) { + std::vector ids1(ids_numpy.shape[0]); + for (size_t i = 0; i < ids1.size(); i++) { + ids1[i] = items.data()[i]; + } + ids.swap(ids1); + } else if (ids_numpy.ndim == 0) { + ids.push_back(*items.data()); + } + } + + return ids; +} + + template class Index { public: @@ -146,7 +194,7 @@ class Index { void set_ef(size_t ef) { default_ef = ef; if (appr_alg) - appr_alg->ef_ = ef; + appr_alg->ef_ = ef; } @@ -188,41 +236,17 @@ class Index { num_threads = num_threads_default; size_t rows, features; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); if (features != dim) - throw std::runtime_error("wrong dimensionality of the vectors"); + throw std::runtime_error("Wrong dimensionality of the vectors"); // avoid using threads when the number of additions is small: if (rows <= num_threads * 4) { num_threads = 1; } - std::vector ids; - - if (!ids_.is_none()) { - py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); - auto ids_numpy = items.request(); - if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; - } - ids.swap(ids1); - } else if (ids_numpy.ndim == 0 && rows == 1) { - ids.push_back(*items.data()); - } else { - throw std::runtime_error("wrong dimensionality of the labels"); - } - } + std::vector ids = get_input_ids_and_check_shapes(ids_, rows); { int start = 0; @@ -503,7 +527,7 @@ class Index { for (size_t i = 0; i < appr_alg->cur_element_count; i++) { if (label_lookup_val_npy.data()[i] < 0) { - throw std::runtime_error("internal id cannot be negative!"); + throw std::runtime_error("Internal id cannot be negative!"); } else { appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i])); } @@ -561,15 +585,7 @@ class Index { { py::gil_scoped_release l; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); // avoid using threads when the number of searches is small: if (rows <= num_threads * 4) { @@ -725,36 +741,12 @@ class BFIndex { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); size_t rows, features; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); if (features != dim) - throw std::runtime_error("wrong dimensionality of the vectors"); + throw std::runtime_error("Wrong dimensionality of the vectors"); - std::vector ids; - - if (!ids_.is_none()) { - py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); - auto ids_numpy = items.request(); - if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; - } - ids.swap(ids1); - } else if (ids_numpy.ndim == 0 && rows == 1) { - ids.push_back(*items.data()); - } else { - throw std::runtime_error("wrong dimensionality of the labels"); - } - } + std::vector ids = get_input_ids_and_check_shapes(ids_, rows); { for (size_t row = 0; row < rows; row++) { @@ -802,14 +794,7 @@ class BFIndex { { py::gil_scoped_release l; - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; @@ -836,14 +821,14 @@ class BFIndex { return py::make_tuple( py::array_t( - {rows, k}, // shape - {k * sizeof(hnswlib::labeltype), + { rows, k }, // shape + { k * sizeof(hnswlib::labeltype), sizeof(hnswlib::labeltype)}, // C-style contiguous strides for each index data_numpy_l, // the data pointer free_when_done_l), py::array_t( - {rows, k}, // shape - {k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for each index + { rows, k }, // shape + { k * sizeof(dist_t), sizeof(dist_t) }, // C-style contiguous strides for each index data_numpy_d, // the data pointer free_when_done_d)); } From 687ca85e7a8c77419a8335bca55e1b00c2576788 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Tue, 20 Sep 2022 08:49:44 +0300 Subject: [PATCH 30/41] Update python recall test (#415) --- .github/workflows/build.yml | 2 +- python_bindings/tests/bindings_test_recall.py | 184 ++++++++++-------- 2 files changed, 99 insertions(+), 87 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4da1d76f..e70f94c7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: run: python -m pip install . - name: Test - run: python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py" + run: python -m unittest discover -v --start-directory python_bindings/tests --pattern "*_test*.py" test_cpp: runs-on: ${{matrix.os}} diff --git a/python_bindings/tests/bindings_test_recall.py b/python_bindings/tests/bindings_test_recall.py index 3742fcdd..55a970d1 100644 --- a/python_bindings/tests/bindings_test_recall.py +++ b/python_bindings/tests/bindings_test_recall.py @@ -1,88 +1,100 @@ +import os import hnswlib import numpy as np - -dim = 32 -num_elements = 100000 -k = 10 -nun_queries = 10 - -# Generating sample data -data = np.float32(np.random.random((num_elements, dim))) - -# Declaring index -hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip -bf_index = hnswlib.BFIndex(space='l2', dim=dim) - -# Initing both hnsw and brute force indices -# max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded -# during insertion of an element. -# The capacity can be increased by saving/loading the index, see below. -# -# hnsw construction params: -# ef_construction - controls index search speed/build speed tradeoff -# -# M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M) -# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction - -hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16) -bf_index.init_index(max_elements=num_elements) - -# Controlling the recall for hnsw by setting ef: -# higher ef leads to better accuracy, but slower search -hnsw_index.set_ef(200) - -# Set number of threads used during batch search/construction in hnsw -# By default using all available cores -hnsw_index.set_num_threads(1) - -print("Adding batch of %d elements" % (len(data))) -hnsw_index.add_items(data) -bf_index.add_items(data) - -print("Indices built") - -# Generating query data -query_data = np.float32(np.random.random((nun_queries, dim))) - -# Query the elements and measure recall: -labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k) -labels_bf, distances_bf = bf_index.knn_query(query_data, k) - -# Measure recall -correct = 0 -for i in range(nun_queries): - for label in labels_hnsw[i]: - for correct_label in labels_bf[i]: - if label == correct_label: - correct += 1 - break - -print("recall is :", float(correct)/(k*nun_queries)) - -# test serializing the brute force index -index_path = 'bf_index.bin' -print("Saving index to '%s'" % index_path) -bf_index.save_index(index_path) -del bf_index - -# Re-initiating, loading the index -bf_index = hnswlib.BFIndex(space='l2', dim=dim) - -print("\nLoading index from '%s'\n" % index_path) -bf_index.load_index(index_path) - -# Query the brute force index again to verify that we get the same results -labels_bf, distances_bf = bf_index.knn_query(query_data, k) - -# Measure recall -correct = 0 -for i in range(nun_queries): - for label in labels_hnsw[i]: - for correct_label in labels_bf[i]: - if label == correct_label: - correct += 1 - break - -print("recall after reloading is :", float(correct)/(k*nun_queries)) - - +import unittest + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + dim = 32 + num_elements = 100000 + k = 10 + num_queries = 20 + + recall_threshold = 0.95 + + # Generating sample data + data = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + + # Initing both hnsw and brute force indices + # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded + # during insertion of an element. + # The capacity can be increased by saving/loading the index, see below. + # + # hnsw construction params: + # ef_construction - controls index search speed/build speed tradeoff + # + # M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M) + # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction + + hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16) + bf_index.init_index(max_elements=num_elements) + + # Controlling the recall for hnsw by setting ef: + # higher ef leads to better accuracy, but slower search + hnsw_index.set_ef(200) + + # Set number of threads used during batch search/construction in hnsw + # By default using all available cores + hnsw_index.set_num_threads(1) + + print("Adding batch of %d elements" % (len(data))) + hnsw_index.add_items(data) + bf_index.add_items(data) + + print("Indices built") + + # Generating query data + query_data = np.float32(np.random.random((num_queries, dim))) + + # Query the elements and measure recall: + labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k) + labels_bf, distances_bf = bf_index.knn_query(query_data, k) + + # Measure recall + correct = 0 + for i in range(num_queries): + for label in labels_hnsw[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct += 1 + break + + recall_before = float(correct) / (k*num_queries) + print("recall is :", recall_before) + self.assertGreater(recall_before, recall_threshold) + + # test serializing the brute force index + index_path = 'bf_index.bin' + print("Saving index to '%s'" % index_path) + bf_index.save_index(index_path) + del bf_index + + # Re-initiating, loading the index + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + + print("\nLoading index from '%s'\n" % index_path) + bf_index.load_index(index_path) + + # Query the brute force index again to verify that we get the same results + labels_bf, distances_bf = bf_index.knn_query(query_data, k) + + # Measure recall + correct = 0 + for i in range(num_queries): + for label in labels_hnsw[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct += 1 + break + + recall_after = float(correct) / (k*num_queries) + print("recall after reloading is :", recall_after) + + self.assertEqual(recall_before, recall_after) + + os.remove(index_path) From 983cea90671e3659222e06d6609a8f62a224b0ae Mon Sep 17 00:00:00 2001 From: Georgios Tsoukas Date: Wed, 9 Nov 2022 10:10:02 +0100 Subject: [PATCH 31/41] =?UTF-8?q?Python:=C2=A0filter=20elements=20with=20a?= =?UTF-8?q?n=20optional=20filtering=20function=20(#417)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Python filter option for knn query. * Implement review suggestions * Removed template filter_func_t, add filter to brute force index and update tests (credits go to dyashuni) Co-authored-by: Georgios Tsoukas --- examples/searchKnnWithFilter_test.cpp | 67 ++++++++++--------- hnswlib/bruteforce.h | 11 ++- hnswlib/hnswalg.h | 13 ++-- hnswlib/hnswlib.h | 20 +++--- python_bindings/bindings.cpp | 42 ++++++++++-- python_bindings/tests/bindings_test_filter.py | 56 ++++++++++++++++ 6 files changed, 148 insertions(+), 61 deletions(-) create mode 100644 python_bindings/tests/bindings_test_filter.py diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index 4aee49b0..6102323c 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -11,20 +11,25 @@ namespace { using idx_t = hnswlib::labeltype; -bool pickIdsDivisibleByThree(unsigned int label_id) { - return label_id % 3 == 0; -} - -bool pickIdsDivisibleBySeven(unsigned int label_id) { - return label_id % 7 == 0; -} +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(idx_t label_id) { + return label_id % divisor == 0; + } +}; -bool pickNothing(unsigned int label_id) { - return false; -} +class PickNothing: public hnswlib::BaseFilterFunctor { + public: + bool operator()(idx_t label_id) { + return false; + } +}; -template -void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) { +void test_some_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -45,8 +50,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -57,8 +62,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe // test searchKnnCloserFirst of BruteforceSearch with filtering for (size_t j = 0; j < nq; ++j) { const void* p = query.data() + j * d; - auto gd = alg_brute->searchKnn(p, k, filter_func); - auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); + auto gd = alg_brute->searchKnn(p, k, &filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func); assert(gd.size() == res.size()); size_t t = gd.size(); while (!gd.empty()) { @@ -71,8 +76,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe // test searchKnnCloserFirst of hnsw with filtering for (size_t j = 0; j < nq; ++j) { const void* p = query.data() + j * d; - auto gd = alg_hnsw->searchKnn(p, k, filter_func); - auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); + auto gd = alg_hnsw->searchKnn(p, k, &filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func); assert(gd.size() == res.size()); size_t t = gd.size(); while (!gd.empty()) { @@ -86,8 +91,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe delete alg_hnsw; } -template -void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { +void test_none_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -108,8 +112,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -120,8 +124,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { // test searchKnnCloserFirst of BruteforceSearch with filtering for (size_t j = 0; j < nq; ++j) { const void* p = query.data() + j * d; - auto gd = alg_brute->searchKnn(p, k, filter_func); - auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); + auto gd = alg_brute->searchKnn(p, k, &filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func); assert(gd.size() == res.size()); assert(0 == gd.size()); } @@ -129,8 +133,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { // test searchKnnCloserFirst of hnsw with filtering for (size_t j = 0; j < nq; ++j) { const void* p = query.data() + j * d; - auto gd = alg_hnsw->searchKnn(p, k, filter_func); - auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); + auto gd = alg_hnsw->searchKnn(p, k, &filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func); assert(gd.size() == res.size()); assert(0 == gd.size()); } @@ -141,13 +145,13 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { } // namespace -class CustomFilterFunctor: public hnswlib::FilterFunctor { - std::unordered_set allowed_values; +class CustomFilterFunctor: public hnswlib::BaseFilterFunctor { + std::unordered_set allowed_values; public: - explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} + explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} - bool operator()(unsigned int id) { + bool operator()(idx_t id) { return allowed_values.count(id) != 0; } }; @@ -156,10 +160,13 @@ int main() { std::cout << "Testing ..." << std::endl; // some of the elements are filtered + PickDivisibleIds pickIdsDivisibleByThree(3); test_some_filtering(pickIdsDivisibleByThree, 3, 17); + PickDivisibleIds pickIdsDivisibleBySeven(7); test_some_filtering(pickIdsDivisibleBySeven, 7, 17); // all of the elements are filtered + PickNothing pickNothing; test_none_filtering(pickNothing, 17); // functor style which can capture context diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index ec2ef350..21130090 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -6,8 +6,8 @@ #include namespace hnswlib { -template -class BruteforceSearch : public AlgorithmInterface { +template +class BruteforceSearch : public AlgorithmInterface { public: char *data_; size_t maxelements_; @@ -98,15 +98,14 @@ class BruteforceSearch : public AlgorithmInterface { std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const { + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { assert(k <= cur_element_count); std::priority_queue> topResults; if (cur_element_count == 0) return topResults; - bool is_filter_disabled = std::is_same::value; for (int i = 0; i < k; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); - if (is_filter_disabled || isIdAllowed(label)) { + if ((!isIdAllowed) || (*isIdAllowed)(label)) { topResults.push(std::pair(dist, label)); } } @@ -115,7 +114,7 @@ class BruteforceSearch : public AlgorithmInterface { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); - if (is_filter_disabled || isIdAllowed(label)) { + if ((!isIdAllowed) || (*isIdAllowed)(label)) { topResults.push(std::pair(dist, label)); } if (topResults.size() > k) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 32b173e1..25995134 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -13,8 +13,8 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; -template -class HierarchicalNSW : public AlgorithmInterface { +template +class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; static const unsigned char DELETE_MARK = 0x01; @@ -268,7 +268,7 @@ class HierarchicalNSW : public AlgorithmInterface { template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const { + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -277,8 +277,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - bool is_filter_disabled = std::is_same::value; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) { + if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); @@ -336,7 +335,7 @@ class HierarchicalNSW : public AlgorithmInterface { _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) top_candidates.emplace(dist, candidate_id); if (top_candidates.size() > ef) @@ -1083,7 +1082,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const { + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { std::priority_queue> result; if (cur_element_count == 0) return result; diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index f11fd373..72c955dc 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -116,13 +116,11 @@ namespace hnswlib { typedef size_t labeltype; // This can be extended to store state for filtering (e.g. from a std::set) -struct FilterFunctor { - template - bool operator()(Args&&...) { return true; } +class BaseFilterFunctor { + public: + virtual bool operator()(hnswlib::labeltype id) { return true; } }; -static FilterFunctor allowAllIds; - template class pairGreater { public: @@ -157,27 +155,27 @@ class SpaceInterface { virtual ~SpaceInterface() {} }; -template +template class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label) = 0; virtual std::priority_queue> - searchKnn(const void*, size_t, filter_func_t& isIdAllowed = allowAllIds) const = 0; + searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; // Return k nearest neighbor in the order of closer fist virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const; + searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; virtual void saveIndex(const std::string &location) = 0; virtual ~AlgorithmInterface(){ } }; -template +template std::vector> -AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, - filter_func_t& isIdAllowed) const { +AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + BaseFilterFunctor* isIdAllowed) const { std::vector> result; // here searchKnn returns the result in the order of further first diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 85751c0b..3da8dbba 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -79,6 +80,20 @@ inline void assert_true(bool expr, const std::string & msg) { } +class CustomFilterFunctor: public hnswlib::BaseFilterFunctor { + std::function filter; + + public: + explicit CustomFilterFunctor(const std::function& f) { + filter = f; + } + + bool operator()(hnswlib::labeltype id) { + return filter(id); + } +}; + + inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) { if (buffer.ndim != 2 && buffer.ndim != 1) { char msg[256]; @@ -573,7 +588,11 @@ class Index { } - py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) { + py::object knnQuery_return_numpy( + py::object input, + size_t k = 1, + int num_threads = -1, + const std::function& filter = nullptr) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); hnswlib::labeltype* data_numpy_l; @@ -595,10 +614,13 @@ class Index { data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; + CustomFilterFunctor idFilter(filter); + CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; + if (normalize == false) { ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { std::priority_queue> result = appr_alg->searchKnn( - (void*)items.data(row), k); + (void*)items.data(row), k, p_idFilter); if (result.size() != k) throw std::runtime_error( "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); @@ -618,7 +640,7 @@ class Index { normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); std::priority_queue> result = appr_alg->searchKnn( - (void*)(norm_array.data() + start_idx), k); + (void*)(norm_array.data() + start_idx), k, p_idFilter); if (result.size() != k) throw std::runtime_error( "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); @@ -785,7 +807,10 @@ class BFIndex { } - py::object knnQuery_return_numpy(py::object input, size_t k = 1) { + py::object knnQuery_return_numpy( + py::object input, + size_t k = 1, + const std::function& filter = nullptr) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); hnswlib::labeltype *data_numpy_l; @@ -799,9 +824,12 @@ class BFIndex { data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; + CustomFilterFunctor idFilter(filter); + CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; + for (size_t row = 0; row < rows; row++) { std::priority_queue> result = alg->searchKnn( - (void *) items.data(row), k); + (void *) items.data(row), k, p_idFilter); for (int i = k - 1; i >= 0; i--) { auto &result_tuple = result.top(); data_numpy_d[row * k + i] = result_tuple.first; @@ -844,7 +872,7 @@ PYBIND11_PLUGIN(hnswlib) { .def(py::init(&Index::createFromIndex), py::arg("index")) .def(py::init(), py::arg("space"), py::arg("dim")) .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M") = 16, py::arg("ef_construction") = 200, py::arg("random_seed") = 100) - .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1) + .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1, py::arg("filter") = py::none()) .def("add_items", &Index::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads") = -1) .def("get_items", &Index::getDataReturnList, py::arg("ids") = py::none()) .def("get_ids_list", &Index::getIdsList) @@ -899,7 +927,7 @@ PYBIND11_PLUGIN(hnswlib) { py::class_>(m, "BFIndex") .def(py::init(), py::arg("space"), py::arg("dim")) .def("init_index", &BFIndex::init_new_index, py::arg("max_elements")) - .def("knn_query", &BFIndex::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1) + .def("knn_query", &BFIndex::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("filter") = py::none()) .def("add_items", &BFIndex::addItems, py::arg("data"), py::arg("ids") = py::none()) .def("delete_vector", &BFIndex::deleteVector, py::arg("label")) .def("save_index", &BFIndex::saveIndex, py::arg("path_to_index")) diff --git a/python_bindings/tests/bindings_test_filter.py b/python_bindings/tests/bindings_test_filter.py new file mode 100644 index 00000000..a0715d7c --- /dev/null +++ b/python_bindings/tests/bindings_test_filter.py @@ -0,0 +1,56 @@ +import os +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + + dim = 16 + num_elements = 10000 + + # Generating sample data + data = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + + # Initiating index + # max_elements - the maximum number of elements, should be known beforehand + # (probably will be made optional in the future) + # + # ef_construction - controls index search speed/build speed tradeoff + # M - is tightly connected with internal dimensionality of the data + # strongly affects the memory consumption + + hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + bf_index.init_index(max_elements=num_elements) + + # Controlling the recall by setting ef: + # higher ef leads to better accuracy, but slower search + hnsw_index.set_ef(10) + + hnsw_index.set_num_threads(4) # by default using all available cores + + print("Adding %d elements" % (len(data))) + hnsw_index.add_items(data) + bf_index.add_items(data) + + # Query the elements for themselves and measure recall: + labels, distances = hnsw_index.knn_query(data, k=1) + self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3) + + print("Querying only even elements") + # Query the even elements for themselves and measure recall: + filter_function = lambda id: id%2 == 0 + labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) + self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3) + # Verify that there are onle even elements: + self.assertTrue(np.max(np.mod(labels, 2)) == 0) + + labels, distances = bf_index.knn_query(data, k=1, filter=filter_function) + self.assertEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5) From 3e006ea4b0ffdcd3fb1319370859a24da897ddd0 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Thu, 12 Jan 2023 13:53:38 +0400 Subject: [PATCH 32/41] Replace deleted elements at addition (#418) * Replace deleted elements at insertion * Add multithread stress tests * Add timeout to jobs in actions * Add locks by label * Remove python 3.6 tests as it is not available in Ubuntu 22.04 * Fix multithread update of elements * Update readme and refactoring --- .github/workflows/build.yml | 6 +- .gitignore | 1 + CMakeLists.txt | 6 + README.md | 117 ++++++++- examples/multiThreadLoad_test.cpp | 140 ++++++++++ examples/multiThread_replace_test.cpp | 121 +++++++++ hnswlib/bruteforce.h | 2 +- hnswlib/hnswalg.h | 171 +++++++++--- hnswlib/hnswlib.h | 2 +- python_bindings/bindings.cpp | 57 +++- python_bindings/tests/bindings_test_filter.py | 2 +- python_bindings/tests/bindings_test_labels.py | 25 +- python_bindings/tests/bindings_test_recall.py | 2 +- .../tests/bindings_test_replace.py | 245 ++++++++++++++++++ .../tests/bindings_test_stress_mt_replace.py | 68 +++++ 15 files changed, 884 insertions(+), 81 deletions(-) create mode 100644 examples/multiThreadLoad_test.cpp create mode 100644 examples/multiThread_replace_test.cpp create mode 100644 python_bindings/tests/bindings_test_replace.py create mode 100644 python_bindings/tests/bindings_test_stress_mt_replace.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e70f94c7..e86d2545 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,7 +8,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -19,6 +19,7 @@ jobs: run: python -m pip install . - name: Test + timeout-minutes: 15 run: python -m unittest discover -v --start-directory python_bindings/tests --pattern "*_test*.py" test_cpp: @@ -52,6 +53,7 @@ jobs: shell: bash - name: Test + timeout-minutes: 15 run: | cd build if [ "$RUNNER_OS" == "Windows" ]; then @@ -59,6 +61,8 @@ jobs: fi ./searchKnnCloserFirst_test ./searchKnnWithFilter_test + ./multiThreadLoad_test + ./multiThread_replace_test ./test_updates ./test_updates update shell: bash diff --git a/.gitignore b/.gitignore index a338107c..48f74604 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ var/ .idea/ .vscode/ .vs/ +**.DS_Store diff --git a/CMakeLists.txt b/CMakeLists.txt index e42d6cee..de951171 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,12 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp) target_link_libraries(searchKnnWithFilter_test hnswlib) + add_executable(multiThreadLoad_test examples/multiThreadLoad_test.cpp) + target_link_libraries(multiThreadLoad_test hnswlib) + + add_executable(multiThread_replace_test examples/multiThread_replace_test.cpp) + target_link_libraries(multiThread_replace_test hnswlib) + add_executable(main main.cpp sift_1b.cpp) target_link_libraries(main hnswlib) endif() diff --git a/README.md b/README.md index c86e4391..c0b0dbcc 100644 --- a/README.md +++ b/README.md @@ -54,19 +54,22 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib. * `hnswlib.Index(space, dim)` creates a non-initialized index an HNSW in space `space` with integer dimension `dim`. `hnswlib.Index` methods: -* `init_index(max_elements, M = 16, ef_construction = 200, random_seed = 100)` initializes the index from with no elements. +* `init_index(max_elements, M = 16, ef_construction = 200, random_seed = 100, allow_replace_deleted = False)` initializes the index from with no elements. * `max_elements` defines the maximum number of elements that can be stored in the structure(can be increased/shrunk). * `ef_construction` defines a construction time/accuracy trade-off (see [ALGO_PARAMS.md](ALGO_PARAMS.md)). * `M` defines tha maximum number of outgoing connections in the graph ([ALGO_PARAMS.md](ALGO_PARAMS.md)). + * `allow_replace_deleted` enables replacing of deleted elements with new added ones. -* `add_items(data, ids, num_threads = -1)` - inserts the `data`(numpy array of vectors, shape:`N*dim`) into the structure. +* `add_items(data, ids, num_threads = -1, replace_deleted = False)` - inserts the `data`(numpy array of vectors, shape:`N*dim`) into the structure. * `num_threads` sets the number of cpu threads to use (-1 means use default). * `ids` are optional N-size numpy array of integer labels for all elements in `data`. - If index already has the elements with the same labels, their features will be updated. Note that update procedure is slower than insertion of a new element, but more memory- and query-efficient. + * `replace_deleted` replaces deleted elements. Note it allows to save memory. + - to use it `init_index` should be called with `allow_replace_deleted=True` * Thread-safe with other `add_items` calls, but not with `knn_query`. * `mark_deleted(label)` - marks the element as deleted, so it will be omitted from search results. Throws an exception if it is already deleted. -* + * `unmark_deleted(label)` - unmarks the element as deleted, so it will be not be omitted from search results. * `resize_index(new_size)` - changes the maximum capacity of the index. Not thread safe with `add_items` and `knn_query`. @@ -74,13 +77,15 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib. * `set_ef(ef)` - sets the query time accuracy/speed trade-off, defined by the `ef` parameter ( [ALGO_PARAMS.md](ALGO_PARAMS.md)). Note that the parameter is currently not saved along with the index, so you need to set it manually after loading. -* `knn_query(data, k = 1, num_threads = -1)` make a batch query for `k` closest elements for each element of the +* `knn_query(data, k = 1, num_threads = -1, filter = None)` make a batch query for `k` closest elements for each element of the * `data` (shape:`N*dim`). Returns a numpy array of (shape:`N*k`). * `num_threads` sets the number of cpu threads to use (-1 means use default). + * `filter` filters elements by its labels, returns elements with allowed ids * Thread-safe with other `knn_query` calls, but not with `add_items`. -* `load_index(path_to_index, max_elements = 0)` loads the index from persistence to the uninitialized index. +* `load_index(path_to_index, max_elements = 0, allow_replace_deleted = False)` loads the index from persistence to the uninitialized index. * `max_elements`(optional) resets the maximum number of elements in the structure. + * `allow_replace_deleted` specifies whether the index being loaded has enabled replacing of deleted elements. * `save_index(path_to_index)` saves the index from persistence. @@ -142,7 +147,7 @@ p.add_items(data, ids) # Controlling the recall by setting ef: p.set_ef(50) # ef should always be > k -# Query dataset, k - number of closest elements (returns 2 numpy arrays) +# Query dataset, k - number of the closest elements (returns 2 numpy arrays) labels, distances = p.knn_query(data, k = 1) # Index objects support pickling @@ -155,7 +160,6 @@ print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") - ``` An example with updates after serialization/deserialization: @@ -196,7 +200,6 @@ p.set_ef(10) # By default using all available cores p.set_num_threads(4) - print("Adding first batch of %d elements" % (len(data1))) p.add_items(data1) @@ -226,6 +229,104 @@ labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") ``` +An example with a filter: +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +print("Adding %d elements" % (len(data))) +# Added elements will have consecutive ids +hnsw_index.add_items(data, ids=np.arange(num_elements)) + +print("Querying only even elements") +# Define filter function that allows only even ids +filter_function = lambda idx: idx%2 == 0 +# Query the elements for themselves and search only for even elements: +labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) +# labels contain only elements with even id +``` + +An example with replacing of deleted elements: +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 1_000 +max_num_elements = 2 * num_elements + +# Generating sample data +labels1 = np.arange(0, num_elements) +data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 +labels2 = np.arange(num_elements, 2 * num_elements) +data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 +labels3 = np.arange(2 * num_elements, 3 * num_elements) +data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +# Enable replacing of deleted elements +hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +# Add batch 1 and 2 data +hnsw_index.add_items(data1, labels1) +hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached + +# Delete data of batch 2 +for label in labels2: + hnsw_index.mark_deleted(label) + +# Replace deleted elements +# Maximum number of elements is reached therefore we cannot add new items, +# but we can replace the deleted ones by using replace_deleted=True +hnsw_index.add_items(data3, labels3, replace_deleted=True) +# hnsw_index contains the data of batch 1 and batch 3 only +``` + ### Bindings installation You can install from sources: diff --git a/examples/multiThreadLoad_test.cpp b/examples/multiThreadLoad_test.cpp new file mode 100644 index 00000000..a713b2ba --- /dev/null +++ b/examples/multiThreadLoad_test.cpp @@ -0,0 +1,140 @@ +#include "../hnswlib/hnswlib.h" +#include +#include + + +int main() { + std::cout << "Running multithread load test" << std::endl; + int d = 16; + int max_elements = 1000; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + + hnswlib::L2Space space(d); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * max_elements); + + std::cout << "Building index" << std::endl; + int num_threads = 40; + int num_labels = 10; + + int num_iterations = 10; + int start_label = 0; + + // run threads that will add elements to the index + // about 7 threads (the number depends on num_threads and num_labels) + // will add/update element with the same label simultaneously + while (true) { + // add elements by batches + std::uniform_int_distribution<> distrib_int(start_label, start_label + num_labels - 1); + std::vector threads; + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&] { + for (int iter = 0; iter < num_iterations; iter++) { + std::vector data(d); + hnswlib::labeltype label = distrib_int(rng); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + } + } + ) + ); + } + for (auto &thread : threads) { + thread.join(); + } + if (alg_hnsw->cur_element_count > max_elements - num_labels) { + break; + } + start_label += num_labels; + } + + // insert remaining elements if needed + for (hnswlib::labeltype label = 0; label < max_elements; label++) { + auto search = alg_hnsw->label_lookup_.find(label); + if (search == alg_hnsw->label_lookup_.end()) { + std::cout << "Adding " << label << std::endl; + std::vector data(d); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + } + } + + std::cout << "Index is created" << std::endl; + + bool stop_threads = false; + std::vector threads; + + // create threads that will do markDeleted and unmarkDeleted of random elements + // each thread works with specific range of labels + std::cout << "Starting markDeleted and unmarkDeleted threads" << std::endl; + num_threads = 20; + int chunk_size = max_elements / num_threads; + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&, thread_id] { + std::uniform_int_distribution<> distrib_int(0, chunk_size - 1); + int start_id = thread_id * chunk_size; + std::vector marked_deleted(chunk_size); + while (!stop_threads) { + int id = distrib_int(rng); + hnswlib::labeltype label = start_id + id; + if (marked_deleted[id]) { + alg_hnsw->unmarkDelete(label); + marked_deleted[id] = false; + } else { + alg_hnsw->markDelete(label); + marked_deleted[id] = true; + } + } + } + ) + ); + } + + // create threads that will add and update random elements + std::cout << "Starting add and update elements threads" << std::endl; + num_threads = 20; + std::uniform_int_distribution<> distrib_int_add(max_elements, 2 * max_elements - 1); + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&] { + std::vector data(d); + while (!stop_threads) { + hnswlib::labeltype label = distrib_int_add(rng); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + std::vector data = alg_hnsw->getDataByLabel(label); + float max_val = *max_element(data.begin(), data.end()); + // never happens but prevents compiler from deleting unused code + if (max_val > 10) { + throw std::runtime_error("Unexpected value in data"); + } + } + } + ) + ); + } + + std::cout << "Sleep and continue operations with index" << std::endl; + int sleep_ms = 60 * 1000; + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + stop_threads = true; + for (auto &thread : threads) { + thread.join(); + } + + std::cout << "Finish" << std::endl; + return 0; +} diff --git a/examples/multiThread_replace_test.cpp b/examples/multiThread_replace_test.cpp new file mode 100644 index 00000000..83ed2826 --- /dev/null +++ b/examples/multiThread_replace_test.cpp @@ -0,0 +1,121 @@ +#include "../hnswlib/hnswlib.h" +#include +#include + + +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +int main() { + std::cout << "Running multithread load test" << std::endl; + int d = 16; + int num_elements = 1000; + int max_elements = 2 * num_elements; + int num_threads = 50; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + + hnswlib::L2Space space(d); + + // generate batch1 and batch2 data + float* batch1 = new float[d * max_elements]; + for (int i = 0; i < d * max_elements; i++) { + batch1[i] = distrib_real(rng); + } + float* batch2 = new float[d * num_elements]; + for (int i = 0; i < d * num_elements; i++) { + batch2[i] = distrib_real(rng); + } + + // generate random labels to delete them from index + std::vector rand_labels(max_elements); + for (int i = 0; i < max_elements; i++) { + rand_labels[i] = i; + } + std::shuffle(rand_labels.begin(), rand_labels.end(), rng); + + int iter = 0; + while (iter < 200) { + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, 16, 200, 123, true); + + // add batch1 data + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(batch1 + d * row), row); + }); + + // delete half random elements of batch1 data + for (int i = 0; i < num_elements; i++) { + alg_hnsw->markDelete(rand_labels[i]); + } + + // replace deleted elements with batch2 data + ParallelFor(0, num_elements, num_threads, [&](size_t row, size_t threadId) { + int label = rand_labels[row] + max_elements; + alg_hnsw->addPoint((void*)(batch2 + d * row), label, true); + }); + + iter += 1; + + delete alg_hnsw; + } + + std::cout << "Finish" << std::endl; + + delete[] batch1; + delete[] batch2; + return 0; +} diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 21130090..30b33ae9 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -61,7 +61,7 @@ class BruteforceSearch : public AlgorithmInterface { } - void addPoint(const void *datapoint, labeltype label) { + void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { int idx; { std::unique_lock lock(index_lock); diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 25995134..7f34e62b 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -16,11 +16,11 @@ typedef unsigned int linklistsizeint; template class HierarchicalNSW : public AlgorithmInterface { public: - static const tableint max_update_element_locks = 65536; + static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; static const unsigned char DELETE_MARK = 0x01; size_t max_elements_{0}; - size_t cur_element_count{0}; + mutable std::atomic cur_element_count{0}; // current number of elements size_t size_data_per_element_{0}; size_t size_links_per_element_{0}; mutable std::atomic num_deleted_{0}; // number of deleted elements @@ -35,13 +35,10 @@ class HierarchicalNSW : public AlgorithmInterface { VisitedListPool *visited_list_pool_{nullptr}; - // Locks to prevent race condition during update/insert of an element at same time. - // Note: Locks for additions can also be used to prevent this race condition - // if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. - std::vector link_list_update_locks_; + // Locks operations with element by label value + mutable std::vector label_op_locks_; std::mutex global; - std::mutex cur_element_count_guard_; std::vector link_list_locks_; tableint enterpoint_node_{0}; @@ -57,7 +54,8 @@ class HierarchicalNSW : public AlgorithmInterface { DISTFUNC fstdistfunc_; void *dist_func_param_{nullptr}; - std::mutex label_lookup_lock; + + mutable std::mutex label_lookup_lock; // lock for label_lookup_ std::unordered_map label_lookup_; std::default_random_engine level_generator_; @@ -66,6 +64,11 @@ class HierarchicalNSW : public AlgorithmInterface { mutable std::atomic metric_distance_computations{0}; mutable std::atomic metric_hops{0}; + bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions + + std::mutex deleted_elements_lock; // lock for deleted_elements + std::unordered_set deleted_elements; // contains internal ids of deleted elements + HierarchicalNSW(SpaceInterface *s) { } @@ -75,7 +78,9 @@ class HierarchicalNSW : public AlgorithmInterface { SpaceInterface *s, const std::string &location, bool nmslib = false, - size_t max_elements = 0) { + size_t max_elements = 0, + bool allow_replace_deleted = false) + : allow_replace_deleted_(allow_replace_deleted) { loadIndex(location, s, max_elements); } @@ -85,10 +90,12 @@ class HierarchicalNSW : public AlgorithmInterface { size_t max_elements, size_t M = 16, size_t ef_construction = 200, - size_t random_seed = 100) + size_t random_seed = 100, + bool allow_replace_deleted = false) : link_list_locks_(max_elements), - link_list_update_locks_(max_update_element_locks), - element_levels_(max_elements) { + label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + element_levels_(max_elements), + allow_replace_deleted_(allow_replace_deleted) { max_elements_ = max_elements; num_deleted_ = 0; data_size_ = s->get_data_size(); @@ -154,6 +161,13 @@ class HierarchicalNSW : public AlgorithmInterface { } + inline std::mutex& getLabelOpMutex(labeltype label) const { + // calculate hash + size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); + return label_op_locks_[lock_id]; + } + + inline labeltype getExternalLabel(tableint internal_id) const { labeltype return_label; memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); @@ -437,6 +451,12 @@ class HierarchicalNSW : public AlgorithmInterface { tableint next_closest_entry_point = selectedNeighbors.back(); { + // lock only during the update + // because during the addition the lock for cur_c is already acquired + std::unique_lock lock(link_list_locks_[cur_c], std::defer_lock); + if (isUpdate) { + lock.lock(); + } linklistsizeint *ll_cur; if (level == 0) ll_cur = get_linklist0(cur_c); @@ -664,7 +684,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); std::vector(max_elements).swap(link_list_locks_); - std::vector(max_update_element_locks).swap(link_list_update_locks_); + std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); visited_list_pool_ = new VisitedListPool(1, max_elements); @@ -693,6 +713,7 @@ class HierarchicalNSW : public AlgorithmInterface { for (size_t i = 0; i < cur_element_count; i++) { if (isMarkedDeleted(i)) { num_deleted_ += 1; + if (allow_replace_deleted_) deleted_elements.insert(i); } } @@ -704,14 +725,18 @@ class HierarchicalNSW : public AlgorithmInterface { template std::vector getDataByLabel(labeltype label) const { - tableint label_internal; + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { throw std::runtime_error("Label not found"); } - label_internal = search->second; + tableint internalId = search->second; + lock_table.unlock(); - char* data_ptrv = getDataByInternalId(label_internal); + char* data_ptrv = getDataByInternalId(internalId); size_t dim = *((size_t *) dist_func_param_); std::vector data; data_t* data_ptr = (data_t*) data_ptrv; @@ -723,66 +748,90 @@ class HierarchicalNSW : public AlgorithmInterface { } - /** - * Marks an element with the given label deleted, does NOT really change the current graph. - */ + /* + * Marks an element with the given label deleted, does NOT really change the current graph. + */ void markDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { throw std::runtime_error("Label not found"); } tableint internalId = search->second; + lock_table.unlock(); + markDeletedInternal(internalId); } - /** - * Uses the last 16 bits of the memory for the linked list size to store the mark, - * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. - */ + /* + * Uses the last 16 bits of the memory for the linked list size to store the mark, + * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. + */ void markDeletedInternal(tableint internalId) { assert(internalId < cur_element_count); if (!isMarkedDeleted(internalId)) { unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; *ll_cur |= DELETE_MARK; num_deleted_ += 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.insert(internalId); + } } else { throw std::runtime_error("The requested to delete element is already deleted"); } } - /** - * Remove the deleted mark of the node, does NOT really change the current graph. - */ + /* + * Removes the deleted mark of the node, does NOT really change the current graph. + * + * Note: the method is not safe to use when replacement of deleted elements is enabled, + * because elements marked as deleted can be completely removed by addPoint + */ void unmarkDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { throw std::runtime_error("Label not found"); } tableint internalId = search->second; + lock_table.unlock(); + unmarkDeletedInternal(internalId); } - /** - * Remove the deleted mark of the node. - */ + + /* + * Remove the deleted mark of the node. + */ void unmarkDeletedInternal(tableint internalId) { assert(internalId < cur_element_count); if (isMarkedDeleted(internalId)) { unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; *ll_cur &= ~DELETE_MARK; num_deleted_ -= 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.erase(internalId); + } } else { throw std::runtime_error("The requested to undelete element is not deleted"); } } - /** - * Checks the first 16 bits of the memory to see if the element is marked deleted. - */ + /* + * Checks the first 16 bits of the memory to see if the element is marked deleted. + */ bool isMarkedDeleted(tableint internalId) const { unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; return *ll_cur & DELETE_MARK; @@ -799,11 +848,48 @@ class HierarchicalNSW : public AlgorithmInterface { } - /** - * Adds point. Updates the point if it is already in the index + /* + * Adds point. Updates the point if it is already in the index. + * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point */ - void addPoint(const void *data_point, labeltype label) { - addPoint(data_point, label, -1); + void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { + if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { + throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); + } + + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + if (!replace_deleted) { + addPoint(data_point, label, -1); + return; + } + // check if there is vacant place + tableint internal_id_replaced; + std::unique_lock lock_deleted_elements(deleted_elements_lock); + bool is_vacant_place = !deleted_elements.empty(); + if (is_vacant_place) { + internal_id_replaced = *deleted_elements.begin(); + deleted_elements.erase(internal_id_replaced); + } + lock_deleted_elements.unlock(); + + // if there is no vacant place then add or update point + // else add point to vacant place + if (!is_vacant_place) { + addPoint(data_point, label, -1); + } else { + // we assume that there are no concurrent operations on deleted element + labeltype label_replaced = getExternalLabel(internal_id_replaced); + setExternalLabel(internal_id_replaced, label); + + std::unique_lock lock_table(label_lookup_lock); + label_lookup_.erase(label_replaced); + label_lookup_[label] = internal_id_replaced; + lock_table.unlock(); + + unmarkDeletedInternal(internal_id_replaced); + updatePoint(data_point, internal_id_replaced, 1.0); + } } @@ -970,13 +1056,16 @@ class HierarchicalNSW : public AlgorithmInterface { { // Checking if the element with the same label already exists // if so, updating it *instead* of creating a new element. - std::unique_lock templock_curr(cur_element_count_guard_); + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search != label_lookup_.end()) { tableint existingInternalId = search->second; - templock_curr.unlock(); - - std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); + if (allow_replace_deleted_) { + if (isMarkedDeleted(existingInternalId)) { + throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); + } + } + lock_table.unlock(); if (isMarkedDeleted(existingInternalId)) { unmarkDeletedInternal(existingInternalId); @@ -995,8 +1084,6 @@ class HierarchicalNSW : public AlgorithmInterface { label_lookup_[label] = cur_c; } - // Take update lock to prevent race conditions on an element with insertion/update at the same time. - std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); std::unique_lock lock_el(link_list_locks_[cur_c]); int curlevel = getRandomLevel(mult_); if (level > 0) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 72c955dc..fb7118fa 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -158,7 +158,7 @@ class SpaceInterface { template class AlgorithmInterface { public: - virtual void addPoint(const void *datapoint, labeltype label) = 0; + virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; virtual std::priority_queue> searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 3da8dbba..3196a228 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -193,12 +193,13 @@ class Index { size_t maxElements, size_t M, size_t efConstruction, - size_t random_seed) { + size_t random_seed, + bool allow_replace_deleted) { if (appr_alg) { throw std::runtime_error("The index is already initiated."); } cur_l = 0; - appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed); + appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed, allow_replace_deleted); index_inited = true; ep_added = false; appr_alg->ef_ = default_ef; @@ -223,12 +224,12 @@ class Index { } - void loadIndex(const std::string &path_to_index, size_t max_elements) { + void loadIndex(const std::string &path_to_index, size_t max_elements, bool allow_replace_deleted) { if (appr_alg) { std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl; delete appr_alg; } - appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements); + appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements, allow_replace_deleted); cur_l = appr_alg->cur_element_count; index_inited = true; } @@ -244,7 +245,7 @@ class Index { } - void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1) { + void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1, bool replace_deleted = false) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); if (num_threads <= 0) @@ -273,7 +274,7 @@ class Index { normalize_vector(vector_data, norm_array.data()); vector_data = norm_array.data(); } - appr_alg->addPoint((void*)vector_data, (size_t)id); + appr_alg->addPoint((void*)vector_data, (size_t)id, replace_deleted); start = 1; ep_added = true; } @@ -282,7 +283,7 @@ class Index { if (normalize == false) { ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { size_t id = ids.size() ? ids.at(row) : (cur_l + row); - appr_alg->addPoint((void*)items.data(row), (size_t)id); + appr_alg->addPoint((void*)items.data(row), (size_t)id, replace_deleted); }); } else { std::vector norm_array(num_threads * dim); @@ -292,7 +293,7 @@ class Index { normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); size_t id = ids.size() ? ids.at(row) : (cur_l + row); - appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id); + appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id, replace_deleted); }); } cur_l += rows; @@ -400,7 +401,7 @@ class Index { return py::dict( "offset_level0"_a = appr_alg->offsetLevel0_, "max_elements"_a = appr_alg->max_elements_, - "cur_element_count"_a = appr_alg->cur_element_count, + "cur_element_count"_a = (size_t)appr_alg->cur_element_count, "size_data_per_element"_a = appr_alg->size_data_per_element_, "label_offset"_a = appr_alg->label_offset_, "offset_data"_a = appr_alg->offsetData_, @@ -414,6 +415,7 @@ class Index { "ef"_a = appr_alg->ef_, "has_deletions"_a = (bool)appr_alg->num_deleted_, "size_links_per_element"_a = appr_alg->size_links_per_element_, + "allow_replace_deleted"_a = appr_alg->allow_replace_deleted_, "label_lookup_external"_a = py::array_t( { appr_alg->label_lookup_.size() }, // shape @@ -576,12 +578,19 @@ class Index { } // process deleted elements + bool allow_replace_deleted = false; + if (d.contains("allow_replace_deleted")) { + allow_replace_deleted = d["allow_replace_deleted"].cast(); + } + appr_alg->allow_replace_deleted_= allow_replace_deleted; + appr_alg->num_deleted_ = 0; bool has_deletions = d["has_deletions"].cast(); if (has_deletions) { for (size_t i = 0; i < appr_alg->cur_element_count; i++) { if (appr_alg->isMarkedDeleted(i)) { appr_alg->num_deleted_ += 1; + if (allow_replace_deleted) appr_alg->deleted_elements.insert(i); } } } @@ -871,15 +880,35 @@ PYBIND11_PLUGIN(hnswlib) { /* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */ .def(py::init(&Index::createFromIndex), py::arg("index")) .def(py::init(), py::arg("space"), py::arg("dim")) - .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M") = 16, py::arg("ef_construction") = 200, py::arg("random_seed") = 100) - .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1, py::arg("filter") = py::none()) - .def("add_items", &Index::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads") = -1) + .def("init_index", + &Index::init_new_index, + py::arg("max_elements"), + py::arg("M") = 16, + py::arg("ef_construction") = 200, + py::arg("random_seed") = 100, + py::arg("allow_replace_deleted") = false) + .def("knn_query", + &Index::knnQuery_return_numpy, + py::arg("data"), + py::arg("k") = 1, + py::arg("num_threads") = -1, + py::arg("filter") = py::none()) + .def("add_items", + &Index::addItems, + py::arg("data"), + py::arg("ids") = py::none(), + py::arg("num_threads") = -1, + py::arg("replace_deleted") = false) .def("get_items", &Index::getDataReturnList, py::arg("ids") = py::none()) .def("get_ids_list", &Index::getIdsList) .def("set_ef", &Index::set_ef, py::arg("ef")) .def("set_num_threads", &Index::set_num_threads, py::arg("num_threads")) .def("save_index", &Index::saveIndex, py::arg("path_to_index")) - .def("load_index", &Index::loadIndex, py::arg("path_to_index"), py::arg("max_elements") = 0) + .def("load_index", + &Index::loadIndex, + py::arg("path_to_index"), + py::arg("max_elements") = 0, + py::arg("allow_replace_deleted") = false) .def("mark_deleted", &Index::markDeleted, py::arg("label")) .def("unmark_deleted", &Index::unmarkDeleted, py::arg("label")) .def("resize_index", &Index::resizeIndex, py::arg("new_size")) @@ -901,7 +930,7 @@ PYBIND11_PLUGIN(hnswlib) { return index.index_inited ? index.appr_alg->max_elements_ : 0; }) .def_property_readonly("element_count", [](const Index & index) { - return index.index_inited ? index.appr_alg->cur_element_count : 0; + return index.index_inited ? (size_t)index.appr_alg->cur_element_count : 0; }) .def_property_readonly("ef_construction", [](const Index & index) { return index.index_inited ? index.appr_alg->ef_construction_ : 0; diff --git a/python_bindings/tests/bindings_test_filter.py b/python_bindings/tests/bindings_test_filter.py index a0715d7c..a798e02f 100644 --- a/python_bindings/tests/bindings_test_filter.py +++ b/python_bindings/tests/bindings_test_filter.py @@ -49,7 +49,7 @@ def testRandomSelf(self): filter_function = lambda id: id%2 == 0 labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3) - # Verify that there are onle even elements: + # Verify that there are only even elements: self.assertTrue(np.max(np.mod(labels, 2)) == 0) labels, distances = bf_index.knn_query(data, k=1, filter=filter_function) diff --git a/python_bindings/tests/bindings_test_labels.py b/python_bindings/tests/bindings_test_labels.py index 2b091371..524a24d5 100644 --- a/python_bindings/tests/bindings_test_labels.py +++ b/python_bindings/tests/bindings_test_labels.py @@ -95,19 +95,20 @@ def testRandomSelf(self): # Delete data1 labels1_deleted, _ = p.knn_query(data1, k=1) - - for l in labels1_deleted: - p.mark_deleted(l[0]) + # delete probable duplicates from nearest neighbors + labels1_deleted_no_dup = set(labels1_deleted.flatten()) + for l in labels1_deleted_no_dup: + p.mark_deleted(l) labels2, _ = p.knn_query(data2, k=1) items = p.get_items(labels2) diff_with_gt_labels = np.mean(np.abs(data2-items)) - self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) # console + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) labels1_after, _ = p.knn_query(data1, k=1) for la in labels1_after: - for lb in labels1_deleted: - if la[0] == lb[0]: - self.assertTrue(False) + if la[0] in labels1_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search") + self.assertTrue(False) print("All the data in data1 are removed") # Checking saving/loading index with elements marked as deleted @@ -119,13 +120,13 @@ def testRandomSelf(self): labels1_after, _ = p.knn_query(data1, k=1) for la in labels1_after: - for lb in labels1_deleted: - if la[0] == lb[0]: - self.assertTrue(False) + if la[0] in labels1_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search after index loading") + self.assertTrue(False) # Unmark deleted data - for l in labels1_deleted: - p.unmark_deleted(l[0]) + for l in labels1_deleted_no_dup: + p.unmark_deleted(l) labels_restored, _ = p.knn_query(data1, k=1) self.assertAlmostEqual(np.mean(labels_restored.reshape(-1) == np.arange(len(data1))), 1.0, 3) print("All the data in data1 are restored") diff --git a/python_bindings/tests/bindings_test_recall.py b/python_bindings/tests/bindings_test_recall.py index 55a970d1..2190ba45 100644 --- a/python_bindings/tests/bindings_test_recall.py +++ b/python_bindings/tests/bindings_test_recall.py @@ -40,7 +40,7 @@ def testRandomSelf(self): # Set number of threads used during batch search/construction in hnsw # By default using all available cores - hnsw_index.set_num_threads(1) + hnsw_index.set_num_threads(4) print("Adding batch of %d elements" % (len(data))) hnsw_index.add_items(data) diff --git a/python_bindings/tests/bindings_test_replace.py b/python_bindings/tests/bindings_test_replace.py new file mode 100644 index 00000000..80003a3a --- /dev/null +++ b/python_bindings/tests/bindings_test_replace.py @@ -0,0 +1,245 @@ +import os +import pickle +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + """ + Tests if replace of deleted elements works correctly + Tests serialization of the index with replaced elements + """ + dim = 16 + num_elements = 5000 + max_num_elements = 2 * num_elements + + recall_threshold = 0.98 + + # Generating sample data + print("Generating data") + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + # batch 4 + first_id += num_elements + last_id += num_elements + labels4 = np.arange(first_id, last_id) + data4 = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) + hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + hnsw_index.set_ef(100) + hnsw_index.set_num_threads(4) + + # Add batch 1 and 2 + print("Adding batch 1") + hnsw_index.add_items(data1, labels1) + print("Adding batch 2") + hnsw_index.add_items(data2, labels2) # maximum number of elements is reached + + # Delete nearest neighbors of batch 2 + print("Deleting neighbors of batch 2") + labels2_deleted, _ = hnsw_index.knn_query(data2, k=1) + # delete probable duplicates from nearest neighbors + labels2_deleted_no_dup = set(labels2_deleted.flatten()) + num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup) + for l in labels2_deleted_no_dup: + hnsw_index.mark_deleted(l) + labels1_found, _ = hnsw_index.knn_query(data1, k=1) + items = hnsw_index.get_items(labels1_found) + diff_with_gt_labels = np.mean(np.abs(data1 - items)) + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) + + labels2_after, _ = hnsw_index.knn_query(data2, k=1) + for la in labels2_after: + if la[0] in labels2_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search") + self.assertTrue(False) + print("All the neighbors of data2 are removed") + + # Replace deleted elements + print("Inserting batch 3 by replacing deleted elements") + # Maximum number of elements is reached therefore we cannot add new items + # but we can replace the deleted ones + # Note: there may be less than num_elements elements. + # As we could delete less than num_elements because of duplicates + labels3_tr = labels3[0:labels3.shape[0] - num_duplicates] + data3_tr = data3[0:data3.shape[0] - num_duplicates] + hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True) + + # After replacing, all labels should be retrievable + print("Checking that remaining labels are in index") + # Get remaining data from batch 1 and batch 2 after deletion of elements + remaining_labels = (set(labels1) | set(labels2)) - labels2_deleted_no_dup + remaining_labels_list = list(remaining_labels) + comb_data = np.concatenate((data1, data2), axis=0) + remaining_data = comb_data[remaining_labels_list] + + returned_items = hnsw_index.get_items(remaining_labels_list) + self.assertSequenceEqual(remaining_data.tolist(), returned_items) + + returned_items = hnsw_index.get_items(labels3_tr) + self.assertSequenceEqual(data3_tr.tolist(), returned_items) + + # Check index serialization + # Delete batch 3 + print("Deleting batch 3") + for l in labels3_tr: + hnsw_index.mark_deleted(l) + + # Save index + index_path = "index.bin" + print(f"Saving index to {index_path}") + hnsw_index.save_index(index_path) + del hnsw_index + + # Reinit and load the index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # the space can be changed - keeps the data, alters the distance function. + hnsw_index.set_num_threads(4) + print(f"Loading index from {index_path}") + hnsw_index.load_index(index_path, max_elements=max_num_elements, allow_replace_deleted=True) + + # Insert batch 4 + print("Inserting batch 4 by replacing deleted elements") + labels4_tr = labels4[0:labels4.shape[0] - num_duplicates] + data4_tr = data4[0:data4.shape[0] - num_duplicates] + hnsw_index.add_items(data4_tr, labels4_tr, replace_deleted=True) + + # Check recall + print("Checking recall") + labels_found, _ = hnsw_index.knn_query(data4_tr, k=1) + recall = np.mean(labels_found.reshape(-1) == labels4_tr) + print(f"Recall for the 4 batch: {recall}") + self.assertGreater(recall, recall_threshold) + + # Delete batch 4 + print("Deleting batch 4") + for l in labels4_tr: + hnsw_index.mark_deleted(l) + + print("Testing pickle serialization") + hnsw_index_pckl = pickle.loads(pickle.dumps(hnsw_index)) + del hnsw_index + # Insert batch 3 + print("Inserting batch 3 by replacing deleted elements") + hnsw_index_pckl.add_items(data3_tr, labels3_tr, replace_deleted=True) + + # Check recall + print("Checking recall") + labels_found, _ = hnsw_index_pckl.knn_query(data3_tr, k=1) + recall = np.mean(labels_found.reshape(-1) == labels3_tr) + print(f"Recall for the 3 batch: {recall}") + self.assertGreater(recall, recall_threshold) + + os.remove(index_path) + + + def test_recall_degradation(self): + """ + Compares recall of the index with replaced elements and without + Measures recall degradation + """ + dim = 16 + num_elements = 10_000 + max_num_elements = 2 * num_elements + query_size = 1_000 + k = 100 + + recall_threshold = 0.98 + max_recall_diff = 0.02 + + # Generating sample data + print("Generating data") + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + # query to test recall + query_data = np.float32(np.random.random((query_size, dim))) + + # Declaring index + hnsw_index_no_replace = hnswlib.Index(space='l2', dim=dim) + hnsw_index_no_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=False) + hnsw_index_with_replace = hnswlib.Index(space='l2', dim=dim) + hnsw_index_with_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + bf_index.init_index(max_elements=max_num_elements) + + hnsw_index_no_replace.set_ef(100) + hnsw_index_no_replace.set_num_threads(50) + hnsw_index_with_replace.set_ef(100) + hnsw_index_with_replace.set_num_threads(50) + + # Add data + print("Adding data") + hnsw_index_with_replace.add_items(data1, labels1) + hnsw_index_with_replace.add_items(data2, labels2) # maximum number of elements is reached + bf_index.add_items(data1, labels1) + bf_index.add_items(data3, labels3) # maximum number of elements is reached + + for l in labels2: + hnsw_index_with_replace.mark_deleted(l) + hnsw_index_with_replace.add_items(data3, labels3, replace_deleted=True) + + hnsw_index_no_replace.add_items(data1, labels1) + hnsw_index_no_replace.add_items(data3, labels3) # maximum number of elements is reached + + # Query the elements and measure recall: + labels_hnsw_with_replace, _ = hnsw_index_with_replace.knn_query(query_data, k) + labels_hnsw_no_replace, _ = hnsw_index_no_replace.knn_query(query_data, k) + labels_bf, distances_bf = bf_index.knn_query(query_data, k) + + # Measure recall + correct_with_replace = 0 + correct_no_replace = 0 + for i in range(query_size): + for label in labels_hnsw_with_replace[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct_with_replace += 1 + break + for label in labels_hnsw_no_replace[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct_no_replace += 1 + break + + recall_with_replace = float(correct_with_replace) / (k*query_size) + recall_no_replace = float(correct_no_replace) / (k*query_size) + print("recall with replace:", recall_with_replace) + print("recall without replace:", recall_no_replace) + + recall_diff = abs(recall_with_replace - recall_with_replace) + + self.assertGreater(recall_no_replace, recall_threshold) + self.assertLess(recall_diff, max_recall_diff) diff --git a/python_bindings/tests/bindings_test_stress_mt_replace.py b/python_bindings/tests/bindings_test_stress_mt_replace.py new file mode 100644 index 00000000..8cd3e9bc --- /dev/null +++ b/python_bindings/tests/bindings_test_stress_mt_replace.py @@ -0,0 +1,68 @@ +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + dim = 16 + num_elements = 1_000 + max_num_elements = 2 * num_elements + + # Generating sample data + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + + for _ in range(100): + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) + hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + hnsw_index.set_ef(100) + hnsw_index.set_num_threads(50) + + # Add batch 1 and 2 + hnsw_index.add_items(data1, labels1) + hnsw_index.add_items(data2, labels2) # maximum number of elements is reached + + # Delete nearest neighbors of batch 2 + labels2_deleted, _ = hnsw_index.knn_query(data2, k=1) + labels2_deleted_flat = labels2_deleted.flatten() + # delete probable duplicates from nearest neighbors + labels2_deleted_no_dup = set(labels2_deleted_flat) + for l in labels2_deleted_no_dup: + hnsw_index.mark_deleted(l) + labels1_found, _ = hnsw_index.knn_query(data1, k=1) + items = hnsw_index.get_items(labels1_found) + diff_with_gt_labels = np.mean(np.abs(data1 - items)) + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) + + labels2_after, _ = hnsw_index.knn_query(data2, k=1) + labels2_after_flat = labels2_after.flatten() + common = np.intersect1d(labels2_after_flat, labels2_deleted_flat) + self.assertTrue(common.size == 0) + + # Replace deleted elements + # Maximum number of elements is reached therefore we cannot add new items + # but we can replace the deleted ones + # Note: there may be less than num_elements elements. + # As we could delete less than num_elements because of duplicates + num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup) + labels3_tr = labels3[0:labels3.shape[0] - num_duplicates] + data3_tr = data3[0:data3.shape[0] - num_duplicates] + hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True) From 28681fc1ab300ed588b5907d6e3b5a2ffb0b02a4 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sat, 14 Jan 2023 15:37:10 +0530 Subject: [PATCH 33/41] Getters for max elements, element count and num deleted. (#431) --- hnswlib/hnswalg.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 7f34e62b..d1597400 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -196,6 +196,17 @@ class HierarchicalNSW : public AlgorithmInterface { return (int) r; } + size_t getMaxElements() { + return max_elements_; + } + + size_t getCurrentElementCount() { + return cur_element_count; + } + + size_t getDeletedCount() { + return num_deleted_; + } std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, const void *data_point, int layer) { From 978f7137bc9555a1b61920f05d9d0d8252ca9169 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sat, 14 Jan 2023 15:37:37 +0530 Subject: [PATCH 34/41] Fix insufficient results during filtering. (#430) Very similar in nature to https://github.com/nmslib/hnswlib/pull/344 --- hnswlib/hnswalg.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index d1597400..bef00170 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -317,7 +317,8 @@ class HierarchicalNSW : public AlgorithmInterface { while (!candidate_set.empty()) { std::pair current_node_pair = candidate_set.top(); - if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) { + if ((-current_node_pair.first) > lowerBound && + (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { break; } candidate_set.pop(); From d86f8f941aa6eec2bbdddc51a3591d4ddd62d974 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sun, 15 Jan 2023 10:18:50 +0400 Subject: [PATCH 35/41] Refactoring of project structure (#432) * Refactor file structure, update readme and examples * Update Makefile * Update git tester * Remove redundant updates_test.cpp, apply suggested changes to example file * Return back python3 in Makefile --- .github/workflows/build.yml | 4 +- CMakeLists.txt | 12 +- Makefile | 2 +- README.md | 107 +-------- examples/EXAMPLES.md | 206 ++++++++++++++++++ examples/example.py | 9 +- examples/example_filter.py | 45 ++++ examples/example_replace_deleted.py | 55 +++++ examples/example_search.py | 41 ++++ ...xample_old.py => example_serialization.py} | 43 ++-- examples/pyw_hnswlib.py | 4 + python_bindings/tests/__init__.py | 0 .../cpp/download_bigann.py | 0 main.cpp => tests/cpp/main.cpp | 0 .../cpp}/multiThreadLoad_test.cpp | 2 +- .../cpp}/multiThread_replace_test.cpp | 2 +- .../cpp}/searchKnnCloserFirst_test.cpp | 2 +- .../cpp}/searchKnnWithFilter_test.cpp | 2 +- sift_1b.cpp => tests/cpp/sift_1b.cpp | 2 +- sift_test.cpp => tests/cpp/sift_test.cpp | 2 +- {examples => tests/cpp}/update_gen_data.py | 0 {examples => tests/cpp}/updates_test.cpp | 4 +- .../tests => tests/python}/bindings_test.py | 0 .../python}/bindings_test_filter.py | 0 .../python}/bindings_test_getdata.py | 0 .../python}/bindings_test_labels.py | 0 .../python}/bindings_test_metadata.py | 0 .../python}/bindings_test_pickle.py | 0 .../python}/bindings_test_recall.py | 0 .../python}/bindings_test_replace.py | 0 .../python}/bindings_test_resize.py | 0 .../python}/bindings_test_spaces.py | 0 .../bindings_test_stress_mt_replace.py | 0 {examples => tests/python}/git_tester.py | 4 +- {examples => tests/python}/speedtest.py | 0 35 files changed, 412 insertions(+), 136 deletions(-) create mode 100644 examples/EXAMPLES.md create mode 100644 examples/example_filter.py create mode 100644 examples/example_replace_deleted.py create mode 100644 examples/example_search.py rename examples/{example_old.py => example_serialization.py} (59%) delete mode 100644 python_bindings/tests/__init__.py rename download_bigann.py => tests/cpp/download_bigann.py (100%) rename main.cpp => tests/cpp/main.cpp (100%) rename {examples => tests/cpp}/multiThreadLoad_test.cpp (99%) rename {examples => tests/cpp}/multiThread_replace_test.cpp (99%) rename {examples => tests/cpp}/searchKnnCloserFirst_test.cpp (98%) rename {examples => tests/cpp}/searchKnnWithFilter_test.cpp (99%) rename sift_1b.cpp => tests/cpp/sift_1b.cpp (99%) rename sift_test.cpp => tests/cpp/sift_test.cpp (99%) rename {examples => tests/cpp}/update_gen_data.py (100%) rename {examples => tests/cpp}/updates_test.cpp (99%) rename {python_bindings/tests => tests/python}/bindings_test.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_filter.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_getdata.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_labels.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_metadata.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_pickle.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_recall.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_replace.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_resize.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_spaces.py (100%) rename {python_bindings/tests => tests/python}/bindings_test_stress_mt_replace.py (100%) rename {examples => tests/python}/git_tester.py (90%) rename {examples => tests/python}/speedtest.py (100%) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e86d2545..d45b8b33 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -20,7 +20,7 @@ jobs: - name: Test timeout-minutes: 15 - run: python -m unittest discover -v --start-directory python_bindings/tests --pattern "*_test*.py" + run: python -m unittest discover -v --start-directory tests/python --pattern "bindings_test*.py" test_cpp: runs-on: ${{matrix.os}} @@ -48,7 +48,7 @@ jobs: - name: Prepare test data run: | pip install numpy - cd examples + cd tests/cpp/ python update_gen_data.py shell: bash diff --git a/CMakeLists.txt b/CMakeLists.txt index de951171..9fcdcb73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,21 +16,21 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize" ) endif() - add_executable(test_updates examples/updates_test.cpp) + add_executable(test_updates tests/cpp/updates_test.cpp) target_link_libraries(test_updates hnswlib) - add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp) + add_executable(searchKnnCloserFirst_test tests/cpp/searchKnnCloserFirst_test.cpp) target_link_libraries(searchKnnCloserFirst_test hnswlib) - add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp) + add_executable(searchKnnWithFilter_test tests/cpp/searchKnnWithFilter_test.cpp) target_link_libraries(searchKnnWithFilter_test hnswlib) - add_executable(multiThreadLoad_test examples/multiThreadLoad_test.cpp) + add_executable(multiThreadLoad_test tests/cpp/multiThreadLoad_test.cpp) target_link_libraries(multiThreadLoad_test hnswlib) - add_executable(multiThread_replace_test examples/multiThread_replace_test.cpp) + add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp) target_link_libraries(multiThread_replace_test hnswlib) - add_executable(main main.cpp sift_1b.cpp) + add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp) target_link_libraries(main hnswlib) endif() diff --git a/Makefile b/Makefile index b5e8fda9..0de9c765 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ dist: python3 -m build --sdist test: - python3 -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py" + python3 -m unittest discover --start-directory tests/python --pattern "bindings_test*.py" clean: rm -rf *.egg-info build dist tmp var tests/__pycache__ hnswlib.cpython*.so diff --git a/README.md b/README.md index c0b0dbcc..04d84d66 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ Properties of `hnswlib.Index` that support reading and writing: #### Python bindings examples +[See more examples here](examples/EXAMPLES.md) ```python import hnswlib import numpy as np @@ -229,104 +230,6 @@ labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") ``` -An example with a filter: -```python -import hnswlib -import numpy as np - -dim = 16 -num_elements = 10000 - -# Generating sample data -data = np.float32(np.random.random((num_elements, dim))) - -# Declaring index -hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip - -# Initiating index -# max_elements - the maximum number of elements, should be known beforehand -# (probably will be made optional in the future) -# -# ef_construction - controls index search speed/build speed tradeoff -# M - is tightly connected with internal dimensionality of the data -# strongly affects the memory consumption - -hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) - -# Controlling the recall by setting ef: -# higher ef leads to better accuracy, but slower search -hnsw_index.set_ef(10) - -# Set number of threads used during batch search/construction -# By default using all available cores -hnsw_index.set_num_threads(4) - -print("Adding %d elements" % (len(data))) -# Added elements will have consecutive ids -hnsw_index.add_items(data, ids=np.arange(num_elements)) - -print("Querying only even elements") -# Define filter function that allows only even ids -filter_function = lambda idx: idx%2 == 0 -# Query the elements for themselves and search only for even elements: -labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) -# labels contain only elements with even id -``` - -An example with replacing of deleted elements: -```python -import hnswlib -import numpy as np - -dim = 16 -num_elements = 1_000 -max_num_elements = 2 * num_elements - -# Generating sample data -labels1 = np.arange(0, num_elements) -data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 -labels2 = np.arange(num_elements, 2 * num_elements) -data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 -labels3 = np.arange(2 * num_elements, 3 * num_elements) -data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 - -# Declaring index -hnsw_index = hnswlib.Index(space='l2', dim=dim) - -# Initiating index -# max_elements - the maximum number of elements, should be known beforehand -# (probably will be made optional in the future) -# -# ef_construction - controls index search speed/build speed tradeoff -# M - is tightly connected with internal dimensionality of the data -# strongly affects the memory consumption - -# Enable replacing of deleted elements -hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) - -# Controlling the recall by setting ef: -# higher ef leads to better accuracy, but slower search -hnsw_index.set_ef(10) - -# Set number of threads used during batch search/construction -# By default using all available cores -hnsw_index.set_num_threads(4) - -# Add batch 1 and 2 data -hnsw_index.add_items(data1, labels1) -hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached - -# Delete data of batch 2 -for label in labels2: - hnsw_index.mark_deleted(label) - -# Replace deleted elements -# Maximum number of elements is reached therefore we cannot add new items, -# but we can replace the deleted ones by using replace_deleted=True -hnsw_index.add_items(data3, labels3, replace_deleted=True) -# hnsw_index contains the data of batch 1 and batch 3 only -``` - ### Bindings installation You can install from sources: @@ -346,9 +249,9 @@ Contributions are highly welcome! Please make pull requests against the `develop` branch. -When making changes please run tests (and please add a test to `python_bindings/tests` in case there is new functionality): +When making changes please run tests (and please add a test to `tests/python` in case there is new functionality): ```bash -python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py" +python -m unittest discover --start-directory tests/python --pattern "bindings_test*.py" ``` @@ -373,7 +276,7 @@ https://github.com/dbaranchuk/ivf-hnsw ### 200M SIFT test reproduction To download and extract the bigann dataset (from root directory): ```bash -python3 download_bigann.py +python tests/cpp/download_bigann.py ``` To compile: ```bash @@ -393,7 +296,7 @@ The size of the BigANN subset (in millions) is controlled by the variable **subs ### Updates test To generate testing data (from root directory): ```bash -cd examples +cd tests/cpp python update_gen_data.py ``` To compile (from root directory): diff --git a/examples/EXAMPLES.md b/examples/EXAMPLES.md new file mode 100644 index 00000000..71f69ff4 --- /dev/null +++ b/examples/EXAMPLES.md @@ -0,0 +1,206 @@ +# Python bindings examples + +Creating index, inserting elements, searching and pickle serialization +```python +import hnswlib +import numpy as np +import pickle + +dim = 128 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) +ids = np.arange(num_elements) + +# Declaring index +p = hnswlib.Index(space = 'l2', dim = dim) # possible options are l2, cosine or ip + +# Initializing index - the maximum number of elements should be known beforehand +p.init_index(max_elements = num_elements, ef_construction = 200, M = 16) + +# Element insertion (can be called several times): +p.add_items(data, ids) + +# Controlling the recall by setting ef: +p.set_ef(50) # ef should always be > k + +# Query dataset, k - number of the closest elements (returns 2 numpy arrays) +labels, distances = p.knn_query(data, k = 1) + +# Index objects support pickling +# WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method! +# Note: ef parameter is included in serialization; random number generator is initialized with random_seed on Index load +p_copy = pickle.loads(pickle.dumps(p)) # creates a copy of index p using pickle round-trip + +### Index parameters are exposed as class properties: +print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim}") +print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") +print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") +print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") +``` + +An example with updates after serialization/deserialization: +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) + +# We split the data in two batches: +data1 = data[:num_elements // 2] +data2 = data[num_elements // 2:] + +# Declaring index +p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initializing index +# max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded +# during insertion of an element. +# The capacity can be increased by saving/loading the index, see below. +# +# ef_construction - controls index search speed/build speed tradeoff +# +# M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M) +# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction + +p.init_index(max_elements=num_elements//2, ef_construction=100, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +p.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +p.set_num_threads(4) + +print("Adding first batch of %d elements" % (len(data1))) +p.add_items(data1) + +# Query the elements for themselves and measure recall: +labels, distances = p.knn_query(data1, k=1) +print("Recall for the first batch:", np.mean(labels.reshape(-1) == np.arange(len(data1))), "\n") + +# Serializing and deleting the index: +index_path='first_half.bin' +print("Saving index to '%s'" % index_path) +p.save_index("first_half.bin") +del p + +# Re-initializing, loading the index +p = hnswlib.Index(space='l2', dim=dim) # the space can be changed - keeps the data, alters the distance function. + +print("\nLoading index from 'first_half.bin'\n") + +# Increase the total capacity (max_elements), so that it will handle the new data +p.load_index("first_half.bin", max_elements = num_elements) + +print("Adding the second batch of %d elements" % (len(data2))) +p.add_items(data2) + +# Query the elements for themselves and measure recall: +labels, distances = p.knn_query(data, k=1) +print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") +``` + +An example with a symbolic filter `filter_function` during the search:: +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +print("Adding %d elements" % (len(data))) +# Added elements will have consecutive ids +hnsw_index.add_items(data, ids=np.arange(num_elements)) + +print("Querying only even elements") +# Define filter function that allows only even ids +filter_function = lambda idx: idx%2 == 0 +# Query the elements for themselves and search only for even elements: +labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) +# labels contain only elements with even id +``` + +An example with reusing the memory of the deleted elements when new elements are being added (via `allow_replace_deleted` flag): +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 1_000 +max_num_elements = 2 * num_elements + +# Generating sample data +labels1 = np.arange(0, num_elements) +data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 +labels2 = np.arange(num_elements, 2 * num_elements) +data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 +labels3 = np.arange(2 * num_elements, 3 * num_elements) +data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +# Enable replacing of deleted elements +hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +# Add batch 1 and 2 data +hnsw_index.add_items(data1, labels1) +hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached + +# Delete data of batch 2 +for label in labels2: + hnsw_index.mark_deleted(label) + +# Replace deleted elements +# Maximum number of elements is reached therefore we cannot add new items, +# but we can replace the deleted ones by using replace_deleted=True +hnsw_index.add_items(data3, labels3, replace_deleted=True) +# hnsw_index contains the data of batch 1 and batch 3 only +``` \ No newline at end of file diff --git a/examples/example.py b/examples/example.py index a08955a1..3d6d7477 100644 --- a/examples/example.py +++ b/examples/example.py @@ -1,6 +1,12 @@ +import os import hnswlib import numpy as np + +""" +Example of index building, search and serialization/deserialization +""" + dim = 16 num_elements = 10000 @@ -34,7 +40,6 @@ # By default using all available cores p.set_num_threads(4) - print("Adding first batch of %d elements" % (len(data1))) p.add_items(data1) @@ -62,3 +67,5 @@ # Query the elements for themselves and measure recall: labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") + +os.remove("first_half.bin") diff --git a/examples/example_filter.py b/examples/example_filter.py new file mode 100644 index 00000000..10a059a8 --- /dev/null +++ b/examples/example_filter.py @@ -0,0 +1,45 @@ +import hnswlib +import numpy as np + + +""" +Example of filtering elements when searching +""" + +dim = 16 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +print("Adding %d elements" % (len(data))) +# Added elements will have consecutive ids +hnsw_index.add_items(data, ids=np.arange(num_elements)) + +print("Querying only even elements") +# Define filter function that allows only even ids +filter_function = lambda idx: idx%2 == 0 +# Query the elements for themselves and search only for even elements: +labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) +# labels contain only elements with even id diff --git a/examples/example_replace_deleted.py b/examples/example_replace_deleted.py new file mode 100644 index 00000000..3c0b62e7 --- /dev/null +++ b/examples/example_replace_deleted.py @@ -0,0 +1,55 @@ +import hnswlib +import numpy as np + + +""" +Example of replacing deleted elements with new ones +""" + +dim = 16 +num_elements = 1_000 +max_num_elements = 2 * num_elements + +# Generating sample data +labels1 = np.arange(0, num_elements) +data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 +labels2 = np.arange(num_elements, 2 * num_elements) +data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 +labels3 = np.arange(2 * num_elements, 3 * num_elements) +data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +# Enable replacing of deleted elements +hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +# Add batch 1 and 2 data +hnsw_index.add_items(data1, labels1) +hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached + +# Delete data of batch 2 +for label in labels2: + hnsw_index.mark_deleted(label) + +# Replace deleted elements +# Maximum number of elements is reached therefore we cannot add new items, +# but we can replace the deleted ones by using replace_deleted=True +hnsw_index.add_items(data3, labels3, replace_deleted=True) +# hnsw_index contains the data of batch 1 and batch 3 only diff --git a/examples/example_search.py b/examples/example_search.py new file mode 100644 index 00000000..4581843b --- /dev/null +++ b/examples/example_search.py @@ -0,0 +1,41 @@ +import hnswlib +import numpy as np +import pickle + + +""" +Example of search +""" + +dim = 128 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) +ids = np.arange(num_elements) + +# Declaring index +p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initializing index - the maximum number of elements should be known beforehand +p.init_index(max_elements=num_elements, ef_construction=200, M=16) + +# Element insertion (can be called several times): +p.add_items(data, ids) + +# Controlling the recall by setting ef: +p.set_ef(50) # ef should always be > k + +# Query dataset, k - number of the closest elements (returns 2 numpy arrays) +labels, distances = p.knn_query(data, k=1) + +# Index objects support pickling +# WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method! +# Note: ef parameter is included in serialization; random number generator is initialized with random_seed on Index load +p_copy = pickle.loads(pickle.dumps(p)) # creates a copy of index p using pickle round-trip + +### Index parameters are exposed as class properties: +print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim}") +print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") +print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") +print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") diff --git a/examples/example_old.py b/examples/example_serialization.py similarity index 59% rename from examples/example_old.py rename to examples/example_serialization.py index 6654a027..76ca1436 100644 --- a/examples/example_old.py +++ b/examples/example_serialization.py @@ -1,34 +1,45 @@ +import os + import hnswlib import numpy as np + +""" +Example of serialization/deserialization +""" + dim = 16 num_elements = 10000 # Generating sample data data = np.float32(np.random.random((num_elements, dim))) +# We split the data in two batches: +data1 = data[:num_elements // 2] +data2 = data[num_elements // 2:] + # Declaring index p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip -# Initing index -# max_elements - the maximum number of elements, should be known beforehand -# (probably will be made optional in the future) +# Initializing index +# max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded +# during insertion of an element. +# The capacity can be increased by saving/loading the index, see below. # # ef_construction - controls index search speed/build speed tradeoff -# M - is tightly connected with internal dimensionality of the data -# stronlgy affects the memory consumption +# +# M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M) +# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction -p.init_index(max_elements=num_elements, ef_construction=100, M=16) +p.init_index(max_elements=num_elements//2, ef_construction=100, M=16) # Controlling the recall by setting ef: # higher ef leads to better accuracy, but slower search p.set_ef(10) -p.set_num_threads(4) # by default using all available cores - -# We split the data in two batches: -data1 = data[:num_elements // 2] -data2 = data[num_elements // 2:] +# Set number of threads used during batch search/construction +# By default using all available cores +p.set_num_threads(4) print("Adding first batch of %d elements" % (len(data1))) p.add_items(data1) @@ -43,11 +54,13 @@ p.save_index("first_half.bin") del p -# Reiniting, loading the index -p = hnswlib.Index(space='l2', dim=dim) # you can change the sa +# Re-initializing, loading the index +p = hnswlib.Index(space='l2', dim=dim) # the space can be changed - keeps the data, alters the distance function. print("\nLoading index from 'first_half.bin'\n") -p.load_index("first_half.bin") + +# Increase the total capacity (max_elements), so that it will handle the new data +p.load_index("first_half.bin", max_elements = num_elements) print("Adding the second batch of %d elements" % (len(data2))) p.add_items(data2) @@ -55,3 +68,5 @@ # Query the elements for themselves and measure recall: labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") + +os.remove("first_half.bin") diff --git a/examples/pyw_hnswlib.py b/examples/pyw_hnswlib.py index aeb93f10..0ccfbc5e 100644 --- a/examples/pyw_hnswlib.py +++ b/examples/pyw_hnswlib.py @@ -4,6 +4,10 @@ import pickle +""" +Example of python wrapper for hnswlib that supports python objects as ids +""" + class Index(): def __init__(self, space, dim): self.index = hnswlib.Index(space, dim) diff --git a/python_bindings/tests/__init__.py b/python_bindings/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/download_bigann.py b/tests/cpp/download_bigann.py similarity index 100% rename from download_bigann.py rename to tests/cpp/download_bigann.py diff --git a/main.cpp b/tests/cpp/main.cpp similarity index 100% rename from main.cpp rename to tests/cpp/main.cpp diff --git a/examples/multiThreadLoad_test.cpp b/tests/cpp/multiThreadLoad_test.cpp similarity index 99% rename from examples/multiThreadLoad_test.cpp rename to tests/cpp/multiThreadLoad_test.cpp index a713b2ba..4d2b4aa2 100644 --- a/examples/multiThreadLoad_test.cpp +++ b/tests/cpp/multiThreadLoad_test.cpp @@ -1,4 +1,4 @@ -#include "../hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include #include diff --git a/examples/multiThread_replace_test.cpp b/tests/cpp/multiThread_replace_test.cpp similarity index 99% rename from examples/multiThread_replace_test.cpp rename to tests/cpp/multiThread_replace_test.cpp index 83ed2826..203cdb0d 100644 --- a/examples/multiThread_replace_test.cpp +++ b/tests/cpp/multiThread_replace_test.cpp @@ -1,4 +1,4 @@ -#include "../hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include #include diff --git a/examples/searchKnnCloserFirst_test.cpp b/tests/cpp/searchKnnCloserFirst_test.cpp similarity index 98% rename from examples/searchKnnCloserFirst_test.cpp rename to tests/cpp/searchKnnCloserFirst_test.cpp index d87102cd..9583fe22 100644 --- a/examples/searchKnnCloserFirst_test.cpp +++ b/tests/cpp/searchKnnCloserFirst_test.cpp @@ -3,7 +3,7 @@ // >>> searchKnnCloserFirst(const void* query_data, size_t k) const; // of class AlgorithmInterface -#include "../hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include diff --git a/examples/searchKnnWithFilter_test.cpp b/tests/cpp/searchKnnWithFilter_test.cpp similarity index 99% rename from examples/searchKnnWithFilter_test.cpp rename to tests/cpp/searchKnnWithFilter_test.cpp index 6102323c..0557b7e4 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/tests/cpp/searchKnnWithFilter_test.cpp @@ -1,6 +1,6 @@ // This is a test file for testing the filtering feature -#include "../hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include diff --git a/sift_1b.cpp b/tests/cpp/sift_1b.cpp similarity index 99% rename from sift_1b.cpp rename to tests/cpp/sift_1b.cpp index 96d83267..43777ff6 100644 --- a/sift_1b.cpp +++ b/tests/cpp/sift_1b.cpp @@ -2,7 +2,7 @@ #include #include #include -#include "hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include diff --git a/sift_test.cpp b/tests/cpp/sift_test.cpp similarity index 99% rename from sift_test.cpp rename to tests/cpp/sift_test.cpp index 751580cb..decdf605 100644 --- a/sift_test.cpp +++ b/tests/cpp/sift_test.cpp @@ -2,7 +2,7 @@ #include #include #include -#include "hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include diff --git a/examples/update_gen_data.py b/tests/cpp/update_gen_data.py similarity index 100% rename from examples/update_gen_data.py rename to tests/cpp/update_gen_data.py diff --git a/examples/updates_test.cpp b/tests/cpp/updates_test.cpp similarity index 99% rename from examples/updates_test.cpp rename to tests/cpp/updates_test.cpp index 8e4ac644..52e1fa14 100644 --- a/examples/updates_test.cpp +++ b/tests/cpp/updates_test.cpp @@ -1,4 +1,4 @@ -#include "../hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include @@ -193,7 +193,7 @@ int main(int argc, char **argv) { exit(1); } - std::string path = "../examples/data/"; + std::string path = "../tests/cpp/data/"; int N; int dummy_data_multiplier; diff --git a/python_bindings/tests/bindings_test.py b/tests/python/bindings_test.py similarity index 100% rename from python_bindings/tests/bindings_test.py rename to tests/python/bindings_test.py diff --git a/python_bindings/tests/bindings_test_filter.py b/tests/python/bindings_test_filter.py similarity index 100% rename from python_bindings/tests/bindings_test_filter.py rename to tests/python/bindings_test_filter.py diff --git a/python_bindings/tests/bindings_test_getdata.py b/tests/python/bindings_test_getdata.py similarity index 100% rename from python_bindings/tests/bindings_test_getdata.py rename to tests/python/bindings_test_getdata.py diff --git a/python_bindings/tests/bindings_test_labels.py b/tests/python/bindings_test_labels.py similarity index 100% rename from python_bindings/tests/bindings_test_labels.py rename to tests/python/bindings_test_labels.py diff --git a/python_bindings/tests/bindings_test_metadata.py b/tests/python/bindings_test_metadata.py similarity index 100% rename from python_bindings/tests/bindings_test_metadata.py rename to tests/python/bindings_test_metadata.py diff --git a/python_bindings/tests/bindings_test_pickle.py b/tests/python/bindings_test_pickle.py similarity index 100% rename from python_bindings/tests/bindings_test_pickle.py rename to tests/python/bindings_test_pickle.py diff --git a/python_bindings/tests/bindings_test_recall.py b/tests/python/bindings_test_recall.py similarity index 100% rename from python_bindings/tests/bindings_test_recall.py rename to tests/python/bindings_test_recall.py diff --git a/python_bindings/tests/bindings_test_replace.py b/tests/python/bindings_test_replace.py similarity index 100% rename from python_bindings/tests/bindings_test_replace.py rename to tests/python/bindings_test_replace.py diff --git a/python_bindings/tests/bindings_test_resize.py b/tests/python/bindings_test_resize.py similarity index 100% rename from python_bindings/tests/bindings_test_resize.py rename to tests/python/bindings_test_resize.py diff --git a/python_bindings/tests/bindings_test_spaces.py b/tests/python/bindings_test_spaces.py similarity index 100% rename from python_bindings/tests/bindings_test_spaces.py rename to tests/python/bindings_test_spaces.py diff --git a/python_bindings/tests/bindings_test_stress_mt_replace.py b/tests/python/bindings_test_stress_mt_replace.py similarity index 100% rename from python_bindings/tests/bindings_test_stress_mt_replace.py rename to tests/python/bindings_test_stress_mt_replace.py diff --git a/examples/git_tester.py b/tests/python/git_tester.py similarity index 90% rename from examples/git_tester.py rename to tests/python/git_tester.py index be3b8a25..5a97f3dd 100644 --- a/examples/git_tester.py +++ b/tests/python/git_tester.py @@ -5,8 +5,8 @@ from pydriller import Repository -speedtest_src_path = os.path.join("examples", "speedtest.py") -speedtest_copy_path = os.path.join("examples", "speedtest2.py") +speedtest_src_path = os.path.join("tests", "python", "speedtest.py") +speedtest_copy_path = os.path.join("tests", "python", "speedtest2.py") shutil.copyfile(speedtest_src_path, speedtest_copy_path) # the file has to be outside of git commits = list(Repository('.', from_tag="v0.6.0").traverse_commits()) diff --git a/examples/speedtest.py b/tests/python/speedtest.py similarity index 100% rename from examples/speedtest.py rename to tests/python/speedtest.py From 225b519e9054cddff7f6f9ce1bca08bb225693d8 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sat, 14 Jan 2023 19:09:33 +0400 Subject: [PATCH 36/41] Add warning that python filter works slow in multi-threaded mode --- python_bindings/bindings.cpp | 4 ++++ tests/python/bindings_test_filter.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 3196a228..3f228832 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -611,6 +611,10 @@ class Index { if (num_threads <= 0) num_threads = num_threads_default; + if ((filter != nullptr) && (num_threads != 1)) { + std::cout << "Warning: search with python filter works slow in multi-threaded mode. For best performance set num_threads=1\n"; + } + { py::gil_scoped_release l; get_input_array_shapes(buffer, &rows, &features); diff --git a/tests/python/bindings_test_filter.py b/tests/python/bindings_test_filter.py index a798e02f..ecb79ab9 100644 --- a/tests/python/bindings_test_filter.py +++ b/tests/python/bindings_test_filter.py @@ -47,7 +47,8 @@ def testRandomSelf(self): print("Querying only even elements") # Query the even elements for themselves and measure recall: filter_function = lambda id: id%2 == 0 - labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) + # Search with python filter works slow in multi-threaded mode, therefore we set num_threads=1 + labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function, num_threads=1) self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3) # Verify that there are only even elements: self.assertTrue(np.max(np.mod(labels, 2)) == 0) From 32f4b02def881a48565a9d51a4e7332a1e24b778 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sun, 15 Jan 2023 12:07:08 +0400 Subject: [PATCH 37/41] Add comments with warnings that filter works slow in python in multithreaded mode. Add example files to CI test. --- .github/workflows/build.yml | 4 +++- README.md | 2 +- examples/EXAMPLES.md | 3 ++- examples/example_filter.py | 3 ++- python_bindings/bindings.cpp | 5 +---- tests/python/bindings_test_filter.py | 4 ++-- 6 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d45b8b33..f2662c15 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -20,7 +20,9 @@ jobs: - name: Test timeout-minutes: 15 - run: python -m unittest discover -v --start-directory tests/python --pattern "bindings_test*.py" + run: | + python -m unittest discover -v --start-directory examples --pattern "example*.py" + python -m unittest discover -v --start-directory tests/python --pattern "bindings_test*.py" test_cpp: runs-on: ${{matrix.os}} diff --git a/README.md b/README.md index 04d84d66..2b027216 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib. * `knn_query(data, k = 1, num_threads = -1, filter = None)` make a batch query for `k` closest elements for each element of the * `data` (shape:`N*dim`). Returns a numpy array of (shape:`N*k`). * `num_threads` sets the number of cpu threads to use (-1 means use default). - * `filter` filters elements by its labels, returns elements with allowed ids + * `filter` filters elements by its labels, returns elements with allowed ids. Note that search with a filter works slow in python in multithreaded mode. It is recommended to set `num_threads=1` * Thread-safe with other `knn_query` calls, but not with `add_items`. * `load_index(path_to_index, max_elements = 0, allow_replace_deleted = False)` loads the index from persistence to the uninitialized index. diff --git a/examples/EXAMPLES.md b/examples/EXAMPLES.md index 71f69ff4..a92f3626 100644 --- a/examples/EXAMPLES.md +++ b/examples/EXAMPLES.md @@ -147,7 +147,8 @@ print("Querying only even elements") # Define filter function that allows only even ids filter_function = lambda idx: idx%2 == 0 # Query the elements for themselves and search only for even elements: -labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) +# Warning: search with python filter works slow in multithreaded mode, therefore we set num_threads=1 +labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function) # labels contain only elements with even id ``` diff --git a/examples/example_filter.py b/examples/example_filter.py index 10a059a8..add22a3d 100644 --- a/examples/example_filter.py +++ b/examples/example_filter.py @@ -41,5 +41,6 @@ # Define filter function that allows only even ids filter_function = lambda idx: idx%2 == 0 # Query the elements for themselves and search only for even elements: -labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) +# Warning: search with a filter works slow in python in multithreaded mode, therefore we set num_threads=1 +labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function) # labels contain only elements with even id diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 3f228832..5153bb58 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -611,10 +611,6 @@ class Index { if (num_threads <= 0) num_threads = num_threads_default; - if ((filter != nullptr) && (num_threads != 1)) { - std::cout << "Warning: search with python filter works slow in multi-threaded mode. For best performance set num_threads=1\n"; - } - { py::gil_scoped_release l; get_input_array_shapes(buffer, &rows, &features); @@ -627,6 +623,7 @@ class Index { data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; + // Warning: search with a filter works slow in python in multithreaded mode. For best performance set num_threads=1 CustomFilterFunctor idFilter(filter); CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; diff --git a/tests/python/bindings_test_filter.py b/tests/python/bindings_test_filter.py index ecb79ab9..480c8dcd 100644 --- a/tests/python/bindings_test_filter.py +++ b/tests/python/bindings_test_filter.py @@ -47,8 +47,8 @@ def testRandomSelf(self): print("Querying only even elements") # Query the even elements for themselves and measure recall: filter_function = lambda id: id%2 == 0 - # Search with python filter works slow in multi-threaded mode, therefore we set num_threads=1 - labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function, num_threads=1) + # Warning: search with a filter works slow in python in multithreaded mode, therefore we set num_threads=1 + labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function) self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3) # Verify that there are only even elements: self.assertTrue(np.max(np.mod(labels, 2)) == 0) From dd266bca950e554bc7d2a6b633a66c9246838b18 Mon Sep 17 00:00:00 2001 From: Yury Date: Sun, 15 Jan 2023 15:38:38 -0800 Subject: [PATCH 38/41] preliminary release notes --- README.md | 39 +++++++++++++++------------------------ setup.py | 2 +- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 2b027216..98bb0fb4 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,22 @@ # Hnswlib - fast approximate nearest neighbor search -Header-only C++ HNSW implementation with python bindings. +Header-only C++ HNSW implementation with python bindings, insertions and updates. **NEWS:** +**version 0.7.0** -**version 0.6.2** - -* Fixed a bug in saving of large pickles. The pickles with > 4GB could have been corrupted. Thanks Kai Wohlfahrt for reporting. -* Thanks to ([@GuyAv46](https://github.com/GuyAv46)) hnswlib inner product now is more consitent accross architectures (SSE, AVX, etc). -* - -**version 0.6.1** - -* Thanks to ([@tony-kuo](https://github.com/tony-kuo)) hnswlib AVX512 and AVX builds are not backwards-compatible with older SSE and non-AVX512 architectures. -* Thanks to ([@psobot](https://github.com/psobot)) there is now a sencible message instead of segfault when passing a scalar to get_items. -* Thanks to ([@urigoren](https://github.com/urigoren)) hnswlib has a lazy index creation python wrapper. - -**version 0.6.0** -* Thanks to ([@dyashuni](https://github.com/dyashuni)) hnswlib now uses github actions for CI, there is a search speedup in some scenarios with deletions. `unmark_deleted(label)` is now also a part of the python interface (note now it throws an exception for double deletions). -* Thanks to ([@slice4e](https://github.com/slice4e)) we now support AVX512; thanks to ([@LTLA](https://github.com/LTLA)) the cmake interface for the lib is now updated. -* Thanks to ([@alonre24](https://github.com/alonre24)) we now have a python bindings for brute-force (and examples for recall tuning: [TESTING_RECALL.md](TESTING_RECALL.md). -* Thanks to ([@dorosy-yeong](https://github.com/dorosy-yeong)) there is a bug fixed in the handling large quantities of deleted elements and large K. - - +* Added support to filtering (#402, #430) by [@kishorenc](https://github.com/kishorenc) +* Added python interface for filtering (though note its performance is limited by GIL) (#417) by [@gtsoukas](https://github.com/gtsoukas) +* Added support for replacing the elements that were market as delete with newly inserted elements (to control the size of the index, #418) by [@dyashuni](https://github.com/dyashuni) +* Fixed data races/deadlocks in updates/insertion, added stress test for multithreaded operation (#418) by [@dyashuni](https://github.com/dyashuni) +* Documentation, tests, exception handling, refactoring (#375, #379, #380, #395, #396, #401, #406, #404, #409, #410, #416, #415, #431, #432, #433) by [@jlmelville](https://github.com/jlmelville), [@dyashuni](https://github.com/dyashuni), [@kishorenc](https://github.com/kishorenc), [@korzhenevski](https://github.com/korzhenevski), [@yoshoku](https://github.com/yoshoku), [@jianshu93](https://github.com/jianshu93), [@PLNech](https://github.com/PLNech) +* global linkages (#383) by [@MasterAler](https://github.com/MasterAler), USE_SSE usage in MSVC (#408) by [@alxvth](https://github.com/alxvth) ### Highlights: 1) Lightweight, header-only, no dependencies other than C++ 11 -2) Interfaces for C++, Java, Python and R (https://github.com/jlmelville/rcpphnsw). -3) Has full support for incremental index construction. Has support for element deletions +2) Interfaces for C++, Python, external support for Java and R (https://github.com/jlmelville/rcpphnsw). +3) Has full support for incremental index construction and updating the elements. Has support for element deletions (by marking them in index). Index is picklable. 4) Can work with custom user defined distances (C++). 5) Significantly less memory footprint and faster build time compared to current nmslib's implementation. @@ -50,7 +38,7 @@ Note that inner product is not an actual metric. An element can be closer to som For other spaces use the nmslib library https://github.com/nmslib/nmslib. -#### Short API description +#### API description * `hnswlib.Index(space, dim)` creates a non-initialized index an HNSW in space `space` with integer dimension `dim`. `hnswlib.Index` methods: @@ -263,14 +251,17 @@ https://github.com/facebookresearch/faiss ["Revisiting the Inverted Indices for Billion-Scale Approximate Nearest Neighbors"](https://arxiv.org/abs/1802.02422) (current state-of-the-art in compressed indexes, C++): https://github.com/dbaranchuk/ivf-hnsw +* Amazon PECOS https://github.com/amzn/pecos * TOROS N2 (python, C++): https://github.com/kakao/n2 * Online HNSW (C++): https://github.com/andrusha97/online-hnsw) * Go implementation: https://github.com/Bithack/go-hnsw * Python implementation (as a part of the clustering code by by Matteo Dell'Amico): https://github.com/matteodellamico/flexible-clustering +* Julia implmentation https://github.com/JuliaNeighbors/HNSW.jl * Java implementation: https://github.com/jelmerk/hnswlib * Java bindings using Java Native Access: https://github.com/stepstone-tech/hnswlib-jna -* .Net implementation: https://github.com/microsoft/HNSW.Net +* .Net implementation: https://github.com/curiosity-ai/hnsw-sharp * CUDA implementation: https://github.com/js1010/cuhnsw +* Rust implementation https://github.com/rust-cv/hnsw * Rust implementation for memory and thread safety purposes and There is A Trait to enable the user to implement its own distances. It takes as data slices of types T satisfying T:Serialize+Clone+Send+Sync.: https://github.com/jean-pierreBoth/hnswlib-rs ### 200M SIFT test reproduction diff --git a/setup.py b/setup.py index 161886fd..0126585e 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -__version__ = '0.6.1' +__version__ = '0.7.0' include_dirs = [ From d35f4288b2912629f86bdf060f80cd10a6f5a95d Mon Sep 17 00:00:00 2001 From: Yury Date: Mon, 16 Jan 2023 19:13:23 -0800 Subject: [PATCH 39/41] Add construction speed logging --- tests/python/git_tester.py | 26 ++++++++++++++----- tests/python/speedtest.py | 53 ++++++++++++++++++++------------------ 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/tests/python/git_tester.py b/tests/python/git_tester.py index 5a97f3dd..1f9c2ba7 100644 --- a/tests/python/git_tester.py +++ b/tests/python/git_tester.py @@ -9,19 +9,29 @@ speedtest_copy_path = os.path.join("tests", "python", "speedtest2.py") shutil.copyfile(speedtest_src_path, speedtest_copy_path) # the file has to be outside of git -commits = list(Repository('.', from_tag="v0.6.0").traverse_commits()) +commits = list(Repository('.', from_tag="v0.6.2").traverse_commits()) print("Found commits:") for idx, commit in enumerate(commits): name = commit.msg.replace('\n', ' ').replace('\r', ' ') print(idx, commit.hash, name) for commit in commits: - name = commit.msg.replace('\n', ' ').replace('\r', ' ') + name = commit.msg.replace('\n', ' ').replace('\r', ' ').replace(",", ";") print("\nProcessing", commit.hash, name) if os.path.exists("build"): shutil.rmtree("build") os.system(f"git checkout {commit.hash}") + + # Checking we have actually switched the branch: + current_commit=list(Repository('.').traverse_commits())[-1] + if current_commit.hash != commit.hash: + print("git checkout failed!!!!") + print("git checkout failed!!!!") + print("git checkout failed!!!!") + print("git checkout failed!!!!") + continue + print("\n\n--------------------\n\n") ret = os.system("python -m pip install .") print("Install result:", ret) @@ -33,8 +43,10 @@ print("build failed!!!!") continue - os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 1') - os.system(f'python {speedtest_copy_path} -n "{name}" -d 64 -t 1') - os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 1') - os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 24') - os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 24') + # os.system(f'python {speedtest_copy_path} -n "{hash[:4]}_{name}" -d 32 -t 1') + os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 16 -t 1') + os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 16 -t 64') + # os.system(f'python {speedtest_copy_path} -n "{name}" -d 64 -t 1') + # os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 1') + # os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 24') + # os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 24') diff --git a/tests/python/speedtest.py b/tests/python/speedtest.py index cf8e6085..8d16cfc3 100644 --- a/tests/python/speedtest.py +++ b/tests/python/speedtest.py @@ -13,50 +13,53 @@ dim = int(args.d) name = args.n threads=int(args.t) -num_elements = 1000000 * 4//dim +num_elements = 400000 # Generating sample data np.random.seed(1) data = np.float32(np.random.random((num_elements, dim))) -index_path=f'speed_index{dim}.bin' +# index_path=f'speed_index{dim}.bin' # Declaring index p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip -if not os.path.isfile(index_path) : +# if not os.path.isfile(index_path) : - p.init_index(max_elements=num_elements, ef_construction=100, M=16) +p.init_index(max_elements=num_elements, ef_construction=60, M=16) - # Controlling the recall by setting ef: - # higher ef leads to better accuracy, but slower search - p.set_ef(10) +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +p.set_ef(10) - # Set number of threads used during batch search/construction - # By default using all available cores - p.set_num_threads(12) +# Set number of threads used during batch search/construction +# By default using all available cores +p.set_num_threads(64) +t0=time.time() +p.add_items(data) +construction_time=time.time()-t0 +# Serializing and deleting the index: - p.add_items(data) - - # Serializing and deleting the index: - - print("Saving index to '%s'" % index_path) - p.save_index(index_path) +# print("Saving index to '%s'" % index_path) +# p.save_index(index_path) p.set_num_threads(threads) times=[] -time.sleep(10) -p.set_ef(100) -for _ in range(3): - p.load_index(index_path) - for _ in range(10): +time.sleep(1) +p.set_ef(15) +for _ in range(1): + # p.load_index(index_path) + for _ in range(3): t0=time.time() - labels, distances = p.knn_query(data, k=1) + qdata=data[:5000*threads] + labels, distances = p.knn_query(qdata, k=1) tt=time.time()-t0 times.append(tt) - print(f"{tt} seconds") -str_out=f"mean time:{np.mean(times)}, median time:{np.median(times)}, std time {np.std(times)} {name}" + recall=np.sum(labels.reshape(-1)==np.arange(len(qdata)))/len(qdata) + print(f"{tt} seconds, recall= {recall}") + +str_out=f"{np.mean(times)}, {np.median(times)}, {np.std(times)}, {construction_time}, {recall}, {name}" print(str_out) -with open (f"log_{dim}_t{threads}.txt","a") as f: +with open (f"log2_{dim}_t{threads}.txt","a") as f: f.write(str_out+"\n") f.flush() From 68a3387b516fa9e40ecbe0cf7ca76a72df35f07e Mon Sep 17 00:00:00 2001 From: Yury Malkov Date: Wed, 18 Jan 2023 10:50:50 -0800 Subject: [PATCH 40/41] fix a misprint Co-authored-by: drons --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 98bb0fb4..f56d490e 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Header-only C++ HNSW implementation with python bindings, insertions and updates * Added support to filtering (#402, #430) by [@kishorenc](https://github.com/kishorenc) * Added python interface for filtering (though note its performance is limited by GIL) (#417) by [@gtsoukas](https://github.com/gtsoukas) -* Added support for replacing the elements that were market as delete with newly inserted elements (to control the size of the index, #418) by [@dyashuni](https://github.com/dyashuni) +* Added support for replacing the elements that were marked as delete with newly inserted elements (to control the size of the index, #418) by [@dyashuni](https://github.com/dyashuni) * Fixed data races/deadlocks in updates/insertion, added stress test for multithreaded operation (#418) by [@dyashuni](https://github.com/dyashuni) * Documentation, tests, exception handling, refactoring (#375, #379, #380, #395, #396, #401, #406, #404, #409, #410, #416, #415, #431, #432, #433) by [@jlmelville](https://github.com/jlmelville), [@dyashuni](https://github.com/dyashuni), [@kishorenc](https://github.com/kishorenc), [@korzhenevski](https://github.com/korzhenevski), [@yoshoku](https://github.com/yoshoku), [@jianshu93](https://github.com/jianshu93), [@PLNech](https://github.com/PLNech) * global linkages (#383) by [@MasterAler](https://github.com/MasterAler), USE_SSE usage in MSVC (#408) by [@alxvth](https://github.com/alxvth) From 488ab52e395dc40b18b19aa31cf045c327f8548b Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Mon, 30 Jan 2023 10:17:09 +0400 Subject: [PATCH 41/41] Add cpp examples (#435) * Add cpp examples * Add multithreaded cpp examples --- .github/workflows/build.yml | 8 +- ALGO_PARAMS.md | 2 +- CMakeLists.txt | 20 ++ README.md | 15 +- examples/cpp/EXAMPLES.md | 185 ++++++++++++++++++ examples/cpp/example_filter.cpp | 57 ++++++ examples/cpp/example_mt_filter.cpp | 124 ++++++++++++ examples/cpp/example_mt_replace_deleted.cpp | 114 +++++++++++ examples/cpp/example_mt_search.cpp | 107 ++++++++++ examples/cpp/example_replace_deleted.cpp | 54 +++++ examples/cpp/example_search.cpp | 58 ++++++ examples/{ => python}/EXAMPLES.md | 4 +- examples/{ => python}/example.py | 0 examples/{ => python}/example_filter.py | 0 .../{ => python}/example_replace_deleted.py | 0 examples/{ => python}/example_search.py | 0 .../{ => python}/example_serialization.py | 0 examples/{ => python}/pyw_hnswlib.py | 0 18 files changed, 743 insertions(+), 5 deletions(-) create mode 100644 examples/cpp/EXAMPLES.md create mode 100644 examples/cpp/example_filter.cpp create mode 100644 examples/cpp/example_mt_filter.cpp create mode 100644 examples/cpp/example_mt_replace_deleted.cpp create mode 100644 examples/cpp/example_mt_search.cpp create mode 100644 examples/cpp/example_replace_deleted.cpp create mode 100644 examples/cpp/example_search.cpp rename examples/{ => python}/EXAMPLES.md (99%) rename examples/{ => python}/example.py (100%) rename examples/{ => python}/example_filter.py (100%) rename examples/{ => python}/example_replace_deleted.py (100%) rename examples/{ => python}/example_search.py (100%) rename examples/{ => python}/example_serialization.py (100%) rename examples/{ => python}/pyw_hnswlib.py (100%) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f2662c15..8cfa469a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,7 +21,7 @@ jobs: - name: Test timeout-minutes: 15 run: | - python -m unittest discover -v --start-directory examples --pattern "example*.py" + python -m unittest discover -v --start-directory examples/python --pattern "example*.py" python -m unittest discover -v --start-directory tests/python --pattern "bindings_test*.py" test_cpp: @@ -61,6 +61,12 @@ jobs: if [ "$RUNNER_OS" == "Windows" ]; then cp ./Release/* ./ fi + ./example_search + ./example_filter + ./example_replace_deleted + ./example_mt_search + ./example_mt_filter + ./example_mt_replace_deleted ./searchKnnCloserFirst_test ./searchKnnWithFilter_test ./multiThreadLoad_test diff --git a/ALGO_PARAMS.md b/ALGO_PARAMS.md index b0a6b7ad..0d5133f3 100644 --- a/ALGO_PARAMS.md +++ b/ALGO_PARAMS.md @@ -27,5 +27,5 @@ ef_construction leads to longer construction, but better index quality. At some not improve the quality of the index. One way to check if the selection of ef_construction was ok is to measure a recall for M nearest neighbor search when ```ef``` =```ef_construction```: if the recall is lower than 0.9, than there is room for improvement. -* ```num_elements``` - defines the maximum number of elements in the index. The index can be extened by saving/loading(load_index +* ```num_elements``` - defines the maximum number of elements in the index. The index can be extended by saving/loading (load_index function has a parameter which defines the new maximum number of elements). diff --git a/CMakeLists.txt b/CMakeLists.txt index 9fcdcb73..7cebe600 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,26 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize" ) endif() + # examples + add_executable(example_search examples/cpp/example_search.cpp) + target_link_libraries(example_search hnswlib) + + add_executable(example_filter examples/cpp/example_filter.cpp) + target_link_libraries(example_filter hnswlib) + + add_executable(example_replace_deleted examples/cpp/example_replace_deleted.cpp) + target_link_libraries(example_replace_deleted hnswlib) + + add_executable(example_mt_search examples/cpp/example_mt_search.cpp) + target_link_libraries(example_mt_search hnswlib) + + add_executable(example_mt_filter examples/cpp/example_mt_filter.cpp) + target_link_libraries(example_mt_filter hnswlib) + + add_executable(example_mt_replace_deleted examples/cpp/example_mt_replace_deleted.cpp) + target_link_libraries(example_mt_replace_deleted hnswlib) + + # tests add_executable(test_updates tests/cpp/updates_test.cpp) target_link_libraries(test_updates hnswlib) diff --git a/README.md b/README.md index f56d490e..3ed466a7 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,12 @@ Properties of `hnswlib.Index` that support reading and writing: #### Python bindings examples -[See more examples here](examples/EXAMPLES.md) +[See more examples here](examples/python/EXAMPLES.md): +* Creating index, inserting elements, searching, serialization/deserialization +* Filtering during the search with a boolean function +* Deleting the elements and reusing the memory of the deleted elements for newly added elements + +An example of creating index, inserting elements, searching and pickle serialization: ```python import hnswlib import numpy as np @@ -218,6 +223,14 @@ labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") ``` +#### C++ examples +[See examples here](examples/cpp/EXAMPLES.md): +* creating index, inserting elements, searching, serialization/deserialization +* filtering during the search with a boolean function +* deleting the elements and reusing the memory of the deleted elements for newly added elements +* multithreaded usage + + ### Bindings installation You can install from sources: diff --git a/examples/cpp/EXAMPLES.md b/examples/cpp/EXAMPLES.md new file mode 100644 index 00000000..3af603d4 --- /dev/null +++ b/examples/cpp/EXAMPLES.md @@ -0,0 +1,185 @@ +# C++ examples + +Creating index, inserting elements, searching and serialization +```cpp +#include "../../hnswlib/hnswlib.h" + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Query the elements for themselves and measure recall + float correct = 0; + for (int i = 0; i < max_elements; i++) { + std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); + hnswlib::labeltype label = result.top().second; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + + // Serialize index + std::string hnsw_path = "hnsw.bin"; + alg_hnsw->saveIndex(hnsw_path); + delete alg_hnsw; + + // Deserialize index and check recall + alg_hnsw = new hnswlib::HierarchicalNSW(&space, hnsw_path); + correct = 0; + for (int i = 0; i < max_elements; i++) { + std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); + hnswlib::labeltype label = result.top().second; + if (label == i) correct++; + } + recall = (float)correct / max_elements; + std::cout << "Recall of deserialized index: " << recall << "\n"; + + delete[] data; + delete alg_hnsw; + return 0; +} +``` + +An example of filtering with a boolean function during the search: +```cpp +#include "../../hnswlib/hnswlib.h" + + +// Filter that allows labels divisible by divisor +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(hnswlib::labeltype label_id) { + return label_id % divisor == 0; + } +}; + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Create filter that allows only even labels + PickDivisibleIds pickIdsDivisibleByTwo(2); + + // Query the elements for themselves with filter and check returned labels + int k = 10; + for (int i = 0; i < max_elements; i++) { + std::vector> result = alg_hnsw->searchKnnCloserFirst(data + i * dim, k, &pickIdsDivisibleByTwo); + for (auto item: result) { + if (item.second % 2 == 1) std::cout << "Error: found odd label\n"; + } + } + + delete[] data; + delete alg_hnsw; + return 0; +} +``` + +An example with reusing the memory of the deleted elements when new elements are being added (via `allow_replace_deleted` flag): +```cpp +#include "../../hnswlib/hnswlib.h" + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, 100, true); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Mark first half of elements as deleted + int num_deleted = max_elements / 2; + for (int i = 0; i < num_deleted; i++) { + alg_hnsw->markDelete(i); + } + + float* add_data = new float[dim * num_deleted]; + for (int i = 0; i < dim * num_deleted; i++) { + add_data[i] = distrib_real(rng); + } + + // Replace deleted data with new elements + // Maximum number of elements is reached therefore we cannot add new items, + // but we can replace the deleted ones by using replace_deleted=true + for (int i = 0; i < num_deleted; i++) { + int label = max_elements + i; + alg_hnsw->addPoint(add_data + i * dim, label, true); + } + + delete[] data; + delete[] add_data; + delete alg_hnsw; + return 0; +} +``` + +Multithreaded examples: +* Creating index, inserting elements, searching [example_mt_search.cpp](example_mt_search.cpp) +* Filtering during the search with a boolean function [example_mt_filter.cpp](example_mt_filter.cpp) +* Reusing the memory of the deleted elements when new elements are being added [example_mt_replace_deleted.cpp](example_mt_replace_deleted.cpp) \ No newline at end of file diff --git a/examples/cpp/example_filter.cpp b/examples/cpp/example_filter.cpp new file mode 100644 index 00000000..dc978c57 --- /dev/null +++ b/examples/cpp/example_filter.cpp @@ -0,0 +1,57 @@ +#include "../../hnswlib/hnswlib.h" + + +// Filter that allows labels divisible by divisor +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(hnswlib::labeltype label_id) { + return label_id % divisor == 0; + } +}; + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Create filter that allows only even labels + PickDivisibleIds pickIdsDivisibleByTwo(2); + + // Query the elements for themselves with filter and check returned labels + int k = 10; + for (int i = 0; i < max_elements; i++) { + std::vector> result = alg_hnsw->searchKnnCloserFirst(data + i * dim, k, &pickIdsDivisibleByTwo); + for (auto item: result) { + if (item.second % 2 == 1) std::cout << "Error: found odd label\n"; + } + } + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_mt_filter.cpp b/examples/cpp/example_mt_filter.cpp new file mode 100644 index 00000000..b39de4c3 --- /dev/null +++ b/examples/cpp/example_mt_filter.cpp @@ -0,0 +1,124 @@ +#include "../../hnswlib/hnswlib.h" +#include + + +// Multithreaded executor +// The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) +// An alternative is using #pragme omp parallel for or any other C++ threading +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +// Filter that allows labels divisible by divisor +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(hnswlib::labeltype label_id) { + return label_id % divisor == 0; + } +}; + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int num_threads = 20; // Number of threads for operations with index + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(data + dim * row), row); + }); + + // Create filter that allows only even labels + PickDivisibleIds pickIdsDivisibleByTwo(2); + + // Query the elements for themselves with filter and check returned labels + int k = 10; + std::vector neighbors(max_elements * k); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, k, &pickIdsDivisibleByTwo); + for (int i = 0; i < k; i++) { + hnswlib::labeltype label = result.top().second; + result.pop(); + neighbors[row * k + i] = label; + } + }); + + for (hnswlib::labeltype label: neighbors) { + if (label % 2 == 1) std::cout << "Error: found odd label\n"; + } + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_mt_replace_deleted.cpp b/examples/cpp/example_mt_replace_deleted.cpp new file mode 100644 index 00000000..40a94ce7 --- /dev/null +++ b/examples/cpp/example_mt_replace_deleted.cpp @@ -0,0 +1,114 @@ +#include "../../hnswlib/hnswlib.h" +#include + + +// Multithreaded executor +// The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) +// An alternative is using #pragme omp parallel for or any other C++ threading +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int num_threads = 20; // Number of threads for operations with index + + // Initing index with allow_replace_deleted=true + int seed = 100; + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, seed, true); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(data + dim * row), row); + }); + + // Mark first half of elements as deleted + int num_deleted = max_elements / 2; + ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->markDelete(row); + }); + + // Generate additional random data + float* add_data = new float[dim * num_deleted]; + for (int i = 0; i < dim * num_deleted; i++) { + add_data[i] = distrib_real(rng); + } + + // Replace deleted data with new elements + // Maximum number of elements is reached therefore we cannot add new items, + // but we can replace the deleted ones by using replace_deleted=true + ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) { + hnswlib::labeltype label = max_elements + row; + alg_hnsw->addPoint((void*)(add_data + dim * row), label, true); + }); + + delete[] data; + delete[] add_data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_mt_search.cpp b/examples/cpp/example_mt_search.cpp new file mode 100644 index 00000000..e315b9ff --- /dev/null +++ b/examples/cpp/example_mt_search.cpp @@ -0,0 +1,107 @@ +#include "../../hnswlib/hnswlib.h" +#include + + +// Multithreaded executor +// The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) +// An alternative is using #pragme omp parallel for or any other C++ threading +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int num_threads = 20; // Number of threads for operations with index + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(data + dim * row), row); + }); + + // Query the elements for themselves and measure recall + std::vector neighbors(max_elements); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, 1); + hnswlib::labeltype label = result.top().second; + neighbors[row] = label; + }); + float correct = 0; + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = neighbors[i]; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_replace_deleted.cpp b/examples/cpp/example_replace_deleted.cpp new file mode 100644 index 00000000..64c995bb --- /dev/null +++ b/examples/cpp/example_replace_deleted.cpp @@ -0,0 +1,54 @@ +#include "../../hnswlib/hnswlib.h" + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index with allow_replace_deleted=true + int seed = 100; + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, seed, true); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Mark first half of elements as deleted + int num_deleted = max_elements / 2; + for (int i = 0; i < num_deleted; i++) { + alg_hnsw->markDelete(i); + } + + // Generate additional random data + float* add_data = new float[dim * num_deleted]; + for (int i = 0; i < dim * num_deleted; i++) { + add_data[i] = distrib_real(rng); + } + + // Replace deleted data with new elements + // Maximum number of elements is reached therefore we cannot add new items, + // but we can replace the deleted ones by using replace_deleted=true + for (int i = 0; i < num_deleted; i++) { + hnswlib::labeltype label = max_elements + i; + alg_hnsw->addPoint(add_data + i * dim, label, true); + } + + delete[] data; + delete[] add_data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_search.cpp b/examples/cpp/example_search.cpp new file mode 100644 index 00000000..2c28738f --- /dev/null +++ b/examples/cpp/example_search.cpp @@ -0,0 +1,58 @@ +#include "../../hnswlib/hnswlib.h" + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Query the elements for themselves and measure recall + float correct = 0; + for (int i = 0; i < max_elements; i++) { + std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); + hnswlib::labeltype label = result.top().second; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + + // Serialize index + std::string hnsw_path = "hnsw.bin"; + alg_hnsw->saveIndex(hnsw_path); + delete alg_hnsw; + + // Deserialize index and check recall + alg_hnsw = new hnswlib::HierarchicalNSW(&space, hnsw_path); + correct = 0; + for (int i = 0; i < max_elements; i++) { + std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); + hnswlib::labeltype label = result.top().second; + if (label == i) correct++; + } + recall = (float)correct / max_elements; + std::cout << "Recall of deserialized index: " << recall << "\n"; + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/EXAMPLES.md b/examples/python/EXAMPLES.md similarity index 99% rename from examples/EXAMPLES.md rename to examples/python/EXAMPLES.md index a92f3626..6c1b20e4 100644 --- a/examples/EXAMPLES.md +++ b/examples/python/EXAMPLES.md @@ -1,6 +1,6 @@ # Python bindings examples -Creating index, inserting elements, searching and pickle serialization +Creating index, inserting elements, searching and pickle serialization: ```python import hnswlib import numpy as np @@ -107,7 +107,7 @@ labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") ``` -An example with a symbolic filter `filter_function` during the search:: +An example with a symbolic filter `filter_function` during the search: ```python import hnswlib import numpy as np diff --git a/examples/example.py b/examples/python/example.py similarity index 100% rename from examples/example.py rename to examples/python/example.py diff --git a/examples/example_filter.py b/examples/python/example_filter.py similarity index 100% rename from examples/example_filter.py rename to examples/python/example_filter.py diff --git a/examples/example_replace_deleted.py b/examples/python/example_replace_deleted.py similarity index 100% rename from examples/example_replace_deleted.py rename to examples/python/example_replace_deleted.py diff --git a/examples/example_search.py b/examples/python/example_search.py similarity index 100% rename from examples/example_search.py rename to examples/python/example_search.py diff --git a/examples/example_serialization.py b/examples/python/example_serialization.py similarity index 100% rename from examples/example_serialization.py rename to examples/python/example_serialization.py diff --git a/examples/pyw_hnswlib.py b/examples/python/pyw_hnswlib.py similarity index 100% rename from examples/pyw_hnswlib.py rename to examples/python/pyw_hnswlib.py