diff --git a/CMakeLists.txt b/CMakeLists.txt index a34a67f2..ef66818e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -247,6 +247,7 @@ if(HNSWLIB_EXAMPLES) "Please check if this is a typo.") endif() endforeach() + add_subdirectory(benchmark/cpp) endif() # Persist CMAKE_CXX_FLAGS in the cache for debuggability. diff --git a/benchmark/cpp/CMakeLists.txt b/benchmark/cpp/CMakeLists.txt new file mode 100644 index 00000000..88532bea --- /dev/null +++ b/benchmark/cpp/CMakeLists.txt @@ -0,0 +1,60 @@ +# CMakeLists.txt +cmake_minimum_required(VERSION 3.11) + +include(FetchContent) + + +# Google Benchmark +# close benchmark-test +set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "" FORCE) +set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "" FORCE) +set(BENCHMARK_ENABLE_WERROR OFF CACHE BOOL "" FORCE) +FetchContent_Declare( + benchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG v1.9.4 + GIT_SHALLOW TRUE +) +FetchContent_MakeAvailable(benchmark) + + +# Use master branch as standard +FetchContent_Declare( + hnswlib_std + GIT_REPOSITORY https://github.com/nmslib/hnswlib.git + GIT_TAG develop + GIT_SHALLOW TRUE +) +# avoid library name conflict +FetchContent_GetProperties(hnswlib_std) +if(NOT hnswlib_std_POPULATED) + FetchContent_Populate(hnswlib_std) + # rename master branch library + add_library(hnswlib_std INTERFACE) + add_library(hnswlib_std::hnswlib ALIAS hnswlib_std) + target_include_directories(hnswlib_std INTERFACE + $ + $) +endif() + + +# create benchmark binaries with different versions of hnswlib +# for standard library +add_executable(benchmark_standard benchmarks_main.cpp + bm_basic.cpp +) +target_link_libraries(benchmark_standard benchmark::benchmark) +target_link_libraries(benchmark_standard hnswlib_std) +set_target_properties(benchmark_standard PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR} +) + +# for current library +add_executable(benchmark_current benchmarks_main.cpp + bm_basic.cpp +) +target_link_libraries(benchmark_current benchmark::benchmark) +target_link_libraries(benchmark_current hnswlib) +set_target_properties(benchmark_current PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR} +) \ No newline at end of file diff --git a/benchmark/cpp/benchmarks.h b/benchmark/cpp/benchmarks.h new file mode 100644 index 00000000..f76d6d41 --- /dev/null +++ b/benchmark/cpp/benchmarks.h @@ -0,0 +1,4 @@ +#pragma once + +void RegisterHnswBasicBenchmarks(); + diff --git a/benchmark/cpp/benchmarks_main.cpp b/benchmark/cpp/benchmarks_main.cpp new file mode 100644 index 00000000..0b6df4c2 --- /dev/null +++ b/benchmark/cpp/benchmarks_main.cpp @@ -0,0 +1,14 @@ +#include + +#include "benchmarks.h" + +int main(int argc, char** argv) { + ::benchmark::Initialize(&argc, argv); + + RegisterHnswBasicBenchmarks(); + + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + + return 0; +} \ No newline at end of file diff --git a/benchmark/cpp/bm_basic.cpp b/benchmark/cpp/bm_basic.cpp new file mode 100644 index 00000000..54cb1ed9 --- /dev/null +++ b/benchmark/cpp/bm_basic.cpp @@ -0,0 +1,68 @@ +#include +#include +#include + +#include "hnswlib/hnswalg.h" + +// hnsw build benchmark + +void l2_normalize(float* arr, size_t dim) { + float norm = 0; + for (size_t i = 0; i < dim; ++i) { + norm += arr[i] * arr[i]; + } + norm = std::sqrt(norm); + for (size_t i = 0; i < dim; ++i) { + arr[i] /= norm; + } +} +void l2_normalize_batch(float* arr, size_t dim, size_t batch_size) { + for(size_t i = 0; i < batch_size; ++i){ + l2_normalize(arr + i*dim, dim); + } +} +void prepare_data(std::vector< std::vector >& embeddings, size_t dim, size_t x_data_size, bool need_l2_normalize) { + std::mt19937 rng(42); // same seed to ensure reproducibility + std::vector datas(x_data_size*dim); + std::generate(datas.begin(), datas.end(), rng); + if (need_l2_normalize) { + l2_normalize_batch(datas.data(), dim, x_data_size); + } + for(size_t i=0; i > embeddings(x_data_size, std::vector(dim, 0.0f)); + prepare_data(embeddings, dim, x_data_size, true); + + for (auto _: state) { + auto space = std::make_shared(dim); + auto index = std::make_shared>(space.get(), x_data_size); + for (size_t i = 0; i < x_data_size; i++) { + auto& emb = embeddings[i]; + index->addPoint(emb.data(), i); + } + benchmark::DoNotOptimize(index); + } + + state.SetComplexityN(state.range(0)*state.range(1)*state.range(2)*state.range(3)); +} + +void RegisterHnswBasicBenchmarks() { + BENCHMARK(BM_HnswIPAddPointWholeTimeBench) + ->ArgsProduct({ + {16,32}, + {200,400}, + {32, 128}, + {500,5000} + }); +} \ No newline at end of file diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index a04b2ed4..0faeec91 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -481,25 +481,23 @@ class HierarchicalNSW : public AlgorithmInterface { return; } - std::priority_queue> queue_closest; - std::vector> return_list; + std::vector> rqueue_closest; + std::vector return_id_list; while (top_candidates.size() > 0) { - queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + rqueue_closest.emplace_back(top_candidates.top()); top_candidates.pop(); } - while (queue_closest.size()) { - if (return_list.size() >= M) + for(auto rit = rqueue_closest.rbegin(); rit != rqueue_closest.rend(); ++rit) { + if (return_id_list.size() >= M) break; - std::pair curent_pair = queue_closest.top(); - dist_t dist_to_query = -curent_pair.first; - queue_closest.pop(); + dist_t dist_to_query = rit->first; bool good = true; - for (std::pair second_pair : return_list) { + for (const auto& id : return_id_list) { dist_t curdist = - fstdistfunc_(getDataByInternalId(second_pair.second), - getDataByInternalId(curent_pair.second), + fstdistfunc_(getDataByInternalId(id), + getDataByInternalId(rit->second), dist_func_param_); if (curdist < dist_to_query) { good = false; @@ -507,13 +505,10 @@ class HierarchicalNSW : public AlgorithmInterface { } } if (good) { - return_list.push_back(curent_pair); + return_id_list.push_back(rit->second); + top_candidates.emplace(std::move(*rit)); } } - - for (std::pair curent_pair : return_list) { - top_candidates.emplace(-curent_pair.first, curent_pair.second); - } }