Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ if(HNSWLIB_EXAMPLES)
"Please check if this is a typo.")
endif()
endforeach()
add_subdirectory(benchmark/cpp)
endif()

# Persist CMAKE_CXX_FLAGS in the cache for debuggability.
Expand Down
60 changes: 60 additions & 0 deletions benchmark/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# CMakeLists.txt
cmake_minimum_required(VERSION 3.11)

include(FetchContent)


# Google Benchmark
# close benchmark-test
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "" FORCE)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "" FORCE)
set(BENCHMARK_ENABLE_WERROR OFF CACHE BOOL "" FORCE)
FetchContent_Declare(
benchmark
GIT_REPOSITORY https://github.com/google/benchmark.git
GIT_TAG v1.9.4
GIT_SHALLOW TRUE
)
FetchContent_MakeAvailable(benchmark)


# Use master branch as standard
FetchContent_Declare(
hnswlib_std
GIT_REPOSITORY https://github.com/nmslib/hnswlib.git
GIT_TAG develop
GIT_SHALLOW TRUE
)
# avoid library name conflict
FetchContent_GetProperties(hnswlib_std)
if(NOT hnswlib_std_POPULATED)
FetchContent_Populate(hnswlib_std)
# rename master branch library
add_library(hnswlib_std INTERFACE)
add_library(hnswlib_std::hnswlib ALIAS hnswlib_std)
target_include_directories(hnswlib_std INTERFACE
$<BUILD_INTERFACE:${hnswlib_std_SOURCE_DIR}>
$<INSTALL_INTERFACE:include>)
endif()


# create benchmark binaries with different versions of hnswlib
# for standard library
add_executable(benchmark_standard benchmarks_main.cpp
bm_basic.cpp
)
target_link_libraries(benchmark_standard benchmark::benchmark)
target_link_libraries(benchmark_standard hnswlib_std)
set_target_properties(benchmark_standard PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
)

# for current library
add_executable(benchmark_current benchmarks_main.cpp
bm_basic.cpp
)
target_link_libraries(benchmark_current benchmark::benchmark)
target_link_libraries(benchmark_current hnswlib)
set_target_properties(benchmark_current PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
)
4 changes: 4 additions & 0 deletions benchmark/cpp/benchmarks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#pragma once

void RegisterHnswBasicBenchmarks();

14 changes: 14 additions & 0 deletions benchmark/cpp/benchmarks_main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <benchmark/benchmark.h>

#include "benchmarks.h"

int main(int argc, char** argv) {
::benchmark::Initialize(&argc, argv);

RegisterHnswBasicBenchmarks();

::benchmark::RunSpecifiedBenchmarks();
::benchmark::Shutdown();

return 0;
}
68 changes: 68 additions & 0 deletions benchmark/cpp/bm_basic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <benchmark/benchmark.h>
#include <vector>
#include <random>

#include "hnswlib/hnswalg.h"

// hnsw build benchmark

void l2_normalize(float* arr, size_t dim) {
float norm = 0;
for (size_t i = 0; i < dim; ++i) {
norm += arr[i] * arr[i];
}
norm = std::sqrt(norm);
for (size_t i = 0; i < dim; ++i) {
arr[i] /= norm;
}
}
void l2_normalize_batch(float* arr, size_t dim, size_t batch_size) {
for(size_t i = 0; i < batch_size; ++i){
l2_normalize(arr + i*dim, dim);
}
}
void prepare_data(std::vector< std::vector<float> >& embeddings, size_t dim, size_t x_data_size, bool need_l2_normalize) {
std::mt19937 rng(42); // same seed to ensure reproducibility
std::vector<float> datas(x_data_size*dim);
std::generate(datas.begin(), datas.end(), rng);
if (need_l2_normalize) {
l2_normalize_batch(datas.data(), dim, x_data_size);
}
for(size_t i=0; i<x_data_size; ++i) {
auto& emb = embeddings[i];
memcpy(emb.data(), datas.data() + i*dim, dim*sizeof(float));
}
}


static void BM_HnswIPAddPointWholeTimeBench(benchmark::State& state) {
size_t M = state.range(0);
size_t ef_construction = state.range(1);
size_t dim = state.range(2);
size_t x_data_size = state.range(3);

std::vector< std::vector<float> > embeddings(x_data_size, std::vector<float>(dim, 0.0f));
prepare_data(embeddings, dim, x_data_size, true);

for (auto _: state) {
auto space = std::make_shared<hnswlib::InnerProductSpace>(dim);
auto index = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), x_data_size);
for (size_t i = 0; i < x_data_size; i++) {
auto& emb = embeddings[i];
index->addPoint(emb.data(), i);
}
benchmark::DoNotOptimize(index);
}

state.SetComplexityN(state.range(0)*state.range(1)*state.range(2)*state.range(3));
}

void RegisterHnswBasicBenchmarks() {
BENCHMARK(BM_HnswIPAddPointWholeTimeBench)
->ArgsProduct({
{16,32},
{200,400},
{32, 128},
{500,5000}
});
}
27 changes: 11 additions & 16 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,39 +481,34 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return;
}

std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
std::vector<std::pair<dist_t, tableint>> return_list;
std::vector<std::pair<dist_t, tableint>> rqueue_closest;
std::vector<tableint> return_id_list;
while (top_candidates.size() > 0) {
queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
rqueue_closest.emplace_back(top_candidates.top());
top_candidates.pop();
}

while (queue_closest.size()) {
if (return_list.size() >= M)
for(auto rit = rqueue_closest.rbegin(); rit != rqueue_closest.rend(); ++rit) {
if (return_id_list.size() >= M)
break;
std::pair<dist_t, tableint> curent_pair = queue_closest.top();
dist_t dist_to_query = -curent_pair.first;
queue_closest.pop();
dist_t dist_to_query = rit->first;
bool good = true;

for (std::pair<dist_t, tableint> second_pair : return_list) {
for (const auto& id : return_id_list) {
dist_t curdist =
fstdistfunc_(getDataByInternalId(second_pair.second),
getDataByInternalId(curent_pair.second),
fstdistfunc_(getDataByInternalId(id),
getDataByInternalId(rit->second),
dist_func_param_);
if (curdist < dist_to_query) {
good = false;
break;
}
}
if (good) {
return_list.push_back(curent_pair);
return_id_list.push_back(rit->second);
top_candidates.emplace(std::move(*rit));
}
}

for (std::pair<dist_t, tableint> curent_pair : return_list) {
top_candidates.emplace(-curent_pair.first, curent_pair.second);
}
}


Expand Down