Skip to content

Commit b125ce8

Browse files
authored
optimize priority queue in getNeighborsByHeuristic2
1 parent 2fba7fb commit b125ce8

File tree

6 files changed

+158
-16
lines changed

6 files changed

+158
-16
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ if(HNSWLIB_EXAMPLES)
247247
"Please check if this is a typo.")
248248
endif()
249249
endforeach()
250+
add_subdirectory(benchmark/cpp)
250251
endif()
251252

252253
# Persist CMAKE_CXX_FLAGS in the cache for debuggability.

benchmark/cpp/CMakeLists.txt

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# CMakeLists.txt
2+
cmake_minimum_required(VERSION 3.11)
3+
4+
include(FetchContent)
5+
6+
7+
# Google Benchmark
8+
# close benchmark-test
9+
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "" FORCE)
10+
set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "" FORCE)
11+
set(BENCHMARK_ENABLE_WERROR OFF CACHE BOOL "" FORCE)
12+
FetchContent_Declare(
13+
benchmark
14+
GIT_REPOSITORY https://github.com/google/benchmark.git
15+
GIT_TAG v1.9.4
16+
GIT_SHALLOW TRUE
17+
)
18+
FetchContent_MakeAvailable(benchmark)
19+
20+
21+
# Use master branch as standard
22+
FetchContent_Declare(
23+
hnswlib_std
24+
GIT_REPOSITORY https://github.com/nmslib/hnswlib.git
25+
GIT_TAG develop
26+
GIT_SHALLOW TRUE
27+
)
28+
# avoid library name conflict
29+
FetchContent_GetProperties(hnswlib_std)
30+
if(NOT hnswlib_std_POPULATED)
31+
FetchContent_Populate(hnswlib_std)
32+
# rename master branch library
33+
add_library(hnswlib_std INTERFACE)
34+
add_library(hnswlib_std::hnswlib ALIAS hnswlib_std)
35+
target_include_directories(hnswlib_std INTERFACE
36+
$<BUILD_INTERFACE:${hnswlib_std_SOURCE_DIR}>
37+
$<INSTALL_INTERFACE:include>)
38+
endif()
39+
40+
41+
# create benchmark binaries with different versions of hnswlib
42+
# for standard library
43+
add_executable(benchmark_standard benchmarks_main.cpp
44+
bm_basic.cpp
45+
)
46+
target_link_libraries(benchmark_standard benchmark::benchmark)
47+
target_link_libraries(benchmark_standard hnswlib_std)
48+
set_target_properties(benchmark_standard PROPERTIES
49+
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
50+
)
51+
52+
# for current library
53+
add_executable(benchmark_current benchmarks_main.cpp
54+
bm_basic.cpp
55+
)
56+
target_link_libraries(benchmark_current benchmark::benchmark)
57+
target_link_libraries(benchmark_current hnswlib)
58+
set_target_properties(benchmark_current PROPERTIES
59+
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
60+
)

benchmark/cpp/benchmarks.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#pragma once
2+
3+
void RegisterHnswBenchmarks();
4+

benchmark/cpp/benchmarks_main.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <benchmark/benchmark.h>
2+
3+
#include "benchmarks.h"
4+
5+
int main(int argc, char** argv) {
6+
::benchmark::Initialize(&argc, argv);
7+
8+
RegisterHnswBasicBenchmarks();
9+
10+
::benchmark::RunSpecifiedBenchmarks();
11+
::benchmark::Shutdown();
12+
13+
return 0;
14+
}

