1+ #include  < benchmark/benchmark.h> 
2+ #include  < vector> 
3+ #include  < random> 
4+ 
5+ #include  " hnswlib/hnswalg.h" 
6+ 
7+ //  hnsw build benchmark
8+ 
9+ void  l2_normalize (float * arr, size_t  dim) {
10+     float  norm = 0 ;
11+     for  (size_t  i = 0 ; i < dim; ++i) {
12+         norm += arr[i] * arr[i];
13+     }
14+     norm = std::sqrt (norm);
15+     for  (size_t  i = 0 ; i < dim; ++i) {
16+         arr[i] /= norm;
17+     }
18+ }
19+ void  l2_normalize_batch (float * arr, size_t  dim, size_t  batch_size) {
20+     for (size_t  i = 0 ; i < batch_size; ++i){
21+         l2_normalize (arr + i*dim, dim);
22+     }
23+ }
24+ void  prepare_data (std::vector< std::vector<float > >& embeddings, size_t  dim, size_t  x_data_size, bool  need_l2_normalize) {
25+     std::mt19937 rng (42 ); //  same seed to ensure reproducibility
26+     std::vector<float > datas (x_data_size*dim);
27+     std::generate (datas.begin (), datas.end (), rng);
28+     if  (need_l2_normalize) {
29+         l2_normalize_batch (datas.data (), dim, x_data_size);
30+     }
31+     for (size_t  i=0 ; i<x_data_size; ++i) {
32+         auto & emb  = embeddings[i];
33+         memcpy (emb.data (), datas.data () + i*dim, dim*sizeof (float ));
34+     }
35+ }
36+ 
37+ 
38+ static  void  BM_HnswIPAddPointWholeTimeBench (benchmark::State& state) {
39+     size_t  M = state.range (0 );
40+     size_t  ef_construction = state.range (1 );
41+     size_t  dim = state.range (2 );
42+     size_t  x_data_size = state.range (3 );
43+ 
44+     std::vector< std::vector<float > > embeddings (x_data_size, std::vector<float >(dim, 0 .0f ));
45+     prepare_data (embeddings, dim, x_data_size, true );
46+ 
47+     for  (auto  _: state) {
48+         auto  space = std::make_shared<hnswlib::InnerProductSpace>(dim);
49+         auto  index = std::make_shared<hnswlib::HierarchicalNSW<float >>(space.get (), x_data_size);
50+         for  (size_t  i = 0 ; i < x_data_size; i++) {
51+             auto & emb = embeddings[i];
52+             index->addPoint (emb.data (), i);
53+         }
54+         benchmark::DoNotOptimize (index);
55+     }
56+ 
57+     state.SetComplexityN (state.range (0 )*state.range (1 )*state.range (2 )*state.range (3 ));
58+ }
59+ 
60+ void  RegisterHnswBasicBenchmarks () {
61+     BENCHMARK (BM_HnswIPAddPointWholeTimeBench)
62+         ->ArgsProduct ({
63+             {16 ,32 },
64+             {200 ,400 },
65+             {32 , 128 },
66+             {500 ,5000 }
67+         });
68+ }
0 commit comments