benchmark/cpp/bm_basic.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <benchmark/benchmark.h>
2+
#include <vector>
3+
#include <random>
4+
5+
#include "hnswlib/hnswalg.h"
6+
7+
// hnsw build benchmark
8+
9+
void l2_normalize(float* arr, size_t dim) {
10+
float norm = 0;
11+
for (size_t i = 0; i < dim; ++i) {
12+
norm += arr[i] * arr[i];
13+
}
14+
norm = std::sqrt(norm);
15+
for (size_t i = 0; i < dim; ++i) {
16+
arr[i] /= norm;
17+
}
18+
}
19+
void l2_normalize_batch(float* arr, size_t dim, size_t batch_size) {
20+
for(size_t i = 0; i < batch_size; ++i){
21+
l2_normalize(arr + i*dim, dim);
22+
}
23+
}
24+
void prepare_data(std::vector< std::vector<float> >& embeddings, size_t dim, size_t x_data_size, bool need_l2_normalize) {
25+
std::mt19937 rng(42); // same seed to ensure reproducibility
26+
std::vector<float> datas(x_data_size*dim);
27+
std::generate(datas.begin(), datas.end(), rng);
28+
if (need_l2_normalize) {
29+
l2_normalize_batch(datas.data(), dim, x_data_size);
30+
}
31+
for(size_t i=0; i<x_data_size; ++i) {
32+
auto& emb = embeddings[i];
33+
memcpy(emb.data(), datas.data() + i*dim, dim*sizeof(float));
34+
}
35+
}
36+
37+
38+
static void BM_HnswIPAddPointWholeTimeBench(benchmark::State& state) {
39+
size_t M = state.range(0);
40+
size_t ef_construction = state.range(1);
41+
size_t dim = state.range(2);
42+
size_t x_data_size = state.range(3);
43+
44+
std::vector< std::vector<float> > embeddings(x_data_size, std::vector<float>(dim, 0.0f));
45+
prepare_data(embeddings, dim, x_data_size, true);
46+
47+
for (auto _: state) {
48+
auto space = std::make_shared<hnswlib::InnerProductSpace>(dim);
49+
auto index = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), x_data_size);
50+
for (size_t i = 0; i < x_data_size; i++) {
51+
auto& emb = embeddings[i];
52+
index->addPoint(emb.data(), i);
53+
}
54+
benchmark::DoNotOptimize(index);
55+
}
56+
57+
state.SetComplexityN(state.range(0)*state.range(1)*state.range(2)*state.range(3));
58+
}
59+
60+
void RegisterHnswBasicBenchmarks() {
61+
BENCHMARK(BM_HnswIPAddPointWholeTimeBench)
62+
->ArgsProduct({
63+
{16,32},
64+
{200,400},
65+
{32, 128},
66+
{500,5000}
67+
});
68+
}

hnswlib/hnswalg.h

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -481,39 +481,34 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
481481
return;
482482
}
483483

484-
std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
485-
std::vector<std::pair<dist_t, tableint>> return_list;
484+
std::vector<std::pair<dist_t, tableint>> rqueue_closest;
485+
std::vector<tableint> return_id_list;
486486
while (top_candidates.size() > 0) {
487-
queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
487+
rqueue_closest.emplace_back(top_candidates.top());
488488
top_candidates.pop();
489489
}
490490

491-
while (queue_closest.size()) {
492-
if (return_list.size() >= M)
491+
for(auto rit = rqueue_closest.rbegin(); rit != rqueue_closest.rend(); ++rit) {
492+
if (return_id_list.size() >= M)
493493
break;
494-
std::pair<dist_t, tableint> curent_pair = queue_closest.top();
495-
dist_t dist_to_query = -curent_pair.first;
496-
queue_closest.pop();
494+
dist_t dist_to_query = rit->first;
497495
bool good = true;
498496

499-
for (std::pair<dist_t, tableint> second_pair : return_list) {
497+
for (const auto& id : return_id_list) {
500498
dist_t curdist =
501-
fstdistfunc_(getDataByInternalId(second_pair.second),
502-
getDataByInternalId(curent_pair.second),
499+
fstdistfunc_(getDataByInternalId(id),
500+
getDataByInternalId(rit->second),
503501
dist_func_param_);
504502
if (curdist < dist_to_query) {
505503
good = false;
506504
break;
507505
}
508506
}
509507
if (good) {
510-
return_list.push_back(curent_pair);
508+
return_id_list.push_back(rit->second);
509+
top_candidates.emplace(std::move(*rit));
511510
}
512511
}
513-
514-
for (std::pair<dist_t, tableint> curent_pair : return_list) {
515-
top_candidates.emplace(-curent_pair.first, curent_pair.second);
516-
}
517512
}
518513

519514

0 commit comments

Comments
 (0)