[feature](ANN) implemented the import process of vector index (#49703)
### What problem does this PR solve?
Issue Number: close #xxx
Related PR: #xxx
Problem Summary:
```
CREATE TABLE `vector_table` (
`siteid` int(11) NULL DEFAULT "10" COMMENT "",
`embedding` array<float> NOT NULL COMMENT "",
`comment` text NULL,
INDEX idx_test_ann (`embedding`) USING ANN PROPERTIES(
"index_type"="diskann",
"metric_type"="l2",
"dim"="8",
"search_list"="32",
"max_degree"="100"
) COMMENT 'test diskann index',
INDEX idx_comment (`comment`) USING INVERTED PROPERTIES("support_phrase" = "true", "parser" = "english", "lower_case" = "true") COMMENT 'inverted index for comment'
) ENGINE=OLAP
duplicate KEY(`siteid`) COMMENT "OLAP"
DISTRIBUTED BY HASH(`siteid`) BUCKETS 1
PROPERTIES (
"replication_num" = "1",
"storage_vault_name" = "s3_vault"
);
INSERT INTO `vector_table` (`siteid`, `embedding`,`comment`) VALUES
(10, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,20],"emb1"),
(20, [7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0,30],"emb2");
MySQL [clz]> select * from vector_table;
+--------+---------------------------+---------+
| siteid | embedding | comment |
+--------+---------------------------+---------+
| 10 | [1, 2, 3, 4, 5, 6, 7, 20] | emb1 |
| 20 | [7, 6, 5, 4, 3, 2, 1, 30] | emb2 |
+--------+---------------------------+---------+
编译镜像:https://clz-repo.cdn.bcebos.com/doris_ann.image.tar
编译命令:
export JAVA_HOME=/usr/lib/jvm/jdk-17.0.2/;
DISABLE_BE_JAVA_EXTENSIONS=ON BUILD_TYPE=Debug DORIS_TOOLCHAIN=gcc sh build.sh --be --fe
```
### Release note
None
### Check List (For Author)
- Test <!-- At least one of them must be included. -->
- [ ] Regression test
- [ ] Unit Test
- [ ] Manual test (add detailed scripts or steps below)
- [ ] No need to test or manual test. Explain why:
- [ ] This is a refactor/code format and no logic has been changed.
- [ ] Previous test can cover this change.
- [ ] No code files have been changed.
- [ ] Other reason <!-- Add your reason? -->
- Behavior changed:
- [ ] No.
- [ ] Yes. <!-- Explain the behavior change -->
- Does this need documentation?
- [ ] No.
- [ ] Yes. <!-- Add document PR link here. eg:
https://github.com/apache/doris-website/pull/1214 -->
### Check List (For Reviewer who merge this PR)
- [ ] Confirm the release note
- [ ] Confirm test cases
- [ ] Confirm document
- [ ] Add branch pick label <!-- Add branch pick label that this PR
should merge into -->diff --git a/be/CMakeLists.txt b/be/CMakeLists.txt
index d476af8..1cd1ac0 100644
--- a/be/CMakeLists.txt
+++ b/be/CMakeLists.txt
@@ -421,7 +421,7 @@
endif()
# Add flags that are common across build types
-set(CMAKE_CXX_FLAGS "${CXX_COMMON_FLAGS} ${CMAKE_CXX_FLAGS} ${EXTRA_CXX_FLAGS}")
+set(CMAKE_CXX_FLAGS "${CXX_COMMON_FLAGS} ${CMAKE_CXX_FLAGS} ${EXTRA_CXX_FLAGS} -fopenmp -fopenmp-simd ")
set(CMAKE_C_FLAGS ${CMAKE_CXX_FLAGS})
@@ -511,6 +511,8 @@
Cloud
${WL_END_GROUP}
CommonCPP
+ diskann_s
+ vector
)
set(absl_DIR ${THIRDPARTY_DIR}/lib/cmake/absl)
@@ -750,20 +752,28 @@
endif()
endfunction(pch_reuse target)
-add_subdirectory(${SRC_DIR}/agent)
-add_subdirectory(${SRC_DIR}/common)
-add_subdirectory(${SRC_DIR}/exec)
-add_subdirectory(${SRC_DIR}/exprs)
-add_subdirectory(${SRC_DIR}/gen_cpp)
-add_subdirectory(${SRC_DIR}/geo)
-add_subdirectory(${SRC_DIR}/gutil)
-add_subdirectory(${SRC_DIR}/http)
-add_subdirectory(${SRC_DIR}/io)
-add_subdirectory(${SRC_DIR}/olap)
-add_subdirectory(${SRC_DIR}/runtime)
-add_subdirectory(${SRC_DIR}/service) # this include doris_be
-add_subdirectory(${SRC_DIR}/udf)
-add_subdirectory(${SRC_DIR}/cloud)
+macro(add_subdirectory_with_log subdir)
+ message(STATUS "Start compiling ${subdir} at ${CMAKE_CURRENT_SOURCE_DIR}/${subdir}")
+ add_subdirectory(${subdir})
+endmacro()
+
+add_subdirectory_with_log(${SRC_DIR}/agent)
+add_subdirectory_with_log(${SRC_DIR}/common)
+add_subdirectory_with_log(${SRC_DIR}/exec)
+add_subdirectory_with_log(${SRC_DIR}/exprs)
+add_subdirectory_with_log(${SRC_DIR}/gen_cpp)
+add_subdirectory_with_log(${SRC_DIR}/geo)
+add_subdirectory_with_log(${SRC_DIR}/gutil)
+add_subdirectory_with_log(${SRC_DIR}/http)
+add_subdirectory_with_log(${SRC_DIR}/io)
+add_subdirectory_with_log(${SRC_DIR}/olap)
+add_subdirectory_with_log(${SRC_DIR}/runtime)
+add_subdirectory_with_log(${SRC_DIR}/service) # this include doris_be
+add_subdirectory_with_log(${SRC_DIR}/udf)
+add_subdirectory_with_log(${SRC_DIR}/cloud)
+add_subdirectory_with_log(${SRC_DIR}/extern/diskann)
+add_subdirectory_with_log(${SRC_DIR}/vector)
+
option(BUILD_META_TOOL "Build meta tool" OFF)
if (BUILD_META_TOOL)
diff --git a/be/src/cloud/CMakeLists.txt b/be/src/cloud/CMakeLists.txt
index dbe8160..fd259cb 100644
--- a/be/src/cloud/CMakeLists.txt
+++ b/be/src/cloud/CMakeLists.txt
@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
+set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
+add_compile_options(-Wno-error=attributes)
+
# where to put generated libraries
set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/cloud")
diff --git a/be/src/common/status.h b/be/src/common/status.h
index 0252ec8..c7fd15b 100644
--- a/be/src/common/status.h
+++ b/be/src/common/status.h
@@ -640,7 +640,7 @@
// some generally useful macros
#define RETURN_IF_ERROR(stmt) \
do { \
- Status _status_ = (stmt); \
+ doris::Status _status_ = (stmt); \
if (UNLIKELY(!_status_.ok())) { \
return _status_; \
} \
diff --git a/be/src/extern/diskann/CMakeLists.txt b/be/src/extern/diskann/CMakeLists.txt
new file mode 100644
index 0000000..f2e2415
--- /dev/null
+++ b/be/src/extern/diskann/CMakeLists.txt
@@ -0,0 +1,173 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT license.
+
+# Parameters:
+#
+# BOOST_ROOT:
+# Specify root of the Boost library if Boost cannot be auto-detected.
+#
+# DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS:
+# This is a work-in-progress feature, not completed yet. The core DiskANN library will be split into
+# build-related and search-related functionality. In build-related functionality, when using tcmalloc,
+# it's possible to release memory that's free but reserved by tcmalloc. Setting this to true enables
+# such behavior.
+# Contact for this feature: gopalrs.
+cmake_minimum_required(VERSION 3.15)
+set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/extern")
+project(diskann)
+set(CMAKE_STANDARD 17)
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
+add_compile_options(-Wno-error=attributes)
+add_compile_options(-Wno-deprecated-copy)
+add_compile_options(-Wno-reorder)
+add_compile_options(-Wno-unused-but-set-variable)
+add_compile_options(-Wno-error=unused-variable)
+
+
+message(STATUS "CMAKE_C_FLAGS in be: ${CMAKE_C_FLAGS}")
+message(STATUS "CMAKE_CXX_FLAGS in be: ${CMAKE_CXX_FLAGS}")
+message(STATUS "CMAKE_PREFIX_PATH in be: ${CMAKE_PREFIX_PATH}")
+message(STATUS "LD_LIBRARY_PATH in be: $ENV{LD_LIBRARY_PATH}")
+message(STATUS "CMAKE_MODULE_PATH in be: $ENV")
+
+
+set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
+
+
+
+include_directories(${PROJECT_SOURCE_DIR}/include)
+include_directories(/home/users/clz/baidu/third-party/doris-diskann/build)
+
+if(NOT PYBIND)
+ set(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS ON)
+endif()
+# It's necessary to include tcmalloc headers only if calling into MallocExtension interface.
+# For using tcmalloc in DiskANN tools, it's enough to just link with tcmalloc.
+if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
+ include_directories(${PROJECT_SOURCE_DIR}/gperftools/src)
+endif()
+
+#OpenMP
+find_package(OpenMP)
+message(" OpenMP1 ${OPENMP_FOUND} ${CMAKE_CXX_COMPILER}")
+if (OPENMP_FOUND)
+ set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
+ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
+else()
+ message(FATAL_ERROR "No OpenMP support")
+endif()
+
+# DiskANN core uses header-only libraries. Only DiskANN tools need program_options which has a linker library,
+# but its size is small. Reduce number of dependent DLLs by linking statically.
+
+# find_package(Boost COMPONENTS program_options)
+# message("${Boost_PROGRAM_OPTIONS_INCLUDE_DIRS},${Boost_PROGRAM_OPTIONS_LIBRARIES}")
+# if (NOT Boost_FOUND)
+# message(FATAL_ERROR "Couldn't find Boost dependency")
+# endif()
+include_directories(/home/users/clz/baidu/third-party/palo/thirdparty/installed/include/)
+link_directories(/home/users/clz/baidu/third-party/palo/thirdparty/installed/lib64)
+
+#MKL Config
+# expected path for manual intel mkl installs
+set(POSSIBLE_OMP_PATHS "/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin/libiomp5.so;/usr/lib/x86_64-linux-gnu/libiomp5.so;/opt/intel/lib/intel64_lin/libiomp5.so")
+foreach(POSSIBLE_OMP_PATH ${POSSIBLE_OMP_PATHS})
+ if (EXISTS ${POSSIBLE_OMP_PATH})
+ get_filename_component(OMP_PATH ${POSSIBLE_OMP_PATH} DIRECTORY)
+ endif()
+endforeach()
+
+if(NOT OMP_PATH)
+ message(FATAL_ERROR "Could not find Intel OMP in standard locations; use -DOMP_PATH to specify the install location for your environment")
+endif()
+link_directories(${OMP_PATH})
+
+
+set(POSSIBLE_MKL_LIB_PATHS "/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so;/usr/lib/x86_64-linux-gnu/libmkl_core.so;/opt/intel/mkl/lib/intel64/libmkl_core.so")
+foreach(POSSIBLE_MKL_LIB_PATH ${POSSIBLE_MKL_LIB_PATHS})
+ if (EXISTS ${POSSIBLE_MKL_LIB_PATH})
+ get_filename_component(MKL_PATH ${POSSIBLE_MKL_LIB_PATH} DIRECTORY)
+ endif()
+endforeach()
+
+set(POSSIBLE_MKL_INCLUDE_PATHS "/opt/intel/oneapi/mkl/latest/include;/usr/include/mkl;/opt/intel/mkl/include/;")
+foreach(POSSIBLE_MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATHS})
+ if (EXISTS ${POSSIBLE_MKL_INCLUDE_PATH})
+ set(MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATH})
+ endif()
+endforeach()
+if(NOT MKL_PATH)
+ message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_PATH to specify the install location for your environment")
+elseif(NOT MKL_INCLUDE_PATH)
+ message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_INCLUDE_PATH to specify the install location for headers for your environment")
+endif()
+if (EXISTS ${MKL_PATH}/libmkl_def.so.2)
+ set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so.2)
+elseif(EXISTS ${MKL_PATH}/libmkl_def.so)
+ set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so)
+else()
+ message(FATAL_ERROR "Despite finding MKL, libmkl_def.so was not found in expected locations.")
+endif()
+link_directories(${MKL_PATH})
+message("mkl ${MKL_PATH}")
+include_directories(${MKL_INCLUDE_PATH})
+
+# compile flags and link libraries
+add_compile_options(-m64 -Wl,--no-as-needed)
+# if (NOT PYBIND)
+# link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl)
+# else()
+ # static linking for python so as to minimize customer dependency issues
+ link_libraries(
+ ${MKL_PATH}/libmkl_intel_ilp64.a
+ ${MKL_PATH}/libmkl_intel_thread.a
+ ${MKL_PATH}/libmkl_core.a
+ ${MKL_DEF_SO}
+ /opt/intel/lib/intel64_lin/libiomp5.a
+ pthread
+ m
+ dl
+ )
+# endif()
+add_definitions(-DMKL_ILP64)
+
+# Section for tcmalloc. The DiskANN tools are always linked to tcmalloc.
+if(NOT PYBIND)
+ set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-ltcmalloc")
+endif()
+
+if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
+ add_definitions(-DRELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
+endif()
+
+set(DISKANN_ASYNC_LIB aio)
+
+
+#Main compiler/linker settings
+set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -msse2 -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2")
+set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DNDEBUG")
+if (NOT PYBIND)
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
+ if (NOT PORTABLE)
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native")
+ endif()
+else()
+ # -Ofast is not supported in a python extension module
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -fPIC")
+endif()
+
+
+
+
+set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
+
+file(GLOB CPP_SOURCES src/*.cpp)
+add_library(${PROJECT_NAME} ${CPP_SOURCES})
+add_library(${PROJECT_NAME}_s STATIC ${CPP_SOURCES})
+
+
+
diff --git a/be/src/extern/diskann/include/ThreadPool.h b/be/src/extern/diskann/include/ThreadPool.h
new file mode 100644
index 0000000..3c051dc
--- /dev/null
+++ b/be/src/extern/diskann/include/ThreadPool.h
@@ -0,0 +1,98 @@
+#ifndef THREAD_POOL_H
+#define THREAD_POOL_H
+
+#include <vector>
+#include <queue>
+#include <memory>
+#include <thread>
+#include <mutex>
+#include <condition_variable>
+#include <future>
+#include <functional>
+#include <stdexcept>
+
+class ThreadPool {
+public:
+ ThreadPool(size_t);
+ template<class F, class... Args>
+ auto enqueue(F&& f, Args&&... args)
+ -> std::future<typename std::result_of<F(Args...)>::type>;
+ ~ThreadPool();
+private:
+ // need to keep track of threads so we can join them
+ std::vector< std::thread > workers;
+ // the task queue
+ std::queue< std::function<void()> > tasks;
+
+ // synchronization
+ std::mutex queue_mutex;
+ std::condition_variable condition;
+ bool stop;
+};
+
+// the constructor just launches some amount of workers
+inline ThreadPool::ThreadPool(size_t threads)
+ : stop(false)
+{
+ for(size_t i = 0;i<threads;++i)
+ workers.emplace_back(
+ [this]
+ {
+ for(;;)
+ {
+ std::function<void()> task;
+
+ {
+ std::unique_lock<std::mutex> lock(this->queue_mutex);
+ this->condition.wait(lock,
+ [this]{ return this->stop || !this->tasks.empty(); });
+ if(this->stop && this->tasks.empty())
+ return;
+ task = std::move(this->tasks.front());
+ this->tasks.pop();
+ }
+
+ task();
+ }
+ }
+ );
+}
+
+// add new work item to the pool
+template<class F, class... Args>
+auto ThreadPool::enqueue(F&& f, Args&&... args)
+ -> std::future<typename std::result_of<F(Args...)>::type>
+{
+ using return_type = typename std::result_of<F(Args...)>::type;
+
+ auto task = std::make_shared< std::packaged_task<return_type()> >(
+ std::bind(std::forward<F>(f), std::forward<Args>(args)...)
+ );
+
+ std::future<return_type> res = task->get_future();
+ {
+ std::unique_lock<std::mutex> lock(queue_mutex);
+
+ // don't allow enqueueing after stopping the pool
+ if(stop)
+ throw std::runtime_error("enqueue on stopped ThreadPool");
+
+ tasks.emplace([task](){ (*task)(); });
+ }
+ condition.notify_one();
+ return res;
+}
+
+// the destructor joins all threads
+inline ThreadPool::~ThreadPool()
+{
+ {
+ std::unique_lock<std::mutex> lock(queue_mutex);
+ stop = true;
+ }
+ condition.notify_all();
+ for(std::thread &worker: workers)
+ worker.join();
+}
+
+#endif
\ No newline at end of file
diff --git a/be/src/extern/diskann/include/abstract_data_store.h b/be/src/extern/diskann/include/abstract_data_store.h
new file mode 100644
index 0000000..89856f1
--- /dev/null
+++ b/be/src/extern/diskann/include/abstract_data_store.h
@@ -0,0 +1,127 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <vector>
+#include <string>
+
+#include "types.h"
+#include "windows_customizations.h"
+#include "distance.h"
+
+namespace diskann
+{
+
+template <typename data_t> class AbstractScratch;
+
+template <typename data_t> class AbstractDataStore
+{
+ public:
+ AbstractDataStore(const location_t capacity, const size_t dim);
+
+ virtual ~AbstractDataStore() = default;
+
+ // Return number of points returned
+ virtual location_t load(const std::string &filename) = 0;
+
+ // Why does store take num_pts? Since store only has capacity, but we allow
+ // resizing we can end up in a situation where the store has spare capacity.
+ // To optimize disk utilization, we pass the number of points that are "true"
+ // points, so that the store can discard the empty locations before saving.
+ virtual size_t save(const std::string &filename, const location_t num_pts) = 0;
+
+ DISKANN_DLLEXPORT virtual location_t capacity() const;
+
+ DISKANN_DLLEXPORT virtual size_t get_dims() const;
+
+ // Implementers can choose to return _dim if they are not
+ // concerned about memory alignment.
+ // Some distance metrics (like l2) need data vectors to be aligned, so we
+ // align the dimension by padding zeros.
+ virtual size_t get_aligned_dim() const = 0;
+
+ // populate the store with vectors (either from a pointer or bin file),
+ // potentially after pre-processing the vectors if the metric deems so
+ // e.g., normalizing vectors for cosine distance over floating-point vectors
+ // useful for bulk or static index building.
+ virtual void populate_data(const data_t *vectors, const location_t num_pts) = 0;
+ virtual void populate_data(const std::string &filename, const size_t offset) = 0;
+
+ // save the first num_pts many vectors back to bin file
+ // note: cannot undo the pre-processing done in populate data
+ virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) = 0;
+
+ // Returns the updated capacity of the datastore. Clients should check
+ // if resize actually changed the capacity to new_num_points before
+ // proceeding with operations. See the code below:
+ // auto new_capcity = data_store->resize(new_num_points);
+ // if ( new_capacity >= new_num_points) {
+ // //PROCEED
+ // else
+ // //ERROR.
+ virtual location_t resize(const location_t new_num_points);
+
+ // operations on vectors
+ // like populate_data function, but over one vector at a time useful for
+ // streaming setting
+ virtual void get_vector(const location_t i, data_t *dest) const = 0;
+ virtual void set_vector(const location_t i, const data_t *const vector) = 0;
+ virtual void prefetch_vector(const location_t loc) = 0;
+
+ // internal shuffle operations to move around vectors
+ // will bulk-move all the vectors in [old_start_loc, old_start_loc +
+ // num_points) to [new_start_loc, new_start_loc + num_points) and set the old
+ // positions to zero vectors.
+ virtual void move_vectors(const location_t old_start_loc, const location_t new_start_loc,
+ const location_t num_points) = 0;
+
+ // same as above, without resetting the vectors in [from_loc, from_loc +
+ // num_points) to zero
+ virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0;
+
+ // With the PQ Data Store PR, we have also changed iterate_to_fixed_point to NOT take the query
+ // from the scratch object. Therefore every data store has to implement preprocess_query which
+ // at the least will be to copy the query into the scratch object. So making this pure virtual.
+ virtual void preprocess_query(const data_t *aligned_query,
+ AbstractScratch<data_t> *query_scratch = nullptr) const = 0;
+ // distance functions.
+ virtual float get_distance(const data_t *query, const location_t loc) const = 0;
+ virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count,
+ float *distances, AbstractScratch<data_t> *scratch_space = nullptr) const = 0;
+ // Specific overload for index.cpp.
+ virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
+ std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const = 0;
+ virtual float get_distance(const location_t loc1, const location_t loc2) const = 0;
+
+ // stats of the data stored in store
+ // Returns the point in the dataset that is closest to the mean of all points
+ // in the dataset
+ virtual location_t calculate_medoid() const = 0;
+
+ // REFACTOR PQ TODO: Each data store knows about its distance function, so this is
+ // redundant. However, we don't have an OptmizedDataStore yet, and to preserve code
+ // compability, we are exposing this function.
+ virtual Distance<data_t> *get_dist_fn() const = 0;
+
+ // search helpers
+ // if the base data is aligned per the request of the metric, this will tell
+ // how to align the query vector in a consistent manner
+ virtual size_t get_alignment_factor() const = 0;
+
+ protected:
+ // Expand the datastore to new_num_points. Returns the new capacity created,
+ // which should be == new_num_points in the normal case. Implementers can also
+ // return _capacity to indicate that there are not implementing this method.
+ virtual location_t expand(const location_t new_num_points) = 0;
+
+ // Shrink the datastore to new_num_points. It is NOT an error if shrink
+ // doesn't reduce the capacity so callers need to check this correctly. See
+ // also for "default" implementation
+ virtual location_t shrink(const location_t new_num_points) = 0;
+
+ location_t _capacity;
+ size_t _dim;
+};
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/abstract_graph_store.h b/be/src/extern/diskann/include/abstract_graph_store.h
new file mode 100644
index 0000000..4ef09c7
--- /dev/null
+++ b/be/src/extern/diskann/include/abstract_graph_store.h
@@ -0,0 +1,75 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <string>
+#include <vector>
+#include "types.h"
+
+namespace diskann
+{
+
+class AbstractGraphStore
+{
+ public:
+ AbstractGraphStore(const size_t total_pts, const size_t reserve_graph_degree)
+ : _capacity(total_pts), _reserve_graph_degree(reserve_graph_degree)
+ {
+ }
+
+ virtual ~AbstractGraphStore() = default;
+
+ // returns tuple of <nodes_read, start, num_frozen_points>
+ virtual std::tuple<uint32_t, uint32_t, size_t> load(const std::string &index_path_prefix,
+ const size_t num_points) = 0;
+ virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_fz_points,
+ const uint32_t start) = 0;
+
+ virtual int store(std::stringstream &index_stream, const size_t num_points,
+ const size_t num_frozen_points, const uint32_t start){
+ return 0;
+ }
+ virtual int save_graph(std::stringstream &out, const size_t num_points,
+ const size_t num_frozen_points, const uint32_t start){ return 0;}
+
+ // not synchronised, user should use lock when necvessary.
+ virtual const std::vector<location_t> &get_neighbours(const location_t i) const = 0;
+ virtual void add_neighbour(const location_t i, location_t neighbour_id) = 0;
+ virtual void clear_neighbours(const location_t i) = 0;
+ virtual void swap_neighbours(const location_t a, location_t b) = 0;
+
+ virtual void set_neighbours(const location_t i, std::vector<location_t> &neighbours) = 0;
+
+ virtual size_t resize_graph(const size_t new_size) = 0;
+ virtual void clear_graph() = 0;
+
+ virtual uint32_t get_max_observed_degree() = 0;
+
+ // set during load
+ virtual size_t get_max_range_of_graph() = 0;
+
+ // Total internal points _max_points + _num_frozen_points
+ size_t get_total_points()
+ {
+ return _capacity;
+ }
+
+ protected:
+ // Internal function, changes total points when resize_graph is called.
+ void set_total_points(size_t new_capacity)
+ {
+ _capacity = new_capacity;
+ }
+
+ size_t get_reserve_graph_degree()
+ {
+ return _reserve_graph_degree;
+ }
+
+ private:
+ size_t _capacity;
+ size_t _reserve_graph_degree;
+};
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/include/abstract_index.h b/be/src/extern/diskann/include/abstract_index.h
new file mode 100644
index 0000000..059866f
--- /dev/null
+++ b/be/src/extern/diskann/include/abstract_index.h
@@ -0,0 +1,129 @@
+#pragma once
+#include "distance.h"
+#include "parameters.h"
+#include "utils.h"
+#include "types.h"
+#include "index_config.h"
+#include "index_build_params.h"
+#include <any>
+
+namespace diskann
+{
+struct consolidation_report
+{
+ enum status_code
+ {
+ SUCCESS = 0,
+ FAIL = 1,
+ LOCK_FAIL = 2,
+ INCONSISTENT_COUNT_ERROR = 3
+ };
+ status_code _status;
+ size_t _active_points, _max_points, _empty_slots, _slots_released, _delete_set_size, _num_calls_to_process_delete;
+ double _time;
+
+ consolidation_report(status_code status, size_t active_points, size_t max_points, size_t empty_slots,
+ size_t slots_released, size_t delete_set_size, size_t num_calls_to_process_delete,
+ double time_secs)
+ : _status(status), _active_points(active_points), _max_points(max_points), _empty_slots(empty_slots),
+ _slots_released(slots_released), _delete_set_size(delete_set_size),
+ _num_calls_to_process_delete(num_calls_to_process_delete), _time(time_secs)
+ {
+ }
+};
+
+/* A templated independent class for intercation with Index. Uses Type Erasure to add virtual implemetation of methods
+that can take any type(using std::any) and Provides a clean API that can be inherited by different type of Index.
+*/
+class AbstractIndex
+{
+ public:
+ AbstractIndex() = default;
+ virtual ~AbstractIndex() = default;
+
+ virtual void build(const std::string &data_file, const size_t num_points_to_load,
+ IndexFilterParams &build_params) = 0;
+
+ template <typename data_type, typename tag_type>
+ void build(const data_type *data, const size_t num_points_to_load, const std::vector<tag_type> &tags);
+
+ virtual void save(const char *filename, bool compact_before_save = false) = 0;
+
+#ifdef EXEC_ENV_OLS
+ virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0;
+#else
+ virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l) = 0;
+#endif
+
+ // For FastL2 search on optimized layout
+ template <typename data_type>
+ void search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices);
+
+ // Initialize space for res_vectors before calling.
+ template <typename data_type, typename tag_type>
+ size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
+ float *distances, std::vector<data_type *> &res_vectors, bool use_filters = false,
+ const std::string filter_label = "");
+
+ // Added search overload that takes L as parameter, so that we
+ // can customize L on a per-query basis without tampering with "Parameters"
+ // IDtype is either uint32_t or uint64_t
+ template <typename data_type, typename IDType>
+ std::pair<uint32_t, uint32_t> search(const data_type *query, const size_t K, const uint32_t L, IDType *indices,
+ float *distances = nullptr);
+
+ // Filter support search
+ // IndexType is either uint32_t or uint64_t
+ template <typename IndexType>
+ std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
+ const size_t K, const uint32_t L, IndexType *indices,
+ float *distances);
+
+ // insert points with labels, labels should be present for filtered index
+ template <typename data_type, typename tag_type, typename label_type>
+ int insert_point(const data_type *point, const tag_type tag, const std::vector<label_type> &labels);
+
+ // insert point for unfiltered index build. do not use with filtered index
+ template <typename data_type, typename tag_type> int insert_point(const data_type *point, const tag_type tag);
+
+ // delete point with tag, or return -1 if point can not be deleted
+ template <typename tag_type> int lazy_delete(const tag_type &tag);
+
+ // batch delete tags and populates failed tags if unabke to delete given tags.
+ template <typename tag_type>
+ void lazy_delete(const std::vector<tag_type> &tags, std::vector<tag_type> &failed_tags);
+
+ template <typename tag_type> void get_active_tags(tsl::robin_set<tag_type> &active_tags);
+
+ template <typename data_type> void set_start_points_at_random(data_type radius, uint32_t random_seed = 0);
+
+ virtual consolidation_report consolidate_deletes(const IndexWriteParameters ¶meters) = 0;
+
+ virtual void optimize_index_layout() = 0;
+
+ // memory should be allocated for vec before calling this function
+ template <typename tag_type, typename data_type> int get_vector_by_tag(tag_type &tag, data_type *vec);
+
+ template <typename label_type> void set_universal_label(const label_type universal_label);
+
+ private:
+ virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0;
+ virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
+ std::any &indices, float *distances = nullptr) = 0;
+ virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
+ const size_t K, const uint32_t L, std::any &indices,
+ float *distances) = 0;
+ virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0;
+ virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;
+ virtual int _lazy_delete(const TagType &tag) = 0;
+ virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) = 0;
+ virtual void _get_active_tags(TagRobinSet &active_tags) = 0;
+ virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0;
+ virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0;
+ virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
+ float *distances, DataVector &res_vectors, bool use_filters = false,
+ const std::string filter_label = "") = 0;
+ virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
+ virtual void _set_universal_label(const LabelType universal_label) = 0;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/abstract_scratch.h b/be/src/extern/diskann/include/abstract_scratch.h
new file mode 100644
index 0000000..b42a836
--- /dev/null
+++ b/be/src/extern/diskann/include/abstract_scratch.h
@@ -0,0 +1,35 @@
+#pragma once
+namespace diskann
+{
+
+template <typename data_t> class PQScratch;
+
+// By somewhat more than a coincidence, it seems that both InMemQueryScratch
+// and SSDQueryScratch have the aligned query and PQScratch objects. So we
+// can put them in a neat hierarchy and keep PQScratch as a standalone class.
+template <typename data_t> class AbstractScratch
+{
+ public:
+ AbstractScratch() = default;
+ // This class does not take any responsibilty for memory management of
+ // its members. It is the responsibility of the derived classes to do so.
+ virtual ~AbstractScratch() = default;
+
+ // Scratch objects should not be copied
+ AbstractScratch(const AbstractScratch &) = delete;
+ AbstractScratch &operator=(const AbstractScratch &) = delete;
+
+ data_t *aligned_query_T()
+ {
+ return _aligned_query_T;
+ }
+ PQScratch<data_t> *pq_scratch()
+ {
+ return _pq_scratch;
+ }
+
+ protected:
+ data_t *_aligned_query_T = nullptr;
+ PQScratch<data_t> *_pq_scratch = nullptr;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/aligned_file_reader.h b/be/src/extern/diskann/include/aligned_file_reader.h
new file mode 100644
index 0000000..f8e9c30
--- /dev/null
+++ b/be/src/extern/diskann/include/aligned_file_reader.h
@@ -0,0 +1,70 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#define MAX_IO_DEPTH 128
+
+#include <vector>
+#include <atomic>
+
+#include <fcntl.h>
+#include <libaio.h>
+#include <unistd.h>
+typedef io_context_t IOContext;
+
+#include <malloc.h>
+#include <cstdio>
+#include <mutex>
+#include <thread>
+#include "tsl/robin_map.h"
+#include "utils.h"
+
+// NOTE :: all 3 fields must be 512-aligned
+struct AlignedRead
+{
+ uint64_t offset; // where to read from
+ uint64_t len; // how much to read
+ void *buf; // where to read into
+
+ AlignedRead() : offset(0), len(0), buf(nullptr)
+ {
+ }
+
+ AlignedRead(uint64_t offset, uint64_t len, void *buf) : offset(offset), len(len), buf(buf)
+ {
+ assert(IS_512_ALIGNED(offset));
+ assert(IS_512_ALIGNED(len));
+ assert(IS_512_ALIGNED(buf));
+ // assert(malloc_usable_size(buf) >= len);
+ }
+};
+
+class AlignedFileReader
+{
+ protected:
+ tsl::robin_map<std::thread::id, IOContext> ctx_map;
+ std::mutex ctx_mut;
+
+ public:
+ // returns the thread-specific context
+ // returns (io_context_t)(-1) if thread is not registered
+ virtual IOContext &get_ctx() = 0;
+
+ virtual ~AlignedFileReader(){};
+
+ // register thread-id for a context
+ virtual void register_thread() = 0;
+ // de-register thread-id for a context
+ virtual void deregister_thread() = 0;
+ virtual void deregister_all_threads() = 0;
+
+ // Open & close ops
+ // Blocking calls
+ virtual void open(const std::string &fname) = 0;
+ virtual void close() = 0;
+
+ // process batch of aligned requests in parallel
+ // NOTE :: blocking call
+ virtual void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async = false) = 0;
+};
\ No newline at end of file
diff --git a/be/src/extern/diskann/include/ann_exception.h b/be/src/extern/diskann/include/ann_exception.h
new file mode 100644
index 0000000..9f55c8c
--- /dev/null
+++ b/be/src/extern/diskann/include/ann_exception.h
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <string>
+#include <stdexcept>
+#include <system_error>
+#include <cstdint>
+
+#include "windows_customizations.h"
+
+#ifndef _WINDOWS
+#define __FUNCSIG__ __PRETTY_FUNCTION__
+#endif
+
+namespace diskann
+{
+
+class ANNException : public std::runtime_error
+{
+ public:
+ DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode);
+ DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode, const std::string &funcSig,
+ const std::string &fileName, uint32_t lineNum);
+
+ private:
+ int _errorCode;
+};
+
+class FileException : public ANNException
+{
+ public:
+ DISKANN_DLLEXPORT FileException(const std::string &filename, std::system_error &e, const std::string &funcSig,
+ const std::string &fileName, uint32_t lineNum);
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/any_wrappers.h b/be/src/extern/diskann/include/any_wrappers.h
new file mode 100644
index 0000000..da9005c
--- /dev/null
+++ b/be/src/extern/diskann/include/any_wrappers.h
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <cstdint>
+#include <cstddef>
+#include <vector>
+#include <any>
+#include "tsl/robin_set.h"
+
+namespace AnyWrapper
+{
+
+/*
+ * Base Struct to hold refrence to the data.
+ * Note: No memory mamagement, caller need to keep object alive.
+ */
+struct AnyReference
+{
+ template <typename Ty> AnyReference(Ty &reference) : _data(&reference)
+ {
+ }
+
+ template <typename Ty> Ty &get()
+ {
+ auto ptr = std::any_cast<Ty *>(_data);
+ return *ptr;
+ }
+
+ private:
+ std::any _data;
+};
+struct AnyRobinSet : public AnyReference
+{
+ template <typename T> AnyRobinSet(const tsl::robin_set<T> &robin_set) : AnyReference(robin_set)
+ {
+ }
+ template <typename T> AnyRobinSet(tsl::robin_set<T> &robin_set) : AnyReference(robin_set)
+ {
+ }
+};
+
+struct AnyVector : public AnyReference
+{
+ template <typename T> AnyVector(const std::vector<T> &vector) : AnyReference(vector)
+ {
+ }
+ template <typename T> AnyVector(std::vector<T> &vector) : AnyReference(vector)
+ {
+ }
+};
+} // namespace AnyWrapper
diff --git a/be/src/extern/diskann/include/boost_dynamic_bitset_fwd.h b/be/src/extern/diskann/include/boost_dynamic_bitset_fwd.h
new file mode 100644
index 0000000..5aebb2b
--- /dev/null
+++ b/be/src/extern/diskann/include/boost_dynamic_bitset_fwd.h
@@ -0,0 +1,11 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+namespace boost
+{
+#ifndef BOOST_DYNAMIC_BITSET_FWD_HPP
+template <typename Block = unsigned long, typename Allocator = std::allocator<Block>> class dynamic_bitset;
+#endif
+} // namespace boost
diff --git a/be/src/extern/diskann/include/cached_io.h b/be/src/extern/diskann/include/cached_io.h
new file mode 100644
index 0000000..daef2f2
--- /dev/null
+++ b/be/src/extern/diskann/include/cached_io.h
@@ -0,0 +1,217 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <sstream>
+
+#include "logger.h"
+#include "ann_exception.h"
+
+// sequential cached reads
+class cached_ifstream
+{
+ public:
+ cached_ifstream()
+ {
+ }
+ cached_ifstream(const std::string &filename, uint64_t cacheSize) : cache_size(cacheSize), cur_off(0)
+ {
+ reader.exceptions(std::ifstream::failbit | std::ifstream::badbit);
+ this->open(filename, cache_size);
+ }
+ ~cached_ifstream()
+ {
+ delete[] cache_buf;
+ reader.close();
+ }
+
+ void open(const std::string &filename, uint64_t cacheSize)
+ {
+ this->cur_off = 0;
+
+ try
+ {
+ reader.open(filename, std::ios::binary | std::ios::ate);
+ fsize = reader.tellg();
+ reader.seekg(0, std::ios::beg);
+ assert(reader.is_open());
+ assert(cacheSize > 0);
+ cacheSize = (std::min)(cacheSize, fsize);
+ this->cache_size = cacheSize;
+ cache_buf = new char[cacheSize];
+ reader.read(cache_buf, cacheSize);
+ diskann::cout << "Opened: " << filename.c_str() << ", size: " << fsize << ", cache_size: " << cacheSize
+ << std::endl;
+ }
+ catch (std::system_error &e)
+ {
+ throw diskann::FileException(filename, e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+
+ size_t get_file_size()
+ {
+ return fsize;
+ }
+
+ void read(char *read_buf, uint64_t n_bytes)
+ {
+ assert(cache_buf != nullptr);
+ assert(read_buf != nullptr);
+
+ if (n_bytes <= (cache_size - cur_off))
+ {
+ // case 1: cache contains all data
+ memcpy(read_buf, cache_buf + cur_off, n_bytes);
+ cur_off += n_bytes;
+ }
+ else
+ {
+ // case 2: cache contains some data
+ uint64_t cached_bytes = cache_size - cur_off;
+ if (n_bytes - cached_bytes > fsize - reader.tellg())
+ {
+ std::stringstream stream;
+ stream << "Reading beyond end of file" << std::endl;
+ stream << "n_bytes: " << n_bytes << " cached_bytes: " << cached_bytes << " fsize: " << fsize
+ << " current pos:" << reader.tellg() << std::endl;
+ diskann::cout << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ memcpy(read_buf, cache_buf + cur_off, cached_bytes);
+
+ // go to disk and fetch more data
+ reader.read(read_buf + cached_bytes, n_bytes - cached_bytes);
+ // reset cur off
+ cur_off = cache_size;
+
+ uint64_t size_left = fsize - reader.tellg();
+
+ if (size_left >= cache_size)
+ {
+ reader.read(cache_buf, cache_size);
+ cur_off = 0;
+ }
+ // note that if size_left < cache_size, then cur_off = cache_size,
+ // so subsequent reads will all be directly from file
+ }
+ }
+
+ private:
+ // underlying ifstream
+ std::ifstream reader;
+ // # bytes to cache in one shot read
+ uint64_t cache_size = 0;
+ // underlying buf for cache
+ char *cache_buf = nullptr;
+ // offset into cache_buf for cur_pos
+ uint64_t cur_off = 0;
+ // file size
+ uint64_t fsize = 0;
+};
+
+// sequential cached writes
+class cached_ofstream
+{
+ public:
+ cached_ofstream(const std::string &filename, uint64_t cache_size) : cache_size(cache_size), cur_off(0)
+ {
+ writer.exceptions(std::ifstream::failbit | std::ifstream::badbit);
+ try
+ {
+ writer.open(filename, std::ios::binary);
+ assert(writer.is_open());
+ assert(cache_size > 0);
+ cache_buf = new char[cache_size];
+ diskann::cout << "Opened: " << filename.c_str() << ", cache_size: " << cache_size << std::endl;
+ }
+ catch (std::system_error &e)
+ {
+ throw diskann::FileException(filename, e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+
+ ~cached_ofstream()
+ {
+ this->close();
+ }
+
+ void close()
+ {
+ // dump any remaining data in memory
+ if (cur_off > 0)
+ {
+ this->flush_cache();
+ }
+
+ if (cache_buf != nullptr)
+ {
+ delete[] cache_buf;
+ cache_buf = nullptr;
+ }
+
+ if (writer.is_open())
+ writer.close();
+ diskann::cout << "Finished writing " << fsize << "B" << std::endl;
+ }
+
+ size_t get_file_size()
+ {
+ return fsize;
+ }
+ // writes n_bytes from write_buf to the underlying ofstream/cache
+ void write(char *write_buf, uint64_t n_bytes)
+ {
+ assert(cache_buf != nullptr);
+ if (n_bytes <= (cache_size - cur_off))
+ {
+ // case 1: cache can take all data
+ memcpy(cache_buf + cur_off, write_buf, n_bytes);
+ cur_off += n_bytes;
+ }
+ else
+ {
+ // case 2: cache cant take all data
+ // go to disk and write existing cache data
+ writer.write(cache_buf, cur_off);
+ fsize += cur_off;
+ // write the new data to disk
+ writer.write(write_buf, n_bytes);
+ fsize += n_bytes;
+ // memset all cache data and reset cur_off
+ memset(cache_buf, 0, cache_size);
+ cur_off = 0;
+ }
+ }
+
+ void flush_cache()
+ {
+ assert(cache_buf != nullptr);
+ writer.write(cache_buf, cur_off);
+ fsize += cur_off;
+ memset(cache_buf, 0, cache_size);
+ cur_off = 0;
+ }
+
+ void reset()
+ {
+ flush_cache();
+ writer.seekp(0);
+ }
+
+ private:
+ // underlying ofstream
+ std::ofstream writer;
+ // # bytes to cache for one shot write
+ uint64_t cache_size = 0;
+ // underlying buf for cache
+ char *cache_buf = nullptr;
+ // offset into cache_buf for cur_pos
+ uint64_t cur_off = 0;
+
+ // file size
+ uint64_t fsize = 0;
+};
diff --git a/be/src/extern/diskann/include/combined_file.h b/be/src/extern/diskann/include/combined_file.h
new file mode 100644
index 0000000..d5285f3
--- /dev/null
+++ b/be/src/extern/diskann/include/combined_file.h
@@ -0,0 +1,105 @@
+#pragma once
+
+#include <iostream>
+#include <fstream>
+#include <string>
+#include <vector>
+#include <cstring>
+#include <cstdio>
+#include <map>
+#include <memory>
+#include <mutex>
+
+namespace diskann {
+
+// 抽象的文件读取器基类
+class Reader {
+ public:
+ virtual ~Reader() {}
+ virtual size_t read(char* buffer, uint64_t offset, size_t size) = 0;
+ virtual void seek(uint64_t offset) = 0;
+ virtual bool open(const std::string& file_path) = 0;
+ virtual void close() = 0;
+ virtual uint64_t get_file_size() = 0;
+ virtual uint64_t get_base_offset() = 0;
+};
+
+class IOWriter {
+ public:
+ virtual ~IOWriter() {}
+ virtual void write(const char* data, size_t size) = 0;
+ virtual void close() = 0;
+};
+
+
+
+class LocalFileReader : public Reader {
+ private:
+ std::ifstream in;
+ std::string _file_path;
+ //文件起始
+ uint64_t _base_offset = 0;
+ uint64_t _size = 0 ;
+ std::mutex _mutex;
+ public:
+ LocalFileReader(const std::string& file_path, uint64_t start = 0, uint64_t size = 0);
+
+ LocalFileReader();
+
+ bool open(const std::string& file_path);
+
+ uint64_t get_file_size();
+
+ size_t read(char* buffer, uint64_t offset, size_t size);
+
+ void seek(uint64_t offset);
+
+ void close();
+
+ uint64_t get_base_offset() {
+ return _base_offset;
+ }
+};
+
+class LocalFileIOWriter : public IOWriter {
+ public:
+ LocalFileIOWriter(const std::string& file_path);
+ void write(const char* data, size_t size);
+ void close();
+ private:
+ std::ofstream out;
+};
+
+class FileMerger {
+ private:
+ struct FileInfo {
+ std::string alias;
+ std::string path;
+ uint64_t offset;
+ uint64_t size;
+ };
+ std::vector<FileInfo> files;
+ uint64_t current_offset;
+
+ // 使用map来缓存已经实例化过的reader,键为文件别名,值为对应的reader智能指针
+ std::map<std::string, std::shared_ptr<Reader>> reader_cache;
+
+ public:
+ FileMerger() : current_offset(0) {}
+
+ // 添加文件方法
+ bool add(const std::string& alias, const std::string& path);
+ void clear();
+
+ // 把最终合并的文件写到磁盘,文件的meta信息要放到最后,且方便反序列化,同时记录total_meta_size并写入磁盘末尾
+ void save(IOWriter* writer);
+
+ // 根据文件名原始文件别名, 获取base_offset和文件大小,并根据base_offset和文件创建一个reader,返回这个reader
+ // 这里从磁盘读取文件相关offset和size信息(假设磁盘文件格式符合之前定义的规则),先获取total_meta_size来定位meta起始位置
+ // 同时添加了缓存功能,记录已经找到的reader,如果发现之前实例化过则直接返回
+ template <typename ReaderType>
+ std::shared_ptr<ReaderType> get_reader(const std::string& alias, const std::string& merged_file_path);
+};
+
+
+}
diff --git a/be/src/extern/diskann/include/common_includes.h b/be/src/extern/diskann/include/common_includes.h
new file mode 100644
index 0000000..e1a51bd
--- /dev/null
+++ b/be/src/extern/diskann/include/common_includes.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <algorithm>
+#include <atomic>
+#include <cassert>
+#include <chrono>
+#include <climits>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <ctime>
+#include <fcntl.h>
+#include <fstream>
+#include <iostream>
+#include <iomanip>
+#include <omp.h>
+#include <queue>
+#include <random>
+#include <set>
+#include <shared_mutex>
+#include <sys/stat.h>
+#include <sstream>
+#include <unordered_map>
+#include <vector>
diff --git a/be/src/extern/diskann/include/concurrent_queue.h b/be/src/extern/diskann/include/concurrent_queue.h
new file mode 100644
index 0000000..1e57bbf
--- /dev/null
+++ b/be/src/extern/diskann/include/concurrent_queue.h
@@ -0,0 +1,132 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <mutex>
+#include <queue>
+#include <thread>
+#include <type_traits>
+#include <unordered_set>
+
+namespace diskann
+{
+
+template <typename T> class ConcurrentQueue
+{
+ typedef std::chrono::microseconds chrono_us_t;
+ typedef std::unique_lock<std::mutex> mutex_locker;
+
+ std::queue<T> q;
+ std::mutex mut;
+ std::mutex push_mut;
+ std::mutex pop_mut;
+ std::condition_variable push_cv;
+ std::condition_variable pop_cv;
+ T null_T;
+
+ public:
+ ConcurrentQueue()
+ {
+ }
+
+ ConcurrentQueue(T nullT)
+ {
+ this->null_T = nullT;
+ }
+
+ ~ConcurrentQueue()
+ {
+ this->push_cv.notify_all();
+ this->pop_cv.notify_all();
+ }
+
+ // queue stats
+ uint64_t size()
+ {
+ mutex_locker lk(this->mut);
+ uint64_t ret = q.size();
+ lk.unlock();
+ return ret;
+ }
+
+ bool empty()
+ {
+ return (this->size() == 0);
+ }
+
+ // PUSH BACK
+ void push(T &new_val)
+ {
+ mutex_locker lk(this->mut);
+ this->q.push(new_val);
+ lk.unlock();
+ }
+
+ template <class Iterator> void insert(Iterator iter_begin, Iterator iter_end)
+ {
+ mutex_locker lk(this->mut);
+ for (Iterator it = iter_begin; it != iter_end; it++)
+ {
+ this->q.push(*it);
+ }
+ lk.unlock();
+ }
+
+ // POP FRONT
+ T pop()
+ {
+ mutex_locker lk(this->mut);
+ if (this->q.empty())
+ {
+ lk.unlock();
+ return this->null_T;
+ }
+ else
+ {
+ T ret = this->q.front();
+ this->q.pop();
+ // diskann::cout << "thread_id: " << std::this_thread::get_id() <<
+ // ", ctx: "
+ // << ret.ctx << "\n";
+ lk.unlock();
+ return ret;
+ }
+ }
+
+ // register for notifications
+ void wait_for_push_notify(chrono_us_t wait_time = chrono_us_t{10})
+ {
+ mutex_locker lk(this->push_mut);
+ this->push_cv.wait_for(lk, wait_time);
+ lk.unlock();
+ }
+
+ void wait_for_pop_notify(chrono_us_t wait_time = chrono_us_t{10})
+ {
+ mutex_locker lk(this->pop_mut);
+ this->pop_cv.wait_for(lk, wait_time);
+ lk.unlock();
+ }
+
+ // just notify functions
+ void push_notify_one()
+ {
+ this->push_cv.notify_one();
+ }
+ void push_notify_all()
+ {
+ this->push_cv.notify_all();
+ }
+ void pop_notify_one()
+ {
+ this->pop_cv.notify_one();
+ }
+ void pop_notify_all()
+ {
+ this->pop_cv.notify_all();
+ }
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/cosine_similarity.h b/be/src/extern/diskann/include/cosine_similarity.h
new file mode 100644
index 0000000..dc51f6c
--- /dev/null
+++ b/be/src/extern/diskann/include/cosine_similarity.h
@@ -0,0 +1,283 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <immintrin.h>
+#include <smmintrin.h>
+#include <tmmintrin.h>
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <vector>
+#include <limits>
+#include <algorithm>
+#include <stdexcept>
+
+#include "simd_utils.h"
+
+extern bool Avx2SupportedCPU;
+
+#ifdef _WINDOWS
+// SIMD implementation of Cosine similarity. Taken from hnsw library.
+
+/**
+ * Non-metric Space Library
+ *
+ * Authors: Bilegsaikhan Naidan (https://github.com/bileg), Leonid Boytsov
+ * (http://boytsov.info). With contributions from Lawrence Cayton
+ * (http://lcayton.com/) and others.
+ *
+ * For the complete list of contributors and further details see:
+ * https://github.com/searchivarius/NonMetricSpaceLib
+ *
+ * Copyright (c) 2014
+ *
+ * This code is released under the
+ * Apache License Version 2.0 http://www.apache.org/licenses/.
+ *
+ */
+
+namespace diskann
+{
+
+using namespace std;
+
+#define PORTABLE_ALIGN16 __declspec(align(16))
+
+static float NormScalarProductSIMD2(const int8_t *pVect1, const int8_t *pVect2, uint32_t qty)
+{
+ if (Avx2SupportedCPU)
+ {
+ __m256 cos, p1Len, p2Len;
+ cos = p1Len = p2Len = _mm256_setzero_ps();
+ while (qty >= 32)
+ {
+ __m256i rx = _mm256_load_si256((__m256i *)pVect1), ry = _mm256_load_si256((__m256i *)pVect2);
+ cos = _mm256_add_ps(cos, _mm256_mul_epi8(rx, ry));
+ p1Len = _mm256_add_ps(p1Len, _mm256_mul_epi8(rx, rx));
+ p2Len = _mm256_add_ps(p2Len, _mm256_mul_epi8(ry, ry));
+ pVect1 += 32;
+ pVect2 += 32;
+ qty -= 32;
+ }
+ while (qty > 0)
+ {
+ __m128i rx = _mm_load_si128((__m128i *)pVect1), ry = _mm_load_si128((__m128i *)pVect2);
+ cos = _mm256_add_ps(cos, _mm256_mul32_pi8(rx, ry));
+ p1Len = _mm256_add_ps(p1Len, _mm256_mul32_pi8(rx, rx));
+ p2Len = _mm256_add_ps(p2Len, _mm256_mul32_pi8(ry, ry));
+ pVect1 += 4;
+ pVect2 += 4;
+ qty -= 4;
+ }
+ cos = _mm256_hadd_ps(_mm256_hadd_ps(cos, cos), cos);
+ p1Len = _mm256_hadd_ps(_mm256_hadd_ps(p1Len, p1Len), p1Len);
+ p2Len = _mm256_hadd_ps(_mm256_hadd_ps(p2Len, p2Len), p2Len);
+ float denominator = max(numeric_limits<float>::min() * 2, sqrt(p1Len.m256_f32[0] + p1Len.m256_f32[4]) *
+ sqrt(p2Len.m256_f32[0] + p2Len.m256_f32[4]));
+ float cosine = (cos.m256_f32[0] + cos.m256_f32[4]) / denominator;
+
+ return max(float(-1), min(float(1), cosine));
+ }
+
+ __m128 cos, p1Len, p2Len;
+ cos = p1Len = p2Len = _mm_setzero_ps();
+ __m128i rx, ry;
+ while (qty >= 16)
+ {
+ rx = _mm_load_si128((__m128i *)pVect1);
+ ry = _mm_load_si128((__m128i *)pVect2);
+ cos = _mm_add_ps(cos, _mm_mul_epi8(rx, ry));
+ p1Len = _mm_add_ps(p1Len, _mm_mul_epi8(rx, rx));
+ p2Len = _mm_add_ps(p2Len, _mm_mul_epi8(ry, ry));
+ pVect1 += 16;
+ pVect2 += 16;
+ qty -= 16;
+ }
+ while (qty > 0)
+ {
+ rx = _mm_load_si128((__m128i *)pVect1);
+ ry = _mm_load_si128((__m128i *)pVect2);
+ cos = _mm_add_ps(cos, _mm_mul32_pi8(rx, ry));
+ p1Len = _mm_add_ps(p1Len, _mm_mul32_pi8(rx, rx));
+ p2Len = _mm_add_ps(p2Len, _mm_mul32_pi8(ry, ry));
+ pVect1 += 4;
+ pVect2 += 4;
+ qty -= 4;
+ }
+ cos = _mm_hadd_ps(_mm_hadd_ps(cos, cos), cos);
+ p1Len = _mm_hadd_ps(_mm_hadd_ps(p1Len, p1Len), p1Len);
+ p2Len = _mm_hadd_ps(_mm_hadd_ps(p2Len, p2Len), p2Len);
+ float norm1 = p1Len.m128_f32[0];
+ float norm2 = p2Len.m128_f32[0];
+
+ static const float eps = numeric_limits<float>::min() * 2;
+
+ if (norm1 < eps)
+ { /*
+ * This shouldn't normally happen for this space, but
+ * if it does, we don't want to get NANs
+ */
+ if (norm2 < eps)
+ {
+ return 1;
+ }
+ return 0;
+ }
+ /*
+ * Sometimes due to rounding errors, we get values > 1 or < -1.
+ * This throws off other functions that use scalar product, e.g., acos
+ */
+ return max(float(-1), min(float(1), cos.m128_f32[0] / sqrt(norm1) / sqrt(norm2)));
+}
+
+static float NormScalarProductSIMD(const float *pVect1, const float *pVect2, uint32_t qty)
+{
+ // Didn't get significant performance gain compared with 128bit version.
+ static const float eps = numeric_limits<float>::min() * 2;
+
+ if (Avx2SupportedCPU)
+ {
+ uint32_t qty8 = qty / 8;
+
+ const float *pEnd1 = pVect1 + 8 * qty8;
+ const float *pEnd2 = pVect1 + qty;
+
+ __m256 v1, v2;
+ __m256 sum_prod = _mm256_set_ps(0, 0, 0, 0, 0, 0, 0, 0);
+ __m256 sum_square1 = sum_prod;
+ __m256 sum_square2 = sum_prod;
+
+ while (pVect1 < pEnd1)
+ {
+ v1 = _mm256_loadu_ps(pVect1);
+ pVect1 += 8;
+ v2 = _mm256_loadu_ps(pVect2);
+ pVect2 += 8;
+ sum_prod = _mm256_add_ps(sum_prod, _mm256_mul_ps(v1, v2));
+ sum_square1 = _mm256_add_ps(sum_square1, _mm256_mul_ps(v1, v1));
+ sum_square2 = _mm256_add_ps(sum_square2, _mm256_mul_ps(v2, v2));
+ }
+
+ float PORTABLE_ALIGN16 TmpResProd[8];
+ float PORTABLE_ALIGN16 TmpResSquare1[8];
+ float PORTABLE_ALIGN16 TmpResSquare2[8];
+
+ _mm256_store_ps(TmpResProd, sum_prod);
+ _mm256_store_ps(TmpResSquare1, sum_square1);
+ _mm256_store_ps(TmpResSquare2, sum_square2);
+
+ float sum = 0.0f;
+ float norm1 = 0.0f;
+ float norm2 = 0.0f;
+ for (uint32_t i = 0; i < 8; ++i)
+ {
+ sum += TmpResProd[i];
+ norm1 += TmpResSquare1[i];
+ norm2 += TmpResSquare2[i];
+ }
+
+ while (pVect1 < pEnd2)
+ {
+ sum += (*pVect1) * (*pVect2);
+ norm1 += (*pVect1) * (*pVect1);
+ norm2 += (*pVect2) * (*pVect2);
+
+ ++pVect1;
+ ++pVect2;
+ }
+
+ if (norm1 < eps)
+ {
+ return norm2 < eps ? 1.0f : 0.0f;
+ }
+
+ return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2)));
+ }
+
+ __m128 v1, v2;
+ __m128 sum_prod = _mm_set1_ps(0);
+ __m128 sum_square1 = sum_prod;
+ __m128 sum_square2 = sum_prod;
+
+ while (qty >= 4)
+ {
+ v1 = _mm_loadu_ps(pVect1);
+ pVect1 += 4;
+ v2 = _mm_loadu_ps(pVect2);
+ pVect2 += 4;
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
+ sum_square1 = _mm_add_ps(sum_square1, _mm_mul_ps(v1, v1));
+ sum_square2 = _mm_add_ps(sum_square2, _mm_mul_ps(v2, v2));
+
+ qty -= 4;
+ }
+
+ float sum = sum_prod.m128_f32[0] + sum_prod.m128_f32[1] + sum_prod.m128_f32[2] + sum_prod.m128_f32[3];
+ float norm1 = sum_square1.m128_f32[0] + sum_square1.m128_f32[1] + sum_square1.m128_f32[2] + sum_square1.m128_f32[3];
+ float norm2 = sum_square2.m128_f32[0] + sum_square2.m128_f32[1] + sum_square2.m128_f32[2] + sum_square2.m128_f32[3];
+
+ if (norm1 < eps)
+ {
+ return norm2 < eps ? 1.0f : 0.0f;
+ }
+
+ return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2)));
+}
+
+static float NormScalarProductSIMD2(const float *pVect1, const float *pVect2, uint32_t qty)
+{
+ return NormScalarProductSIMD(pVect1, pVect2, qty);
+}
+
+template <class T> static float CosineSimilarity2(const T *p1, const T *p2, uint32_t qty)
+{
+ return std::max(0.0f, 1.0f - NormScalarProductSIMD2(p1, p2, qty));
+}
+
+// static template float CosineSimilarity2<__int8>(const __int8* pVect1,
+// const __int8* pVect2, size_t qty);
+
+// static template float CosineSimilarity2<float>(const float* pVect1,
+// const float* pVect2, size_t qty);
+
+template <class T> static void CosineSimilarityNormalize(T *pVector, uint32_t qty)
+{
+ T sum = 0;
+ for (uint32_t i = 0; i < qty; ++i)
+ {
+ sum += pVector[i] * pVector[i];
+ }
+ sum = 1 / sqrt(sum);
+ if (sum == 0)
+ {
+ sum = numeric_limits<T>::min();
+ }
+ for (uint32_t i = 0; i < qty; ++i)
+ {
+ pVector[i] *= sum;
+ }
+}
+
+// template static void CosineSimilarityNormalize<float>(float* pVector,
+// size_t qty);
+// template static void CosineSimilarityNormalize<double>(double* pVector,
+// size_t qty);
+
+template <> void CosineSimilarityNormalize(__int8 * /*pVector*/, uint32_t /*qty*/)
+{
+ throw std::runtime_error("For int8 type vector, you can not use cosine distance!");
+}
+
+template <> void CosineSimilarityNormalize(__int16 * /*pVector*/, uint32_t /*qty*/)
+{
+ throw std::runtime_error("For int16 type vector, you can not use cosine distance!");
+}
+
+template <> void CosineSimilarityNormalize(int * /*pVector*/, uint32_t /*qty*/)
+{
+ throw std::runtime_error("For int type vector, you can not use cosine distance!");
+}
+} // namespace diskann
+#endif
\ No newline at end of file
diff --git a/be/src/extern/diskann/include/defaults.h b/be/src/extern/diskann/include/defaults.h
new file mode 100644
index 0000000..ef1750f
--- /dev/null
+++ b/be/src/extern/diskann/include/defaults.h
@@ -0,0 +1,34 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <stdint.h>
+
+namespace diskann
+{
+namespace defaults
+{
+const float ALPHA = 1.2f;
+const uint32_t NUM_THREADS = 0;
+const uint32_t MAX_OCCLUSION_SIZE = 750;
+const bool HAS_LABELS = false;
+const uint32_t FILTER_LIST_SIZE = 0;
+const uint32_t NUM_FROZEN_POINTS_STATIC = 0;
+const uint32_t NUM_FROZEN_POINTS_DYNAMIC = 1;
+
+// In-mem index related limits
+const float GRAPH_SLACK_FACTOR = 1.3f;
+
+// SSD Index related limits
+const uint64_t MAX_GRAPH_DEGREE = 512;
+const uint64_t SECTOR_LEN = 4096;
+const uint64_t MAX_N_SECTOR_READS = 128;
+
+// following constants should always be specified, but are useful as a
+// sensible default at cli / python boundaries
+const uint32_t MAX_DEGREE = 64;
+const uint32_t BUILD_LIST_SIZE = 100;
+const uint32_t SATURATE_GRAPH = false;
+const uint32_t SEARCH_LIST_SIZE = 100;
+} // namespace defaults
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/disk_utils.h b/be/src/extern/diskann/include/disk_utils.h
new file mode 100644
index 0000000..ab63359
--- /dev/null
+++ b/be/src/extern/diskann/include/disk_utils.h
@@ -0,0 +1,135 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <algorithm>
+#include <fcntl.h>
+#include <cassert>
+#include <cstdlib>
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <random>
+#include <set>
+#ifdef __APPLE__
+#else
+#include <malloc.h>
+#endif
+
+#ifdef _WINDOWS
+#include <Windows.h>
+typedef HANDLE FileHandle;
+#else
+#include <unistd.h>
+typedef int FileHandle;
+#endif
+
+#include "cached_io.h"
+#include "common_includes.h"
+
+#include "utils.h"
+#include "windows_customizations.h"
+
+namespace diskann
+{
+const size_t MAX_SAMPLE_POINTS_FOR_WARMUP = 100000;
+const double PQ_TRAINING_SET_FRACTION = 0.1;
+const double SPACE_FOR_CACHED_NODES_IN_GB = 0.25;
+const double THRESHOLD_FOR_CACHING_IN_GB = 1.0;
+const uint32_t NUM_NODES_TO_CACHE = 250000;
+const uint32_t WARMUP_L = 20;
+const uint32_t NUM_KMEANS_REPS = 12;
+
+const std::string PQ_PIVOTS = "pq_pivots";
+const std::string PQ_COMPRESSED = "pq_compressed";
+const std::string MEM_INDEX = "mem_index";
+const std::string DISK_INDEX_PATH = "disk_index_path";
+const std::string MEM_INDEX_DATA = "mem_index_data";
+
+template <typename T, typename LabelT> class PQFlashIndex;
+
+DISKANN_DLLEXPORT double get_memory_budget(const std::string &mem_budget_str);
+DISKANN_DLLEXPORT double get_memory_budget(double search_ram_budget_in_gb);
+DISKANN_DLLEXPORT void add_new_file_to_single_index(std::string index_file, std::string new_file);
+
+DISKANN_DLLEXPORT size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, uint32_t dim);
+
+DISKANN_DLLEXPORT void read_idmap(const std::string &fname, std::vector<uint32_t> &ivecs);
+
+#ifdef EXEC_ENV_OLS
+template <typename T>
+DISKANN_DLLEXPORT T *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, uint64_t &warmup_num,
+ uint64_t warmup_dim, uint64_t warmup_aligned_dim);
+#else
+template <typename T>
+DISKANN_DLLEXPORT T *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, uint64_t warmup_dim,
+ uint64_t warmup_aligned_dim);
+#endif
+
+DISKANN_DLLEXPORT int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suffix,
+ const std::string &idmaps_prefix, const std::string &idmaps_suffix,
+ const uint64_t nshards, uint32_t max_degree, const std::string &output_vamana,
+ const std::string &medoids_file, bool use_filters = false,
+ const std::string &labels_to_medoids_file = std::string(""));
+
+DISKANN_DLLEXPORT void extract_shard_labels(const std::string &in_label_file, const std::string &shard_ids_bin,
+ const std::string &shard_label_file);
+
+template <typename T>
+DISKANN_DLLEXPORT std::string preprocess_base_file(const std::string &infile, const std::string &indexPrefix,
+ diskann::Metric &distMetric);
+
+template <typename T, typename LabelT = uint32_t>
+DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann::Metric _compareMetric, uint32_t L,
+ uint32_t R, double sampling_rate, double ram_budget,
+ std::string mem_index_path, std::string medoids_file,
+ std::string centroids_file, size_t build_pq_bytes, bool use_opq,
+ uint32_t num_threads, bool use_filters = false,
+ const std::string &label_file = std::string(""),
+ const std::string &labels_to_medoids_file = std::string(""),
+ const std::string &universal_label = "", const uint32_t Lf = 0);
+
+template <typename T, typename LabelT = uint32_t>
+DISKANN_DLLEXPORT int build_merged_vamana_index(std::stringstream & data_stream, diskann::Metric _compareMetric, uint32_t L,
+ uint32_t R, double sampling_rate, double ram_budget,
+ std::stringstream &index_stream, std::string medoids_file,
+ std::string centroids_file, size_t build_pq_bytes, bool use_opq,
+ uint32_t num_threads, bool use_filters = false,
+ const std::string &label_file = std::string(""),
+ const std::string &labels_to_medoids_file = std::string(""),
+ const std::string &universal_label = "", const uint32_t Lf = 0);
+
+
+template <typename T, typename LabelT>
+DISKANN_DLLEXPORT uint32_t optimize_beamwidth(std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &_pFlashIndex,
+ T *tuning_sample, uint64_t tuning_sample_num,
+ uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
+ uint32_t start_bw = 2);
+
+template <typename T, typename LabelT = uint32_t>
+DISKANN_DLLEXPORT int build_disk_index(
+ const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters,
+ diskann::Metric _compareMetric, bool use_opq = false,
+ const std::string &codebook_prefix = "", // default is empty for no codebook pass in
+ bool use_filters = false,
+ const std::string &label_file = std::string(""), // default is empty string for no label_file
+ const std::string &universal_label = "", const uint32_t filter_threshold = 0,
+ const uint32_t Lf = 0); // default is empty string for no universal label
+
+template <typename T>
+DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file,
+ const std::string output_file,
+ const std::string reorder_data_file = std::string(""));
+
+template <typename T>
+DISKANN_DLLEXPORT void create_disk_layout(const T &data, const std::string mem_index_file,
+ const std::string output_file,
+ const std::string reorder_data_file = std::string(""));
+
+template <typename T>
+DISKANN_DLLEXPORT void create_disk_layout(std::stringstream & data_stream, std::stringstream &index_stream,
+ std::stringstream &disklayout_stream,
+ const std::string reorder_data_file = std::string(""));
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/distance.h b/be/src/extern/diskann/include/distance.h
new file mode 100644
index 0000000..f3b1de2
--- /dev/null
+++ b/be/src/extern/diskann/include/distance.h
@@ -0,0 +1,235 @@
+#pragma once
+#include "windows_customizations.h"
+#include <cstring>
+
+namespace diskann
+{
+enum Metric
+{
+ L2 = 0,
+ INNER_PRODUCT = 1,
+ COSINE = 2,
+ FAST_L2 = 3
+};
+
+template <typename T> class Distance
+{
+ public:
+ DISKANN_DLLEXPORT Distance(diskann::Metric dist_metric) : _distance_metric(dist_metric)
+ {
+ }
+
+ // distance comparison function
+ DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const = 0;
+
+ // Needed only for COSINE-BYTE and INNER_PRODUCT-BYTE
+ DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, const float normA, const float normB,
+ uint32_t length) const;
+
+ // For MIPS, normalization adds an extra dimension to the vectors.
+ // This function lets callers know if the normalization process
+ // changes the dimension.
+ DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const;
+
+ DISKANN_DLLEXPORT virtual diskann::Metric get_metric() const;
+
+ // This is for efficiency. If no normalization is required, the callers
+ // can simply ignore the normalize_data_for_build() function.
+ DISKANN_DLLEXPORT virtual bool preprocessing_required() const;
+
+ // Check the preprocessing_required() function before calling this.
+ // Clients can call the function like this:
+ //
+ // if (metric->preprocessing_required()){
+ // T* normalized_data_batch;
+ // Split data into batches of batch_size and for each, call:
+ // metric->preprocess_base_points(data_batch, batch_size);
+ //
+ // TODO: This does not take into account the case for SSD inner product
+ // where the dimensions change after normalization.
+ DISKANN_DLLEXPORT virtual void preprocess_base_points(T *original_data, const size_t orig_dim,
+ const size_t num_points);
+
+ // Invokes normalization for a single vector during search. The scratch space
+ // has to be created by the caller keeping track of the fact that
+ // normalization might change the dimension of the query vector.
+ DISKANN_DLLEXPORT virtual void preprocess_query(const T *query_vec, const size_t query_dim, T *scratch_query);
+
+ // If an algorithm has a requirement that some data be aligned to a certain
+ // boundary it can use this function to indicate that requirement. Currently,
+ // we are setting it to 8 because that works well for AVX2. If we have AVX512
+ // implementations of distance algos, they might have to set this to 16
+ // (depending on how they are implemented)
+ DISKANN_DLLEXPORT virtual size_t get_required_alignment() const;
+
+ // Providing a default implementation for the virtual destructor because we
+ // don't expect most metric implementations to need it.
+ DISKANN_DLLEXPORT virtual ~Distance() = default;
+
+ protected:
+ diskann::Metric _distance_metric;
+ size_t _alignment_factor = 8;
+};
+
+class DistanceCosineInt8 : public Distance<int8_t>
+{
+ public:
+ DistanceCosineInt8() : Distance<int8_t>(diskann::Metric::COSINE)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const;
+};
+
+class DistanceL2Int8 : public Distance<int8_t>
+{
+ public:
+ DistanceL2Int8() : Distance<int8_t>(diskann::Metric::L2)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size) const;
+};
+
+// AVX implementations. Borrowed from HNSW code.
+class AVXDistanceL2Int8 : public Distance<int8_t>
+{
+ public:
+ AVXDistanceL2Int8() : Distance<int8_t>(diskann::Metric::L2)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const;
+};
+
+class DistanceCosineFloat : public Distance<float>
+{
+ public:
+ DistanceCosineFloat() : Distance<float>(diskann::Metric::COSINE)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
+};
+
+class DistanceL2Float : public Distance<float>
+{
+ public:
+ DistanceL2Float() : Distance<float>(diskann::Metric::L2)
+ {
+ }
+
+#ifdef _WINDOWS
+ DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const;
+#else
+ DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const __attribute__((hot));
+#endif
+};
+
+class AVXDistanceL2Float : public Distance<float>
+{
+ public:
+ AVXDistanceL2Float() : Distance<float>(diskann::Metric::L2)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
+};
+
+template <typename T> class SlowDistanceL2 : public Distance<T>
+{
+ public:
+ SlowDistanceL2() : Distance<T>(diskann::Metric::L2)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const;
+};
+
+class SlowDistanceCosineUInt8 : public Distance<uint8_t>
+{
+ public:
+ SlowDistanceCosineUInt8() : Distance<uint8_t>(diskann::Metric::COSINE)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length) const;
+};
+
+class DistanceL2UInt8 : public Distance<uint8_t>
+{
+ public:
+ DistanceL2UInt8() : Distance<uint8_t>(diskann::Metric::L2)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size) const;
+};
+
+template <typename T> class DistanceInnerProduct : public Distance<T>
+{
+ public:
+ DistanceInnerProduct() : Distance<T>(diskann::Metric::INNER_PRODUCT)
+ {
+ }
+
+ DistanceInnerProduct(diskann::Metric metric) : Distance<T>(metric)
+ {
+ }
+ inline float inner_product(const T *a, const T *b, unsigned size) const;
+
+ inline float compare(const T *a, const T *b, unsigned size) const
+ {
+ float result = inner_product(a, b, size);
+ // if (result < 0)
+ // return std::numeric_limits<float>::max();
+ // else
+ return -result;
+ }
+};
+
+template <typename T> class DistanceFastL2 : public DistanceInnerProduct<T>
+{
+ // currently defined only for float.
+ // templated for future use.
+ public:
+ DistanceFastL2() : DistanceInnerProduct<T>(diskann::Metric::FAST_L2)
+ {
+ }
+ float norm(const T *a, unsigned size) const;
+ float compare(const T *a, const T *b, float norm, unsigned size) const;
+};
+
+class AVXDistanceInnerProductFloat : public Distance<float>
+{
+ public:
+ AVXDistanceInnerProductFloat() : Distance<float>(diskann::Metric::INNER_PRODUCT)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
+};
+
+class AVXNormalizedCosineDistanceFloat : public Distance<float>
+{
+ private:
+ AVXDistanceInnerProductFloat _innerProduct;
+
+ protected:
+ void normalize_and_copy(const float *a, uint32_t length, float *a_norm) const;
+
+ public:
+ AVXNormalizedCosineDistanceFloat() : Distance<float>(diskann::Metric::COSINE)
+ {
+ }
+ DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const
+ {
+ // Inner product returns negative values to indicate distance.
+ // This will ensure that cosine is between -1 and 1.
+ return 1.0f + _innerProduct.compare(a, b, length);
+ }
+ DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const override;
+
+ DISKANN_DLLEXPORT virtual bool preprocessing_required() const;
+
+ DISKANN_DLLEXPORT virtual void preprocess_base_points(float *original_data, const size_t orig_dim,
+ const size_t num_points) override;
+
+ DISKANN_DLLEXPORT virtual void preprocess_query(const float *query_vec, const size_t query_dim,
+ float *scratch_query_vector) override;
+};
+
+template <typename T> Distance<T> *get_distance_function(Metric m);
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/exceptions.h b/be/src/extern/diskann/include/exceptions.h
new file mode 100644
index 0000000..99e4e73
--- /dev/null
+++ b/be/src/extern/diskann/include/exceptions.h
@@ -0,0 +1,17 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <stdexcept>
+
+namespace diskann
+{
+
+class NotImplementedException : public std::logic_error
+{
+ public:
+ NotImplementedException() : std::logic_error("Function not yet implemented.")
+ {
+ }
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/filter_utils.h b/be/src/extern/diskann/include/filter_utils.h
new file mode 100644
index 0000000..55f7aed
--- /dev/null
+++ b/be/src/extern/diskann/include/filter_utils.h
@@ -0,0 +1,221 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <algorithm>
+#include <fcntl.h>
+#include <cassert>
+#include <cstdlib>
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <random>
+#include <set>
+#include <tuple>
+#include <string>
+#include <tsl/robin_map.h>
+#include <tsl/robin_set.h>
+#ifdef __APPLE__
+#else
+#include <malloc.h>
+#endif
+
+#ifdef _WINDOWS
+#include <Windows.h>
+typedef HANDLE FileHandle;
+#else
+#include <unistd.h>
+typedef int FileHandle;
+#endif
+
+#ifndef _WINDOWS
+#include <sys/uio.h>
+#endif
+
+#include "cached_io.h"
+#include "common_includes.h"
+#include "memory_mapper.h"
+#include "utils.h"
+#include "windows_customizations.h"
+
+// custom types (for readability)
+typedef tsl::robin_set<std::string> label_set;
+typedef std::string path;
+
+// structs for returning multiple items from a function
+typedef std::tuple<std::vector<label_set>, tsl::robin_map<std::string, uint32_t>, tsl::robin_set<std::string>>
+ parse_label_file_return_values;
+typedef std::tuple<std::vector<std::vector<uint32_t>>, uint64_t> load_label_index_return_values;
+
+namespace diskann
+{
+template <typename T>
+DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels,
+ unsigned R, unsigned L, float alpha, unsigned num_threads);
+
+DISKANN_DLLEXPORT load_label_index_return_values load_label_index(path label_index_path,
+ uint32_t label_number_of_points);
+
+template <typename LabelT>
+DISKANN_DLLEXPORT std::tuple<std::vector<std::vector<LabelT>>, tsl::robin_set<LabelT>> parse_formatted_label_file(
+ path label_file);
+
+DISKANN_DLLEXPORT parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label);
+
+template <typename T>
+DISKANN_DLLEXPORT tsl::robin_map<std::string, std::vector<uint32_t>> generate_label_specific_vector_files_compat(
+ path input_data_path, tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
+ std::vector<label_set> point_ids_to_labels, label_set all_labels);
+
+/*
+ * For each label, generates a file containing all vectors that have said label.
+ * Also copies data from original bin file to new dimension-aligned file.
+ *
+ * Utilizes POSIX functions mmap and writev in order to minimize memory
+ * overhead, so we include an STL version as well.
+ *
+ * Each data file is saved under the following format:
+ * input_data_path + "_" + label
+ */
+#ifndef _WINDOWS
+template <typename T>
+inline tsl::robin_map<std::string, std::vector<uint32_t>> generate_label_specific_vector_files(
+ path input_data_path, tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
+ std::vector<label_set> point_ids_to_labels, label_set all_labels)
+{
+#ifndef _WINDOWS
+ auto file_writing_timer = std::chrono::high_resolution_clock::now();
+ diskann::MemoryMapper input_data(input_data_path);
+ char *input_start = input_data.getBuf();
+
+ uint32_t number_of_points, dimension;
+ std::memcpy(&number_of_points, input_start, sizeof(uint32_t));
+ std::memcpy(&dimension, input_start + sizeof(uint32_t), sizeof(uint32_t));
+ const uint32_t VECTOR_SIZE = dimension * sizeof(T);
+ const size_t METADATA = 2 * sizeof(uint32_t);
+ if (number_of_points != point_ids_to_labels.size())
+ {
+ std::cerr << "Error: number of points in labels file and data file differ." << std::endl;
+ throw;
+ }
+
+ tsl::robin_map<std::string, iovec *> label_to_iovec_map;
+ tsl::robin_map<std::string, uint32_t> label_to_curr_iovec;
+ tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id;
+
+ // setup iovec list for each label
+ for (const auto &lbl : all_labels)
+ {
+ iovec *label_iovecs = (iovec *)malloc(labels_to_number_of_points[lbl] * sizeof(iovec));
+ if (label_iovecs == nullptr)
+ {
+ throw;
+ }
+ label_to_iovec_map[lbl] = label_iovecs;
+ label_to_curr_iovec[lbl] = 0;
+ label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]);
+ }
+
+ // each point added to corresponding per-label iovec list
+ for (uint32_t point_id = 0; point_id < number_of_points; point_id++)
+ {
+ char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id);
+ iovec curr_iovec;
+
+ curr_iovec.iov_base = curr_point;
+ curr_iovec.iov_len = VECTOR_SIZE;
+ for (const auto &lbl : point_ids_to_labels[point_id])
+ {
+ *(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec;
+ label_to_curr_iovec[lbl]++;
+ label_id_to_orig_id[lbl].push_back(point_id);
+ }
+ }
+
+ // write each label iovec to resp. file
+ for (const auto &lbl : all_labels)
+ {
+ int label_input_data_fd;
+ path curr_label_input_data_path(input_data_path + "_" + lbl);
+ uint32_t curr_num_pts = labels_to_number_of_points[lbl];
+
+ label_input_data_fd =
+ open(curr_label_input_data_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t)0644);
+ if (label_input_data_fd == -1)
+ throw;
+
+ // write metadata
+ uint32_t metadata[2] = {curr_num_pts, dimension};
+ int return_value = write(label_input_data_fd, metadata, sizeof(uint32_t) * 2);
+ if (return_value == -1)
+ {
+ throw;
+ }
+
+ // limits on number of iovec structs per writev means we need to perform
+ // multiple writevs
+ size_t i = 0;
+ while (curr_num_pts > IOV_MAX)
+ {
+ return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX);
+ if (return_value == -1)
+ {
+ close(label_input_data_fd);
+ throw;
+ }
+ curr_num_pts -= IOV_MAX;
+ i += 1;
+ }
+ return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), curr_num_pts);
+ if (return_value == -1)
+ {
+ close(label_input_data_fd);
+ throw;
+ }
+
+ free(label_to_iovec_map[lbl]);
+ close(label_input_data_fd);
+ }
+
+ std::chrono::duration<double> file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer;
+ std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time "
+ << file_writing_time.count() << "\n"
+ << std::endl;
+
+ return label_id_to_orig_id;
+#endif
+}
+#endif
+
+inline std::vector<uint32_t> loadTags(const std::string &tags_file, const std::string &base_file)
+{
+ const bool tags_enabled = tags_file.empty() ? false : true;
+ std::vector<uint32_t> location_to_tag;
+ if (tags_enabled)
+ {
+ size_t tag_file_ndims, tag_file_npts;
+ std::uint32_t *tag_data;
+ diskann::load_bin<std::uint32_t>(tags_file, tag_data, tag_file_npts, tag_file_ndims);
+ if (tag_file_ndims != 1)
+ {
+ diskann::cerr << "tags file error" << std::endl;
+ throw diskann::ANNException("tag file error", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ // check if the point count match
+ size_t base_file_npts, base_file_ndims;
+ diskann::get_bin_metadata(base_file, base_file_npts, base_file_ndims);
+ if (base_file_npts != tag_file_npts)
+ {
+ diskann::cerr << "point num in tags file mismatch" << std::endl;
+ throw diskann::ANNException("point num in tags file mismatch", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ location_to_tag.assign(tag_data, tag_data + tag_file_npts);
+ delete[] tag_data;
+ }
+ return location_to_tag;
+}
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/in_mem_data_store.h b/be/src/extern/diskann/include/in_mem_data_store.h
new file mode 100644
index 0000000..0a0a617
--- /dev/null
+++ b/be/src/extern/diskann/include/in_mem_data_store.h
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+#pragma once
+
+#include <shared_mutex>
+#include <memory>
+
+#include "tsl/robin_map.h"
+#include "tsl/robin_set.h"
+#include "tsl/sparse_map.h"
+// #include "boost/dynamic_bitset.hpp"
+
+#include "abstract_data_store.h"
+
+#include "distance.h"
+#include "natural_number_map.h"
+#include "natural_number_set.h"
+#include "aligned_file_reader.h"
+
+namespace diskann
+{
+template <typename data_t> class InMemDataStore : public AbstractDataStore<data_t>
+{
+ public:
+ InMemDataStore(const location_t capacity, const size_t dim, std::unique_ptr<Distance<data_t>> distance_fn);
+ virtual ~InMemDataStore();
+
+ virtual location_t load(const std::string &filename) override;
+ virtual size_t save(const std::string &filename, const location_t num_points) override;
+
+ virtual size_t get_aligned_dim() const override;
+
+ // Populate internal data from unaligned data while doing alignment and any
+ // normalization that is required.
+ virtual void populate_data(const data_t *vectors, const location_t num_pts) override;
+ virtual void populate_data(const std::string &filename, const size_t offset) override;
+
+ virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) override;
+
+ virtual void get_vector(const location_t i, data_t *target) const override;
+ virtual void set_vector(const location_t i, const data_t *const vector) override;
+ virtual void prefetch_vector(const location_t loc) override;
+
+ virtual void move_vectors(const location_t old_location_start, const location_t new_location_start,
+ const location_t num_points) override;
+ virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override;
+
+ virtual void preprocess_query(const data_t *query, AbstractScratch<data_t> *query_scratch) const override;
+
+ virtual float get_distance(const data_t *preprocessed_query, const location_t loc) const override;
+ virtual float get_distance(const location_t loc1, const location_t loc2) const override;
+
+ virtual void get_distance(const data_t *preprocessed_query, const location_t *locations,
+ const uint32_t location_count, float *distances,
+ AbstractScratch<data_t> *scratch) const override;
+ virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
+ std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const override;
+
+ virtual location_t calculate_medoid() const override;
+
+ virtual Distance<data_t> *get_dist_fn() const override;
+
+ virtual size_t get_alignment_factor() const override;
+
+ protected:
+ virtual location_t expand(const location_t new_size) override;
+ virtual location_t shrink(const location_t new_size) override;
+
+ virtual location_t load_impl(const std::string &filename);
+#ifdef EXEC_ENV_OLS
+ virtual location_t load_impl(AlignedFileReader &reader);
+#endif
+
+ private:
+ data_t *_data = nullptr;
+
+ size_t _aligned_dim;
+
+ // It may seem weird to put distance metric along with the data store class,
+ // but this gives us perf benefits as the datastore can do distance
+ // computations during search and compute norms of vectors internally without
+ // have to copy data back and forth.
+ std::unique_ptr<Distance<data_t>> _distance_fn;
+
+ // in case we need to save vector norms for optimization
+ std::shared_ptr<float[]> _pre_computed_norms;
+};
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/include/in_mem_graph_store.h b/be/src/extern/diskann/include/in_mem_graph_store.h
new file mode 100644
index 0000000..3952b17
--- /dev/null
+++ b/be/src/extern/diskann/include/in_mem_graph_store.h
@@ -0,0 +1,56 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include "abstract_graph_store.h"
+
+namespace diskann
+{
+
+class InMemGraphStore : public AbstractGraphStore
+{
+ public:
+ InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree);
+
+ // returns tuple of <nodes_read, start, num_frozen_points>
+ virtual std::tuple<uint32_t, uint32_t, size_t> load(const std::string &index_path_prefix,
+ const size_t num_points) override;
+ virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_frozen_points,
+ const uint32_t start) override;
+ virtual int store(std::stringstream &index_stream, const size_t num_points,
+ const size_t num_frozen_points, const uint32_t start) override;
+
+ virtual const std::vector<location_t> &get_neighbours(const location_t i) const override;
+ virtual void add_neighbour(const location_t i, location_t neighbour_id) override;
+ virtual void clear_neighbours(const location_t i) override;
+ virtual void swap_neighbours(const location_t a, location_t b) override;
+
+ virtual void set_neighbours(const location_t i, std::vector<location_t> &neighbors) override;
+
+ virtual size_t resize_graph(const size_t new_size) override;
+ virtual void clear_graph() override;
+
+ virtual size_t get_max_range_of_graph() override;
+ virtual uint32_t get_max_observed_degree() override;
+
+ protected:
+ virtual std::tuple<uint32_t, uint32_t, size_t> load_impl(const std::string &filename, size_t expected_num_points);
+#ifdef EXEC_ENV_OLS
+ virtual std::tuple<uint32_t, uint32_t, size_t> load_impl(AlignedFileReader &reader, size_t expected_num_points);
+#endif
+
+ int save_graph(const std::string &index_path_prefix, const size_t active_points, const size_t num_frozen_points,
+ const uint32_t start);
+
+ virtual int save_graph(std::stringstream &out, const size_t num_points,
+ const size_t num_frozen_points, const uint32_t start) override;
+
+ private:
+ size_t _max_range_of_graph = 0;
+ uint32_t _max_observed_degree = 0;
+
+ std::vector<std::vector<uint32_t>> _graph;
+};
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/index.h b/be/src/extern/diskann/include/index.h
new file mode 100644
index 0000000..40f17d4
--- /dev/null
+++ b/be/src/extern/diskann/include/index.h
@@ -0,0 +1,442 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include "common_includes.h"
+
+#ifdef EXEC_ENV_OLS
+#include "aligned_file_reader.h"
+#endif
+
+#include "distance.h"
+#include "locking.h"
+#include "natural_number_map.h"
+#include "natural_number_set.h"
+#include "neighbor.h"
+#include "parameters.h"
+#include "utils.h"
+#include "windows_customizations.h"
+#include "scratch.h"
+#include "in_mem_data_store.h"
+#include "in_mem_graph_store.h"
+#include "abstract_index.h"
+
+#include "quantized_distance.h"
+#include "pq_data_store.h"
+
+#define OVERHEAD_FACTOR 1.1
+#define EXPAND_IF_FULL 0
+#define DEFAULT_MAXC 750
+
+namespace diskann
+{
+
+inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, uint32_t degree)
+{
+ double size_of_data = ((double)size) * ROUND_UP(dim, 8) * datasize;
+ double size_of_graph = ((double)size) * degree * sizeof(uint32_t) * defaults::GRAPH_SLACK_FACTOR;
+ double size_of_locks = ((double)size) * sizeof(non_recursive_mutex);
+ double size_of_outer_vector = ((double)size) * sizeof(ptrdiff_t);
+
+ return OVERHEAD_FACTOR * (size_of_data + size_of_graph + size_of_locks + size_of_outer_vector);
+}
+
+template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> class Index : public AbstractIndex
+{
+ /**************************************************************************
+ *
+ * Public functions acquire one or more of _update_lock, _consolidate_lock,
+ * _tag_lock, _delete_lock before calling protected functions which DO NOT
+ * acquire these locks. They might acquire locks on _locks[i]
+ *
+ **************************************************************************/
+
+ public:
+ // Constructor for Bulk operations and for creating the index object solely
+ // for loading a prexisting index.
+ DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::shared_ptr<AbstractDataStore<T>> data_store,
+ std::unique_ptr<AbstractGraphStore> graph_store,
+ std::shared_ptr<AbstractDataStore<T>> pq_data_store = nullptr);
+
+ // Constructor for incremental index
+ DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points,
+ const std::shared_ptr<IndexWriteParameters> index_parameters,
+ const std::shared_ptr<IndexSearchParams> index_search_params,
+ const size_t num_frozen_pts = 0, const bool dynamic_index = false,
+ const bool enable_tags = false, const bool concurrent_consolidate = false,
+ const bool pq_dist_build = false, const size_t num_pq_chunks = 0,
+ const bool use_opq = false, const bool filtered_index = false);
+
+ DISKANN_DLLEXPORT ~Index();
+
+ // Saves graph, data, metadata and associated tags.
+ DISKANN_DLLEXPORT void save(const char *filename, bool compact_before_save = false);
+ DISKANN_DLLEXPORT void save(std::stringstream &mem_index_stream, bool compact_before_save = false);
+
+ // Load functions
+#ifdef EXEC_ENV_OLS
+ DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l);
+#else
+ // Reads the number of frozen points from graph's metadata file section.
+ DISKANN_DLLEXPORT static size_t get_graph_num_frozen_points(const std::string &graph_file);
+
+ DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l);
+#endif
+
+ // get some private variables
+ DISKANN_DLLEXPORT size_t get_num_points();
+ DISKANN_DLLEXPORT size_t get_max_points();
+
+ DISKANN_DLLEXPORT bool detect_common_filters(uint32_t point_id, bool search_invocation,
+ const std::vector<LabelT> &incoming_labels);
+
+ // Batch build from a file. Optionally pass tags vector.
+ DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load,
+ const std::vector<TagT> &tags = std::vector<TagT>());
+
+ // Batch build from a file. Optionally pass tags file.
+ DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, const char *tag_filename);
+
+ // Batch build from a data array, which must pad vectors to aligned_dim
+ DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, const std::vector<TagT> &tags);
+
+ // Based on filter params builds a filtered or unfiltered index
+ DISKANN_DLLEXPORT void build(const std::string &data_file, const size_t num_points_to_load,
+ IndexFilterParams &filter_params);
+
+ // Filtered Support
+ DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file,
+ const size_t num_points_to_load,
+ const std::vector<TagT> &tags = std::vector<TagT>());
+
+ DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
+
+ // Get converted integer label from string to int map (_label_map)
+ DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label);
+
+ // Set starting point of an index before inserting any points incrementally.
+ // The data count should be equal to _num_frozen_pts * _aligned_dim.
+ DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count);
+ // Set starting points to random points on a sphere of certain radius.
+ // A fixed random seed can be specified for scenarios where it's important
+ // to have higher consistency between index builds.
+ DISKANN_DLLEXPORT void set_start_points_at_random(T radius, uint32_t random_seed = 0);
+
+ // For FastL2 search on a static index, we interleave the data with graph
+ DISKANN_DLLEXPORT void optimize_index_layout();
+
+ // For FastL2 search on optimized layout
+ DISKANN_DLLEXPORT void search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices);
+
+ // Added search overload that takes L as parameter, so that we
+ // can customize L on a per-query basis without tampering with "Parameters"
+ template <typename IDType>
+ DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search(const T *query, const size_t K, const uint32_t L,
+ IDType *indices, float *distances = nullptr);
+
+ // Initialize space for res_vectors before calling.
+ DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
+ float *distances, std::vector<T *> &res_vectors, bool use_filters = false,
+ const std::string filter_label = "");
+
+ // Filter support search
+ template <typename IndexType>
+ DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const LabelT &filter_label,
+ const size_t K, const uint32_t L,
+ IndexType *indices, float *distances);
+
+ // Will fail if tag already in the index or if tag=0.
+ DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);
+
+ // Will fail if tag already in the index or if tag=0.
+ DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag, const std::vector<LabelT> &label);
+
+ // call this before issuing deletions to sets relevant flags
+ DISKANN_DLLEXPORT int enable_delete();
+
+ // Record deleted point now and restructure graph later. Return -1 if tag
+ // not found, 0 if OK.
+ DISKANN_DLLEXPORT int lazy_delete(const TagT &tag);
+
+ // Record deleted points now and restructure graph later. Add to failed_tags
+ // if tag not found.
+ DISKANN_DLLEXPORT void lazy_delete(const std::vector<TagT> &tags, std::vector<TagT> &failed_tags);
+
+ // Call after a series of lazy deletions
+ // Returns number of live points left after consolidation
+ // If _conc_consolidates is set in the ctor, then this call can be invoked
+ // alongside inserts and lazy deletes, else it acquires _update_lock
+ DISKANN_DLLEXPORT consolidation_report consolidate_deletes(const IndexWriteParameters ¶meters);
+
+ DISKANN_DLLEXPORT void prune_all_neighbors(const uint32_t max_degree, const uint32_t max_occlusion,
+ const float alpha);
+
+ DISKANN_DLLEXPORT bool is_index_saved();
+
+ // repositions frozen points to the end of _data - if they have been moved
+ // during deletion
+ DISKANN_DLLEXPORT void reposition_frozen_point_to_end();
+ DISKANN_DLLEXPORT void reposition_points(uint32_t old_location_start, uint32_t new_location_start,
+ uint32_t num_locations);
+
+ // DISKANN_DLLEXPORT void save_index_as_one_file(bool flag);
+
+ DISKANN_DLLEXPORT void get_active_tags(tsl::robin_set<TagT> &active_tags);
+
+ // memory should be allocated for vec before calling this function
+ DISKANN_DLLEXPORT int get_vector_by_tag(TagT &tag, T *vec);
+
+ DISKANN_DLLEXPORT void print_status();
+
+ DISKANN_DLLEXPORT void count_nodes_at_bfs_levels();
+
+ // This variable MUST be updated if the number of entries in the metadata
+ // change.
+ DISKANN_DLLEXPORT static const int METADATA_ROWS = 5;
+
+ // ********************************
+ //
+ // Internals of the library
+ //
+ // ********************************
+
+ protected:
+ // overload of abstract index virtual methods
+ virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) override;
+
+ virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
+ std::any &indices, float *distances = nullptr) override;
+ virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
+ const std::string &filter_label_raw, const size_t K,
+ const uint32_t L, std::any &indices,
+ float *distances) override;
+
+ virtual int _insert_point(const DataType &data_point, const TagType tag) override;
+ virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override;
+
+ virtual int _lazy_delete(const TagType &tag) override;
+
+ virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) override;
+
+ virtual void _get_active_tags(TagRobinSet &active_tags) override;
+
+ virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) override;
+
+ virtual int _get_vector_by_tag(TagType &tag, DataType &vec) override;
+
+ virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override;
+
+ virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
+ float *distances, DataVector &res_vectors, bool use_filters = false,
+ const std::string filter_label = "") override;
+
+ virtual void _set_universal_label(const LabelType universal_label) override;
+
+ // No copy/assign.
+ Index(const Index<T, TagT, LabelT> &) = delete;
+ Index<T, TagT, LabelT> &operator=(const Index<T, TagT, LabelT> &) = delete;
+
+ // Use after _data and _nd have been populated
+ // Acquire exclusive _update_lock before calling
+ void build_with_data_populated(const std::vector<TagT> &tags);
+
+ // generates 1 frozen point that will never be deleted from the graph
+ // This is not visible to the user
+ void generate_frozen_point();
+
+ // determines navigating node of the graph by calculating medoid of datafopt
+ uint32_t calculate_entry_point();
+
+ void parse_label_file(const std::string &label_file, size_t &num_pts_labels);
+
+ std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
+
+ // Returns the locations of start point and frozen points suitable for use
+ // with iterate_to_fixed_point.
+ std::vector<uint32_t> get_init_ids();
+
+ // The query to use is placed in scratch->aligned_query
+ std::pair<uint32_t, uint32_t> iterate_to_fixed_point(InMemQueryScratch<T> *scratch, const uint32_t Lindex,
+ const std::vector<uint32_t> &init_ids, bool use_filter,
+ const std::vector<LabelT> &filters, bool search_invocation);
+
+ void search_for_point_and_prune(int location, uint32_t Lindex, std::vector<uint32_t> &pruned_list,
+ InMemQueryScratch<T> *scratch, bool use_filter = false,
+ uint32_t filteredLindex = 0);
+
+ void prune_neighbors(const uint32_t location, std::vector<Neighbor> &pool, std::vector<uint32_t> &pruned_list,
+ InMemQueryScratch<T> *scratch);
+
+ void prune_neighbors(const uint32_t location, std::vector<Neighbor> &pool, const uint32_t range,
+ const uint32_t max_candidate_size, const float alpha, std::vector<uint32_t> &pruned_list,
+ InMemQueryScratch<T> *scratch);
+
+ // Prunes candidates in @pool to a shorter list @result
+ // @pool must be sorted before calling
+ void occlude_list(const uint32_t location, std::vector<Neighbor> &pool, const float alpha, const uint32_t degree,
+ const uint32_t maxc, std::vector<uint32_t> &result, InMemQueryScratch<T> *scratch,
+ const tsl::robin_set<uint32_t> *const delete_set_ptr = nullptr);
+
+ // add reverse links from all the visited nodes to node n.
+ void inter_insert(uint32_t n, std::vector<uint32_t> &pruned_list, const uint32_t range,
+ InMemQueryScratch<T> *scratch);
+
+ void inter_insert(uint32_t n, std::vector<uint32_t> &pruned_list, InMemQueryScratch<T> *scratch);
+
+ // Acquire exclusive _update_lock before calling
+ void link();
+
+ // Acquire exclusive _tag_lock and _delete_lock before calling
+ int reserve_location();
+
+ // Acquire exclusive _tag_lock before calling
+ size_t release_location(int location);
+ size_t release_locations(const tsl::robin_set<uint32_t> &locations);
+
+ // Resize the index when no slots are left for insertion.
+ // Acquire exclusive _update_lock and _tag_lock before calling.
+ void resize(size_t new_max_points);
+
+ // Acquire unique lock on _update_lock, _consolidate_lock, _tag_lock
+ // and _delete_lock before calling these functions.
+ // Renumber nodes, update tag and location maps and compact the
+ // graph, mode = _consolidated_order in case of lazy deletion and
+ // _compacted_order in case of eager deletion
+ DISKANN_DLLEXPORT void compact_data();
+ DISKANN_DLLEXPORT void compact_frozen_point();
+
+ // Remove deleted nodes from adjacency list of node loc
+ // Replace removed neighbors with second order neighbors.
+ // Also acquires _locks[i] for i = loc and out-neighbors of loc.
+ void process_delete(const tsl::robin_set<uint32_t> &old_delete_set, size_t loc, const uint32_t range,
+ const uint32_t maxc, const float alpha, InMemQueryScratch<T> *scratch);
+
+ void initialize_query_scratch(uint32_t num_threads, uint32_t search_l, uint32_t indexing_l, uint32_t r,
+ uint32_t maxc, size_t dim);
+
+ // Do not call without acquiring appropriate locks
+ // call public member functions save and load to invoke these.
+ DISKANN_DLLEXPORT size_t save_graph(std::string filename);
+ DISKANN_DLLEXPORT size_t save_graph(std::stringstream &index_stream);
+ DISKANN_DLLEXPORT size_t save_data(std::string filename);
+ DISKANN_DLLEXPORT size_t save_tags(std::string filename);
+ DISKANN_DLLEXPORT size_t save_delete_list(const std::string &filename);
+ DISKANN_DLLEXPORT size_t load_graph(const std::string filename, size_t expected_num_points);
+ DISKANN_DLLEXPORT size_t load_data(std::string filename0);
+ DISKANN_DLLEXPORT size_t load_tags(const std::string tag_file_name);
+ DISKANN_DLLEXPORT size_t load_delete_set(const std::string &filename);
+ private:
+ // Distance functions
+ Metric _dist_metric = diskann::L2;
+
+ // Data
+ std::shared_ptr<AbstractDataStore<T>> _data_store;
+
+ // Graph related data structures
+ std::unique_ptr<AbstractGraphStore> _graph_store;
+
+ char *_opt_graph = nullptr;
+
+ // Dimensions
+ size_t _dim = 0;
+ size_t _nd = 0; // number of active points i.e. existing in the graph
+ size_t _max_points = 0; // total number of points in given data set
+
+ // _num_frozen_pts is the number of points which are used as initial
+ // candidates when iterating to closest point(s). These are not visible
+ // externally and won't be returned by search. At least 1 frozen point is
+ // needed for a dynamic index. The frozen points have consecutive locations.
+ // See also _start below.
+ size_t _num_frozen_pts = 0;
+ size_t _frozen_pts_used = 0;
+ size_t _node_size;
+ size_t _data_len;
+ size_t _neighbor_len;
+
+ // Start point of the search. When _num_frozen_pts is greater than zero,
+ // this is the location of the first frozen point. Otherwise, this is a
+ // location of one of the points in index.
+ uint32_t _start = 0;
+
+ bool _has_built = false;
+ bool _saturate_graph = false;
+ bool _save_as_one_file = false; // plan to support in next version
+ bool _dynamic_index = false;
+ bool _enable_tags = false;
+ bool _normalize_vecs = false; // Using normalied L2 for cosine.
+ bool _deletes_enabled = false;
+
+ // Filter Support
+
+ bool _filtered_index = false;
+ // Location to label is only updated during insert_point(), all other reads are protected by
+ // default as a location can only be released at end of consolidate deletes
+ std::vector<std::vector<LabelT>> _location_to_labels;
+ tsl::robin_set<LabelT> _labels;
+ std::string _labels_file;
+ std::unordered_map<LabelT, uint32_t> _label_to_start_id;
+ std::unordered_map<uint32_t, uint32_t> _medoid_counts;
+
+ bool _use_universal_label = false;
+ LabelT _universal_label = 0;
+ uint32_t _filterIndexingQueueSize;
+ std::unordered_map<std::string, LabelT> _label_map;
+
+ // Indexing parameters
+ uint32_t _indexingQueueSize;
+ uint32_t _indexingRange;
+ uint32_t _indexingMaxC;
+ float _indexingAlpha;
+ uint32_t _indexingThreads;
+
+ // Query scratch data structures
+ ConcurrentQueue<InMemQueryScratch<T> *> _query_scratch;
+
+ // Flags for PQ based distance calculation
+ bool _pq_dist = false;
+ bool _use_opq = false;
+ size_t _num_pq_chunks = 0;
+ // REFACTOR
+ // uint8_t *_pq_data = nullptr;
+ std::shared_ptr<QuantizedDistance<T>> _pq_distance_fn = nullptr;
+ std::shared_ptr<AbstractDataStore<T>> _pq_data_store = nullptr;
+ bool _pq_generated = false;
+ FixedChunkPQTable _pq_table;
+
+ //
+ // Data structures, locks and flags for dynamic indexing and tags
+ //
+
+ // lazy_delete removes entry from _location_to_tag and _tag_to_location. If
+ // _location_to_tag does not resolve a location, infer that it was deleted.
+ tsl::sparse_map<TagT, uint32_t> _tag_to_location;
+ natural_number_map<uint32_t, TagT> _location_to_tag;
+
+ // _empty_slots has unallocated slots and those freed by consolidate_delete.
+ // _delete_set has locations marked deleted by lazy_delete. Will not be
+ // immediately available for insert. consolidate_delete will release these
+ // slots to _empty_slots.
+ natural_number_set<uint32_t> _empty_slots;
+ std::unique_ptr<tsl::robin_set<uint32_t>> _delete_set;
+
+ bool _data_compacted = true; // true if data has been compacted
+ bool _is_saved = false; // Checking if the index is already saved.
+ bool _conc_consolidate = false; // use _lock while searching
+
+ // Acquire locks in the order below when acquiring multiple locks
+ std::shared_timed_mutex // RW mutex between save/load (exclusive lock) and
+ _update_lock; // search/inserts/deletes/consolidate (shared lock)
+ std::shared_timed_mutex // Ensure only one consolidate or compact_data is
+ _consolidate_lock; // ever active
+ std::shared_timed_mutex // RW lock for _tag_to_location,
+ _tag_lock; // _location_to_tag, _empty_slots, _nd, _max_points, _label_to_start_id
+ std::shared_timed_mutex // RW Lock on _delete_set and _data_compacted
+ _delete_lock; // variable
+
+ // Per node lock, cardinality=_max_points + _num_frozen_points
+ std::vector<non_recursive_mutex> _locks;
+
+ static const float INDEX_GROWTH_FACTOR;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/index_build_params.h b/be/src/extern/diskann/include/index_build_params.h
new file mode 100644
index 0000000..d4f4548
--- /dev/null
+++ b/be/src/extern/diskann/include/index_build_params.h
@@ -0,0 +1,73 @@
+#pragma once
+
+#include "common_includes.h"
+#include "parameters.h"
+
+namespace diskann
+{
+struct IndexFilterParams
+{
+ public:
+ std::string save_path_prefix;
+ std::string label_file;
+ std::string tags_file;
+ std::string universal_label;
+ uint32_t filter_threshold = 0;
+
+ private:
+ IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file,
+ const std::string &universal_label, uint32_t filter_threshold)
+ : save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label),
+ filter_threshold(filter_threshold)
+ {
+ }
+
+ friend class IndexFilterParamsBuilder;
+};
+class IndexFilterParamsBuilder
+{
+ public:
+ IndexFilterParamsBuilder() = default;
+
+ IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix)
+ {
+ if (save_path_prefix.empty() || save_path_prefix == "")
+ throw ANNException("Error: save_path_prefix can't be empty", -1);
+ this->_save_path_prefix = save_path_prefix;
+ return *this;
+ }
+
+ IndexFilterParamsBuilder &with_label_file(const std::string &label_file)
+ {
+ this->_label_file = label_file;
+ return *this;
+ }
+
+ IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label)
+ {
+ this->_universal_label = univeral_label;
+ return *this;
+ }
+
+ IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold)
+ {
+ this->_filter_threshold = filter_threshold;
+ return *this;
+ }
+
+ IndexFilterParams build()
+ {
+ return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold);
+ }
+
+ IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete;
+ IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete;
+
+ private:
+ std::string _save_path_prefix;
+ std::string _label_file;
+ std::string _tags_file;
+ std::string _universal_label;
+ uint32_t _filter_threshold = 0;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/index_config.h b/be/src/extern/diskann/include/index_config.h
new file mode 100644
index 0000000..a8e64d0
--- /dev/null
+++ b/be/src/extern/diskann/include/index_config.h
@@ -0,0 +1,256 @@
+#pragma once
+
+#include "common_includes.h"
+#include "parameters.h"
+
+namespace diskann
+{
+enum class DataStoreStrategy
+{
+ MEMORY
+};
+
+enum class GraphStoreStrategy
+{
+ MEMORY
+};
+
+struct IndexConfig
+{
+ DataStoreStrategy data_strategy;
+ GraphStoreStrategy graph_strategy;
+
+ Metric metric;
+ size_t dimension;
+ size_t max_points;
+
+ bool dynamic_index;
+ bool enable_tags;
+ bool pq_dist_build;
+ bool concurrent_consolidate;
+ bool use_opq;
+ bool filtered_index;
+
+ size_t num_pq_chunks;
+ size_t num_frozen_pts;
+
+ std::string label_type;
+ std::string tag_type;
+ std::string data_type;
+
+ // Params for building index
+ std::shared_ptr<IndexWriteParameters> index_write_params;
+ // Params for searching index
+ std::shared_ptr<IndexSearchParams> index_search_params;
+
+ private:
+ IndexConfig(DataStoreStrategy data_strategy, GraphStoreStrategy graph_strategy, Metric metric, size_t dimension,
+ size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags,
+ bool pq_dist_build, bool concurrent_consolidate, bool use_opq, bool filtered_index,
+ std::string &data_type, const std::string &tag_type, const std::string &label_type,
+ std::shared_ptr<IndexWriteParameters> index_write_params,
+ std::shared_ptr<IndexSearchParams> index_search_params)
+ : data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension),
+ max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build),
+ concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), filtered_index(filtered_index),
+ num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type),
+ data_type(data_type), index_write_params(index_write_params), index_search_params(index_search_params)
+ {
+ }
+
+ friend class IndexConfigBuilder;
+};
+
+class IndexConfigBuilder
+{
+ public:
+ IndexConfigBuilder() = default;
+
+ IndexConfigBuilder &with_metric(Metric m)
+ {
+ this->_metric = m;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_graph_load_store_strategy(GraphStoreStrategy graph_strategy)
+ {
+ this->_graph_strategy = graph_strategy;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_data_load_store_strategy(DataStoreStrategy data_strategy)
+ {
+ this->_data_strategy = data_strategy;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_dimension(size_t dimension)
+ {
+ this->_dimension = dimension;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_max_points(size_t max_points)
+ {
+ this->_max_points = max_points;
+ return *this;
+ }
+
+ IndexConfigBuilder &is_dynamic_index(bool dynamic_index)
+ {
+ this->_dynamic_index = dynamic_index;
+ return *this;
+ }
+
+ IndexConfigBuilder &is_enable_tags(bool enable_tags)
+ {
+ this->_enable_tags = enable_tags;
+ return *this;
+ }
+
+ IndexConfigBuilder &is_pq_dist_build(bool pq_dist_build)
+ {
+ this->_pq_dist_build = pq_dist_build;
+ return *this;
+ }
+
+ IndexConfigBuilder &is_concurrent_consolidate(bool concurrent_consolidate)
+ {
+ this->_concurrent_consolidate = concurrent_consolidate;
+ return *this;
+ }
+
+ IndexConfigBuilder &is_use_opq(bool use_opq)
+ {
+ this->_use_opq = use_opq;
+ return *this;
+ }
+
+ IndexConfigBuilder &is_filtered(bool is_filtered)
+ {
+ this->_filtered_index = is_filtered;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_num_pq_chunks(size_t num_pq_chunks)
+ {
+ this->_num_pq_chunks = num_pq_chunks;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_num_frozen_pts(size_t num_frozen_pts)
+ {
+ this->_num_frozen_pts = num_frozen_pts;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_label_type(const std::string &label_type)
+ {
+ this->_label_type = label_type;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_tag_type(const std::string &tag_type)
+ {
+ this->_tag_type = tag_type;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_data_type(const std::string &data_type)
+ {
+ this->_data_type = data_type;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_index_write_params(IndexWriteParameters &index_write_params)
+ {
+ this->_index_write_params = std::make_shared<IndexWriteParameters>(index_write_params);
+ return *this;
+ }
+
+ IndexConfigBuilder &with_index_write_params(std::shared_ptr<IndexWriteParameters> index_write_params_ptr)
+ {
+ if (index_write_params_ptr == nullptr)
+ {
+ diskann::cout << "Passed, empty build_params while creating index config" << std::endl;
+ return *this;
+ }
+ this->_index_write_params = index_write_params_ptr;
+ return *this;
+ }
+
+ IndexConfigBuilder &with_index_search_params(IndexSearchParams &search_params)
+ {
+ this->_index_search_params = std::make_shared<IndexSearchParams>(search_params);
+ return *this;
+ }
+
+ IndexConfigBuilder &with_index_search_params(std::shared_ptr<IndexSearchParams> search_params_ptr)
+ {
+ if (search_params_ptr == nullptr)
+ {
+ diskann::cout << "Passed, empty search_params while creating index config" << std::endl;
+ return *this;
+ }
+ this->_index_search_params = search_params_ptr;
+ return *this;
+ }
+
+ IndexConfig build()
+ {
+ if (_data_type == "" || _data_type.empty())
+ throw ANNException("Error: data_type can not be empty", -1);
+
+ if (_dynamic_index && _num_frozen_pts == 0)
+ {
+ _num_frozen_pts = 1;
+ }
+
+ if (_dynamic_index)
+ {
+ if (_index_search_params != nullptr && _index_search_params->initial_search_list_size == 0)
+ throw ANNException("Error: please pass initial_search_list_size for building dynamic index.", -1);
+ }
+
+ // sanity check
+ if (_dynamic_index && _num_frozen_pts == 0)
+ {
+ diskann::cout << "_num_frozen_pts passed as 0 for dynamic_index. Setting it to 1 for safety." << std::endl;
+ _num_frozen_pts = 1;
+ }
+
+ return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks,
+ _num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate,
+ _use_opq, _filtered_index, _data_type, _tag_type, _label_type, _index_write_params,
+ _index_search_params);
+ }
+
+ IndexConfigBuilder(const IndexConfigBuilder &) = delete;
+ IndexConfigBuilder &operator=(const IndexConfigBuilder &) = delete;
+
+ private:
+ DataStoreStrategy _data_strategy;
+ GraphStoreStrategy _graph_strategy;
+
+ Metric _metric;
+ size_t _dimension;
+ size_t _max_points;
+
+ bool _dynamic_index = false;
+ bool _enable_tags = false;
+ bool _pq_dist_build = false;
+ bool _concurrent_consolidate = false;
+ bool _use_opq = false;
+ bool _filtered_index{defaults::HAS_LABELS};
+
+ size_t _num_pq_chunks = 0;
+ size_t _num_frozen_pts{defaults::NUM_FROZEN_POINTS_STATIC};
+
+ std::string _label_type{"uint32"};
+ std::string _tag_type{"uint32"};
+ std::string _data_type;
+
+ std::shared_ptr<IndexWriteParameters> _index_write_params;
+ std::shared_ptr<IndexSearchParams> _index_search_params;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/index_factory.h b/be/src/extern/diskann/include/index_factory.h
new file mode 100644
index 0000000..76fb0b9
--- /dev/null
+++ b/be/src/extern/diskann/include/index_factory.h
@@ -0,0 +1,51 @@
+#pragma once
+
+#include "index.h"
+#include "abstract_graph_store.h"
+#include "in_mem_graph_store.h"
+#include "pq_data_store.h"
+
+namespace diskann
+{
+class IndexFactory
+{
+ public:
+ DISKANN_DLLEXPORT explicit IndexFactory(const IndexConfig &config);
+ DISKANN_DLLEXPORT std::unique_ptr<AbstractIndex> create_instance();
+
+ DISKANN_DLLEXPORT static std::unique_ptr<AbstractGraphStore> construct_graphstore(
+ const GraphStoreStrategy stratagy, const size_t size, const size_t reserve_graph_degree);
+
+ template <typename T>
+ DISKANN_DLLEXPORT static std::shared_ptr<AbstractDataStore<T>> construct_datastore(DataStoreStrategy stratagy,
+ size_t num_points,
+ size_t dimension, Metric m);
+ // For now PQDataStore incorporates within itself all variants of quantization that we support. In the
+ // future it may be necessary to introduce an AbstractPQDataStore class to spearate various quantization
+ // flavours.
+ template <typename T>
+ DISKANN_DLLEXPORT static std::shared_ptr<PQDataStore<T>> construct_pq_datastore(DataStoreStrategy strategy,
+ size_t num_points, size_t dimension,
+ Metric m, size_t num_pq_chunks,
+ bool use_opq);
+ template <typename T> static Distance<T> *construct_inmem_distance_fn(Metric m);
+
+ private:
+ void check_config();
+
+ template <typename data_type, typename tag_type, typename label_type>
+ std::unique_ptr<AbstractIndex> create_instance();
+
+ std::unique_ptr<AbstractIndex> create_instance(const std::string &data_type, const std::string &tag_type,
+ const std::string &label_type);
+
+ template <typename data_type>
+ std::unique_ptr<AbstractIndex> create_instance(const std::string &tag_type, const std::string &label_type);
+
+ template <typename data_type, typename tag_type>
+ std::unique_ptr<AbstractIndex> create_instance(const std::string &label_type);
+
+ std::unique_ptr<IndexConfig> _config;
+};
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/io_reader.h b/be/src/extern/diskann/include/io_reader.h
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/be/src/extern/diskann/include/io_reader.h
diff --git a/be/src/extern/diskann/include/linux_aligned_file_reader.h b/be/src/extern/diskann/include/linux_aligned_file_reader.h
new file mode 100644
index 0000000..7620e31
--- /dev/null
+++ b/be/src/extern/diskann/include/linux_aligned_file_reader.h
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#ifndef _WINDOWS
+
+#include "aligned_file_reader.h"
+
+class LinuxAlignedFileReader : public AlignedFileReader
+{
+ private:
+ uint64_t file_sz;
+ FileHandle file_desc;
+ io_context_t bad_ctx = (io_context_t)-1;
+
+ public:
+ LinuxAlignedFileReader();
+ ~LinuxAlignedFileReader();
+
+ IOContext &get_ctx();
+
+ // register thread-id for a context
+ void register_thread();
+
+ // de-register thread-id for a context
+ void deregister_thread();
+ void deregister_all_threads();
+
+ // Open & close ops
+ // Blocking calls
+ void open(const std::string &fname);
+ void close();
+
+ // process batch of aligned requests in parallel
+ // NOTE :: blocking call
+ void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async = false);
+};
+
+#endif
diff --git a/be/src/extern/diskann/include/locking.h b/be/src/extern/diskann/include/locking.h
new file mode 100644
index 0000000..890c24a
--- /dev/null
+++ b/be/src/extern/diskann/include/locking.h
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+#pragma once
+
+#include <mutex>
+
+#ifdef _WINDOWS
+#include "windows_slim_lock.h"
+#endif
+
+namespace diskann
+{
+#ifdef _WINDOWS
+using non_recursive_mutex = windows_exclusive_slim_lock;
+using LockGuard = windows_exclusive_slim_lock_guard;
+#else
+using non_recursive_mutex = std::mutex;
+using LockGuard = std::lock_guard<non_recursive_mutex>;
+#endif
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/logger.h b/be/src/extern/diskann/include/logger.h
new file mode 100644
index 0000000..0b17807
--- /dev/null
+++ b/be/src/extern/diskann/include/logger.h
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+#pragma once
+
+#include <functional>
+#include <iostream>
+#include "windows_customizations.h"
+
+#ifdef EXEC_ENV_OLS
+#ifndef ENABLE_CUSTOM_LOGGER
+#define ENABLE_CUSTOM_LOGGER
+#endif // !ENABLE_CUSTOM_LOGGER
+#endif // EXEC_ENV_OLS
+
+namespace diskann
+{
+#ifdef ENABLE_CUSTOM_LOGGER
+DISKANN_DLLEXPORT extern std::basic_ostream<char> cout;
+DISKANN_DLLEXPORT extern std::basic_ostream<char> cerr;
+#else
+using std::cerr;
+using std::cout;
+#endif
+
+enum class DISKANN_DLLEXPORT LogLevel
+{
+ LL_Info = 0,
+ LL_Error,
+ LL_Count
+};
+
+#ifdef ENABLE_CUSTOM_LOGGER
+DISKANN_DLLEXPORT void SetCustomLogger(std::function<void(LogLevel, const char *)> logger);
+#endif
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/logger_impl.h b/be/src/extern/diskann/include/logger_impl.h
new file mode 100644
index 0000000..03c65e0
--- /dev/null
+++ b/be/src/extern/diskann/include/logger_impl.h
@@ -0,0 +1,61 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <sstream>
+#include <mutex>
+
+#include "ann_exception.h"
+#include "logger.h"
+
+namespace diskann
+{
+#ifdef ENABLE_CUSTOM_LOGGER
+class ANNStreamBuf : public std::basic_streambuf<char>
+{
+ public:
+ DISKANN_DLLEXPORT explicit ANNStreamBuf(FILE *fp);
+ DISKANN_DLLEXPORT ~ANNStreamBuf();
+
+ DISKANN_DLLEXPORT bool is_open() const
+ {
+ return true; // because stdout and stderr are always open.
+ }
+ DISKANN_DLLEXPORT void close();
+ DISKANN_DLLEXPORT virtual int underflow();
+ DISKANN_DLLEXPORT virtual int overflow(int c);
+ DISKANN_DLLEXPORT virtual int sync();
+
+ private:
+ FILE *_fp;
+ char *_buf;
+ int _bufIndex;
+ std::mutex _mutex;
+ LogLevel _logLevel;
+
+ int flush();
+ void logImpl(char *str, int numchars);
+
+ // Why the two buffer-sizes? If we are running normally, we are basically
+ // interacting with a character output system, so we short-circuit the
+ // output process by keeping an empty buffer and writing each character
+ // to stdout/stderr. But if we are running in OLS, we have to take all
+ // the text that is written to diskann::cout/diskann:cerr, consolidate it
+ // and push it out in one-shot, because the OLS infra does not give us
+ // character based output. Therefore, we use a larger buffer that is large
+ // enough to store the longest message, and continuously add characters
+ // to it. When the calling code outputs a std::endl or std::flush, sync()
+ // will be called and will output a log level, component name, and the text
+ // that has been collected. (sync() is also called if the buffer is full, so
+ // overflows/missing text are not a concern).
+ // This implies calling code _must_ either print std::endl or std::flush
+ // to ensure that the message is written immediately.
+
+ static const int BUFFER_SIZE = 1024;
+
+ ANNStreamBuf(const ANNStreamBuf &);
+ ANNStreamBuf &operator=(const ANNStreamBuf &);
+};
+#endif
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/math_utils.h b/be/src/extern/diskann/include/math_utils.h
new file mode 100644
index 0000000..83d189f
--- /dev/null
+++ b/be/src/extern/diskann/include/math_utils.h
@@ -0,0 +1,87 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include "common_includes.h"
+#include "utils.h"
+
+namespace math_utils
+{
+
+float calc_distance(float *vec_1, float *vec_2, size_t dim);
+
+// compute l2-squared norms of data stored in row major num_points * dim,
+// needs
+// to be pre-allocated
+void compute_vecs_l2sq(float *vecs_l2sq, float *data, const size_t num_points, const size_t dim);
+
+void rotate_data_randomly(float *data, size_t num_points, size_t dim, float *rot_mat, float *&new_mat,
+ bool transpose_rot = false);
+
+// calculate closest center to data of num_points * dim (row major)
+// centers is num_centers * dim (row major)
+// data_l2sq has pre-computed squared norms of data
+// centers_l2sq has pre-computed squared norms of centers
+// pre-allocated center_index will contain id of k nearest centers
+// pre-allocated dist_matrix shound be num_points * num_centers and contain
+// squared distances
+
+// Ideally used only by compute_closest_centers
+void compute_closest_centers_in_block(const float *const data, const size_t num_points, const size_t dim,
+ const float *const centers, const size_t num_centers,
+ const float *const docs_l2sq, const float *const centers_l2sq,
+ uint32_t *center_index, float *const dist_matrix, size_t k = 1);
+
+// Given data in num_points * new_dim row major
+// Pivots stored in full_pivot_data as k * new_dim row major
+// Calculate the closest pivot for each point and store it in vector
+// closest_centers_ivf (which needs to be allocated outside)
+// Additionally, if inverted index is not null (and pre-allocated), it will
+// return inverted index for each center Additionally, if pts_norms_squared is
+// not null, then it will assume that point norms are pre-computed and use
+// those
+// values
+
+void compute_closest_centers(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers,
+ size_t k, uint32_t *closest_centers_ivf, std::vector<size_t> *inverted_index = NULL,
+ float *pts_norms_squared = NULL);
+
+// if to_subtract is 1, will subtract nearest center from each row. Else will
+// add. Output will be in data_load iself.
+// Nearest centers need to be provided in closst_centers.
+
+void process_residuals(float *data_load, size_t num_points, size_t dim, float *cur_pivot_data, size_t num_centers,
+ uint32_t *closest_centers, bool to_subtract);
+
+} // namespace math_utils
+
+namespace kmeans
+{
+
+// run Lloyds one iteration
+// Given data in row major num_points * dim, and centers in row major
+// num_centers * dim
+// And squared lengths of data points, output the closest center to each data
+// point, update centers, and also return inverted index.
+// If closest_centers == NULL, will allocate memory and return.
+// Similarly, if closest_docs == NULL, will allocate memory and return.
+
+float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, size_t num_centers, float *docs_l2sq,
+ std::vector<size_t> *closest_docs, uint32_t *&closest_center);
+
+// Run Lloyds until max_reps or stopping criterion
+// If you pass NULL for closest_docs and closest_center, it will NOT return
+// the results, else it will assume appriate allocation as closest_docs = new
+// vector<size_t> [num_centers], and closest_center = new size_t[num_points]
+// Final centers are output in centers as row major num_centers * dim
+//
+float run_lloyds(float *data, size_t num_points, size_t dim, float *centers, const size_t num_centers,
+ const size_t max_reps, std::vector<size_t> *closest_docs, uint32_t *closest_center);
+
+// assumes already memory allocated for pivot_data as new
+// float[num_centers*dim] and select randomly num_centers points as pivots
+void selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers);
+
+void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers);
+} // namespace kmeans
diff --git a/be/src/extern/diskann/include/memory_mapper.h b/be/src/extern/diskann/include/memory_mapper.h
new file mode 100644
index 0000000..75faca1
--- /dev/null
+++ b/be/src/extern/diskann/include/memory_mapper.h
@@ -0,0 +1,43 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#ifndef _WINDOWS
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#else
+#include <Windows.h>
+#endif
+#include <string>
+
+namespace diskann
+{
+class MemoryMapper
+{
+ private:
+#ifndef _WINDOWS
+ int _fd;
+#else
+ HANDLE _bareFile;
+ HANDLE _fd;
+
+#endif
+ char *_buf;
+ size_t _fileSize;
+ const char *_fileName;
+
+ public:
+ MemoryMapper(const char *filename);
+ MemoryMapper(const std::string &filename);
+
+ char *getBuf();
+ size_t getFileSize();
+
+ ~MemoryMapper();
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/natural_number_map.h b/be/src/extern/diskann/include/natural_number_map.h
new file mode 100644
index 0000000..e846882
--- /dev/null
+++ b/be/src/extern/diskann/include/natural_number_map.h
@@ -0,0 +1,86 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <memory>
+#include <type_traits>
+#include <vector>
+
+#include <boost/dynamic_bitset.hpp>
+
+namespace diskann
+{
+// A map whose key is a natural number (from 0 onwards) and maps to a value.
+// Made as both memory and performance efficient map for scenario such as
+// DiskANN location-to-tag map. There, the pool of numbers is consecutive from
+// zero to some max value, and it's expected that most if not all keys from 0
+// up to some current maximum will be present in the map. The memory usage of
+// the map is determined by the largest inserted key since it uses vector as a
+// backing store and bitset for presence indication.
+//
+// Thread-safety: this class is not thread-safe in general.
+// Exception: multiple read-only operations are safe on the object only if
+// there are no writers to it in parallel.
+template <typename Key, typename Value> class natural_number_map
+{
+ public:
+ static_assert(std::is_trivial<Key>::value, "Key must be a trivial type");
+
+ // Represents a reference to a element in the map. Used while iterating
+ // over map entries.
+ struct position
+ {
+ size_t _key;
+ // The number of keys that were enumerated when iterating through the
+ // map so far. Used to early-terminate enumeration when ithere are no
+ // more entries in the map.
+ size_t _keys_already_enumerated;
+
+ // Returns whether it's valid to access the element at this position in
+ // the map.
+ bool is_valid() const;
+ };
+
+ natural_number_map();
+
+ void reserve(size_t count);
+ size_t size() const;
+
+ void set(Key key, Value value);
+ void erase(Key key);
+
+ bool contains(Key key) const;
+ bool try_get(Key key, Value &value) const;
+
+ // Returns the value at the specified position. Prerequisite: position is
+ // valid.
+ Value get(const position &pos) const;
+
+ // Finds the first element in the map, if any. Invalidated by changes in the
+ // map.
+ position find_first() const;
+
+ // Finds the next element in the map after the specified position.
+ // Invalidated by changes in the map.
+ position find_next(const position &after_position) const;
+
+ void clear();
+
+ private:
+ // Number of entries in the map. Not the same as size() of the
+ // _values_vector below.
+ size_t _size;
+
+ // Array of values. The key is the index of the value.
+ std::vector<Value> _values_vector;
+
+ // Values that are in the set have the corresponding bit index set
+ // to 1.
+ //
+ // Use a pointer here to allow for forward declaration of dynamic_bitset
+ // in public headers to avoid making boost a dependency for clients
+ // of DiskANN.
+ std::unique_ptr<boost::dynamic_bitset<>> _values_bitset;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/natural_number_set.h b/be/src/extern/diskann/include/natural_number_set.h
new file mode 100644
index 0000000..ec5b827
--- /dev/null
+++ b/be/src/extern/diskann/include/natural_number_set.h
@@ -0,0 +1,50 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <memory>
+#include <type_traits>
+
+#include "boost_dynamic_bitset_fwd.h"
+
+namespace diskann
+{
+// A set of natural numbers (from 0 onwards). Made for scenario where the
+// pool of numbers is consecutive from zero to some max value and very
+// efficient methods for "add to set", "get any value from set", "is in set"
+// are needed. The memory usage of the set is determined by the largest
+// number of inserted entries (uses a vector as a backing store) as well as
+// the largest value to be placed in it (uses bitset as well).
+//
+// Thread-safety: this class is not thread-safe in general.
+// Exception: multiple read-only operations (e.g. is_in_set, empty, size) are
+// safe on the object only if there are no writers to it in parallel.
+template <typename T> class natural_number_set
+{
+ public:
+ static_assert(std::is_trivial<T>::value, "Identifier must be a trivial type");
+
+ natural_number_set();
+
+ bool is_empty() const;
+ void reserve(size_t count);
+ void insert(T id);
+ T pop_any();
+ void clear();
+ size_t size() const;
+ bool is_in_set(T id) const;
+
+ private:
+ // Values that are currently in set.
+ std::vector<T> _values_vector;
+
+ // Values that are in the set have the corresponding bit index set
+ // to 1.
+ //
+ // Use a pointer here to allow for forward declaration of dynamic_bitset
+ // in public headers to avoid making boost a dependency for clients
+ // of DiskANN.
+ std::unique_ptr<boost::dynamic_bitset<>> _values_bitset;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/neighbor.h b/be/src/extern/diskann/include/neighbor.h
new file mode 100644
index 0000000..c58e650
--- /dev/null
+++ b/be/src/extern/diskann/include/neighbor.h
@@ -0,0 +1,160 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <cstddef>
+#include <mutex>
+#include <vector>
+#include "utils.h"
+
+namespace diskann
+{
+
+struct Neighbor
+{
+ unsigned id;
+ float distance;
+ bool expanded;
+
+ Neighbor() = default;
+
+ Neighbor(unsigned id, float distance) : id{id}, distance{distance}, expanded(false)
+ {
+ }
+
+ inline bool operator<(const Neighbor &other) const
+ {
+ return distance < other.distance || (distance == other.distance && id < other.id);
+ }
+
+ inline bool operator==(const Neighbor &other) const
+ {
+ return (id == other.id);
+ }
+};
+
+// Invariant: after every `insert` and `closest_unexpanded()`, `_cur` points to
+// the first Neighbor which is unexpanded.
+class NeighborPriorityQueue
+{
+ public:
+ NeighborPriorityQueue() : _size(0), _capacity(0), _cur(0), _pre(0)
+ {
+ }
+
+ explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1)
+ {
+ }
+
+ // Inserts the item ordered into the set up to the sets capacity.
+ // The item will be dropped if it is the same id as an exiting
+ // set item or it has a greated distance than the final
+ // item in the set. The set cursor that is used to pop() the
+ // next item will be set to the lowest index of an uncheck item
+ void insert(const Neighbor &nbr)
+ {
+ if (_size == _capacity && _data[_size - 1] < nbr)
+ {
+ return;
+ }
+
+ size_t lo = 0, hi = _size;
+ while (lo < hi)
+ {
+ size_t mid = (lo + hi) >> 1;
+ if (nbr < _data[mid])
+ {
+ hi = mid;
+ // Make sure the same id isn't inserted into the set
+ }
+ else if (_data[mid].id == nbr.id)
+ {
+ return;
+ }
+ else
+ {
+ lo = mid + 1;
+ }
+ }
+
+ if (lo < _capacity)
+ {
+ std::memmove(&_data[lo + 1], &_data[lo], (_size - lo) * sizeof(Neighbor));
+ }
+ _data[lo] = {nbr.id, nbr.distance};
+ if (_size < _capacity)
+ {
+ _size++;
+ }
+ if (lo < _cur)
+ {
+ _cur = lo;
+ }
+ }
+
+ Neighbor closest_unexpanded()
+ {
+ _data[_cur].expanded = true;
+ _pre = _cur;
+ while (_cur < _size && _data[_cur].expanded)
+ {
+ _cur++;
+ }
+ return _data[_pre];
+ }
+
+ void remove_pre_expanded_node(){
+ int len = (_size - _pre - 1) * sizeof(Neighbor);
+ if (len > 0)
+ std::memmove(&_data[_pre], &_data[_pre + 1], len);
+ _cur--;
+ _size--;
+ }
+
+ bool has_unexpanded_node() const
+ {
+ return _cur < _size;
+ }
+
+ size_t size() const
+ {
+ return _size;
+ }
+
+ size_t capacity() const
+ {
+ return _capacity;
+ }
+
+ void reserve(size_t capacity)
+ {
+ if (capacity + 1 > _data.size())
+ {
+ _data.resize(capacity + 1);
+ }
+ _capacity = capacity;
+ }
+
+ Neighbor &operator[](size_t i)
+ {
+ return _data[i];
+ }
+
+ Neighbor operator[](size_t i) const
+ {
+ return _data[i];
+ }
+
+ void clear()
+ {
+ _size = 0;
+ _cur = 0;
+ }
+
+ private:
+ size_t _size, _capacity, _cur, _pre;
+ std::vector<Neighbor> _data;
+};
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/parameters.h b/be/src/extern/diskann/include/parameters.h
new file mode 100644
index 0000000..0206814
--- /dev/null
+++ b/be/src/extern/diskann/include/parameters.h
@@ -0,0 +1,119 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <sstream>
+#include <typeinfo>
+#include <unordered_map>
+
+#include "omp.h"
+#include "defaults.h"
+
+namespace diskann
+{
+
+class IndexWriteParameters
+
+{
+ public:
+ const uint32_t search_list_size; // L
+ const uint32_t max_degree; // R
+ const bool saturate_graph;
+ const uint32_t max_occlusion_size; // C
+ const float alpha;
+ const uint32_t num_threads;
+ const uint32_t filter_list_size; // Lf
+
+ IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph,
+ const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads,
+ const uint32_t filter_list_size)
+ : search_list_size(search_list_size), max_degree(max_degree), saturate_graph(saturate_graph),
+ max_occlusion_size(max_occlusion_size), alpha(alpha), num_threads(num_threads),
+ filter_list_size(filter_list_size)
+ {
+ }
+
+ friend class IndexWriteParametersBuilder;
+};
+
+class IndexSearchParams
+{
+ public:
+ IndexSearchParams(const uint32_t initial_search_list_size, const uint32_t num_search_threads)
+ : initial_search_list_size(initial_search_list_size), num_search_threads(num_search_threads)
+ {
+ }
+ const uint32_t initial_search_list_size; // search L
+ const uint32_t num_search_threads; // search threads
+};
+
+class IndexWriteParametersBuilder
+{
+ /**
+ * Fluent builder pattern to keep track of the 7 non-default properties
+ * and their order. The basic ctor was getting unwieldy.
+ */
+ public:
+ IndexWriteParametersBuilder(const uint32_t search_list_size, // L
+ const uint32_t max_degree // R
+ )
+ : _search_list_size(search_list_size), _max_degree(max_degree)
+ {
+ }
+
+ IndexWriteParametersBuilder &with_max_occlusion_size(const uint32_t max_occlusion_size)
+ {
+ _max_occlusion_size = max_occlusion_size;
+ return *this;
+ }
+
+ IndexWriteParametersBuilder &with_saturate_graph(const bool saturate_graph)
+ {
+ _saturate_graph = saturate_graph;
+ return *this;
+ }
+
+ IndexWriteParametersBuilder &with_alpha(const float alpha)
+ {
+ _alpha = alpha;
+ return *this;
+ }
+
+ IndexWriteParametersBuilder &with_num_threads(const uint32_t num_threads)
+ {
+ _num_threads = num_threads == 0 ? omp_get_num_procs() : num_threads;
+ return *this;
+ }
+
+ IndexWriteParametersBuilder &with_filter_list_size(const uint32_t filter_list_size)
+ {
+ _filter_list_size = filter_list_size == 0 ? _search_list_size : filter_list_size;
+ return *this;
+ }
+
+ IndexWriteParameters build() const
+ {
+ return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, _max_occlusion_size, _alpha,
+ _num_threads, _filter_list_size);
+ }
+
+ IndexWriteParametersBuilder(const IndexWriteParameters &wp)
+ : _search_list_size(wp.search_list_size), _max_degree(wp.max_degree),
+ _max_occlusion_size(wp.max_occlusion_size), _saturate_graph(wp.saturate_graph), _alpha(wp.alpha),
+ _filter_list_size(wp.filter_list_size)
+ {
+ }
+ IndexWriteParametersBuilder(const IndexWriteParametersBuilder &) = delete;
+ IndexWriteParametersBuilder &operator=(const IndexWriteParametersBuilder &) = delete;
+
+ private:
+ uint32_t _search_list_size{};
+ uint32_t _max_degree{};
+ uint32_t _max_occlusion_size{defaults::MAX_OCCLUSION_SIZE};
+ bool _saturate_graph{defaults::SATURATE_GRAPH};
+ float _alpha{defaults::ALPHA};
+ uint32_t _num_threads{defaults::NUM_THREADS};
+ uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE};
+};
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/partition.h b/be/src/extern/diskann/include/partition.h
new file mode 100644
index 0000000..ccba399
--- /dev/null
+++ b/be/src/extern/diskann/include/partition.h
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include <cassert>
+#include <sstream>
+#include <stack>
+#include <string>
+#include <unordered_map>
+
+#include "neighbor.h"
+#include "parameters.h"
+#include "tsl/robin_set.h"
+#include "utils.h"
+
+#include "windows_customizations.h"
+
+template <typename T>
+void gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate);
+
+template <typename T>
+void gen_random_slice(const std::string data_file, double p_val, float *&sampled_data, size_t &slice_size,
+ size_t &ndims);
+
+template <typename T>
+void gen_random_slice(std::stringstream &_data_stream, double p_val, float *&sampled_data, size_t &slice_size,
+ size_t &ndims);
+
+template <typename T>
+void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, double p_val, float *&sampled_data,
+ size_t &slice_size);
+
+int estimate_cluster_sizes(float *test_data_float, size_t num_test, float *pivots, const size_t num_centers,
+ const size_t dim, const size_t k_base, std::vector<size_t> &cluster_sizes);
+
+template <typename T>
+int shard_data_into_clusters(const std::string data_file, float *pivots, const size_t num_centers, const size_t dim,
+ const size_t k_base, std::string prefix_path);
+
+template <typename T>
+int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots, const size_t num_centers,
+ const size_t dim, const size_t k_base, std::string prefix_path);
+
+template <typename T>
+int retrieve_shard_data_from_ids(const std::string data_file, std::string idmap_filename, std::string data_filename);
+
+template <typename T>
+int partition(const std::string data_file, const float sampling_rate, size_t num_centers, size_t max_k_means_reps,
+ const std::string prefix_path, size_t k_base);
+
+template <typename T>
+int partition_with_ram_budget(const std::string data_file, const double sampling_rate, double ram_budget,
+ size_t graph_degree, const std::string prefix_path, size_t k_base);
diff --git a/be/src/extern/diskann/include/percentile_stats.h b/be/src/extern/diskann/include/percentile_stats.h
new file mode 100644
index 0000000..344b091
--- /dev/null
+++ b/be/src/extern/diskann/include/percentile_stats.h
@@ -0,0 +1,82 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <fstream>
+#include <functional>
+#ifdef _WINDOWS
+#include <numeric>
+#endif
+#include <string>
+#include <vector>
+
+#include "distance.h"
+#include "parameters.h"
+
+namespace diskann
+{
+struct QueryStats
+{
+ float total_us = 0; // total time to process query in micros
+ float io_us = 0; // total time spent in IO
+ float cpu_us = 0; // total time spent in CPU
+
+ unsigned n_4k = 0; // # of 4kB reads
+ unsigned n_8k = 0; // # of 8kB reads
+ unsigned n_12k = 0; // # of 12kB reads
+ unsigned n_ios = 0; // total # of IOs issued
+ unsigned read_size = 0; // total # of bytes read
+ unsigned n_cmps_saved = 0; // # cmps saved
+ unsigned n_cmps = 0; // # cmps
+ unsigned n_cache_hits = 0; // # cache_hits
+ unsigned n_hops = 0; // # search hops
+
+ std::string to_string(){
+ std::stringstream ss;
+ ss << "total_us:" << total_us
+ << " io_us:" << io_us
+ << " cpu_us:" << cpu_us
+ << " n_4k:" << n_4k
+ << " n_8k:" << n_8k
+ << " n_12k:" << n_12k
+ << " n_ios:" << n_ios
+ << " read_size:" << read_size
+ << " n_cmps_saved:" << n_cmps_saved
+ << " n_cmps:" << n_cmps
+ << " n_cache_hits:" << n_cache_hits
+ << " n_hops:" << n_hops;
+ return ss.str();
+ }
+};
+
+template <typename T>
+inline T get_percentile_stats(QueryStats *stats, uint64_t len, float percentile,
+ const std::function<T(const QueryStats &)> &member_fn)
+{
+ std::vector<T> vals(len);
+ for (uint64_t i = 0; i < len; i++)
+ {
+ vals[i] = member_fn(stats[i]);
+ }
+
+ std::sort(vals.begin(), vals.end(), [](const T &left, const T &right) { return left < right; });
+
+ auto retval = vals[(uint64_t)(percentile * len)];
+ vals.clear();
+ return retval;
+}
+
+template <typename T>
+inline double get_mean_stats(QueryStats *stats, uint64_t len, const std::function<T(const QueryStats &)> &member_fn)
+{
+ double avg = 0;
+ for (uint64_t i = 0; i < len; i++)
+ {
+ avg += (double)member_fn(stats[i]);
+ }
+ return avg / len;
+}
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/pq.h b/be/src/extern/diskann/include/pq.h
new file mode 100644
index 0000000..a65b904
--- /dev/null
+++ b/be/src/extern/diskann/include/pq.h
@@ -0,0 +1,117 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include "utils.h"
+#include "pq_common.h"
+#include "vector/stream_wrapper.h"
+
+namespace diskann
+{
+class FixedChunkPQTable
+{
+ float *tables = nullptr; // pq_tables = float array of size [256 * ndims]
+ uint64_t ndims = 0; // ndims = true dimension of vectors
+ uint64_t n_chunks = 0;
+ bool use_rotation = false;
+ uint32_t *chunk_offsets = nullptr;
+ float *centroid = nullptr;
+ float *tables_tr = nullptr; // same as pq_tables, but col-major
+ float *rotmat_tr = nullptr;
+
+ public:
+ FixedChunkPQTable();
+
+ virtual ~FixedChunkPQTable();
+
+#ifdef EXEC_ENV_OLS
+ void load_pq_centroid_bin(MemoryMappedFiles &files, const char *pq_table_file, size_t num_chunks);
+#else
+ void load_pq_centroid_bin(const char *pq_table_file, size_t num_chunks);
+ void load_pq_centroid_bin(std::shared_ptr<Reader> reader, size_t num_chunks);
+ void load_pq_centroid_bin(IReaderWrapperSPtr reader, size_t num_chunks);
+#endif
+
+ uint32_t get_num_chunks();
+
+ void preprocess_query(float *query_vec);
+
+ // assumes pre-processed query
+ void populate_chunk_distances(const float *query_vec, float *dist_vec);
+
+ float l2_distance(const float *query_vec, uint8_t *base_vec);
+
+ float inner_product(const float *query_vec, uint8_t *base_vec);
+
+ // assumes no rotation is involved
+ void inflate_vector(uint8_t *base_vec, float *out_vec);
+
+ void populate_chunk_inner_products(const float *query_vec, float *dist_vec);
+};
+
+void aggregate_coords(const std::vector<unsigned> &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out);
+
+void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
+ std::vector<float> &dists_out);
+
+// Need to replace calls to these with calls to vector& based functions above
+void aggregate_coords(const unsigned *ids, const uint64_t n_ids, const uint8_t *all_coords, const uint64_t ndims,
+ uint8_t *out);
+
+void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
+ float *dists_out);
+
+DISKANN_DLLEXPORT int generate_pq_pivots(const float *const train_data, size_t num_train, unsigned dim,
+ unsigned num_centers, unsigned num_pq_chunks, unsigned max_k_means_reps,
+ std::string pq_pivots_path, bool make_zero_mean = false);
+
+DISKANN_DLLEXPORT int generate_pq_pivots(const float *const train_data, size_t num_train, unsigned dim,
+ unsigned num_centers, unsigned num_pq_chunks, unsigned max_k_means_reps,
+ std::stringstream &pq_pivots_stream, bool make_zero_mean = false);
+
+DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_train, unsigned dim, unsigned num_centers,
+ unsigned num_pq_chunks, std::string opq_pivots_path,
+ bool make_zero_mean = false);
+
+DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_train, unsigned dim, unsigned num_centers,
+ unsigned num_pq_chunks, std::stringstream &opq_pivots_stream,
+ bool make_zero_mean = false);
+
+DISKANN_DLLEXPORT int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim,
+ size_t num_pq_chunks, std::vector<float> &pivot_data_vector);
+
+template <typename T>
+int generate_pq_data_from_pivots(const std::string &data_file, unsigned num_centers, unsigned num_pq_chunks,
+ const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path,
+ bool use_opq = false);
+
+template <typename T>
+int generate_pq_data_from_pivots(std::stringstream &data_stream, unsigned num_centers, unsigned num_pq_chunks,
+ std::stringstream &pq_pivots_stream, std::stringstream &pq_compressed_stream,
+ bool use_opq = false);
+
+DISKANN_DLLEXPORT int generate_pq_data_from_pivots_simplified(const float *data, const size_t num,
+ const float *pivot_data, const size_t pivots_num,
+ const size_t dim, const size_t num_pq_chunks,
+ std::vector<uint8_t> &pq);
+
+template <typename T>
+void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path,
+ const std::string &disk_pq_compressed_vectors_path,
+ const diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims);
+
+template <typename T>
+void generate_quantized_data(const std::string &data_file_to_use, const std::string &pq_pivots_path,
+ const std::string &pq_compressed_vectors_path, const diskann::Metric compareMetric,
+ const double p_val, const uint64_t num_pq_chunks, const bool use_opq,
+ const std::string &codebook_prefix = "");
+
+template <typename T>
+void generate_quantized_data(std::stringstream & _data_stream, std::stringstream &pq_pivots_stream,
+ std::stringstream &pq_compressed_stream, const diskann::Metric compareMetric,
+ const double p_val, const uint64_t num_pq_chunks, const bool use_opq,
+ const std::string &codebook_prefix = "");
+
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/pq_common.h b/be/src/extern/diskann/include/pq_common.h
new file mode 100644
index 0000000..c6a3a57
--- /dev/null
+++ b/be/src/extern/diskann/include/pq_common.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <string>
+#include <sstream>
+
+#define NUM_PQ_BITS 8
+#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
+#define MAX_OPQ_ITERS 20
+#define NUM_KMEANS_REPS_PQ 12
+#define MAX_PQ_TRAINING_SET_SIZE 256000
+#define MAX_PQ_CHUNKS 512
+
+namespace diskann
+{
+inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
+{
+ return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin";
+}
+
+inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
+{
+ return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin";
+}
+
+inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename)
+{
+ return pivot_data_filename + "_rotation_matrix.bin";
+}
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/pq_data_store.h b/be/src/extern/diskann/include/pq_data_store.h
new file mode 100644
index 0000000..7c0cb5f
--- /dev/null
+++ b/be/src/extern/diskann/include/pq_data_store.h
@@ -0,0 +1,97 @@
+#pragma once
+#include <memory>
+#include "distance.h"
+#include "quantized_distance.h"
+#include "pq.h"
+#include "abstract_data_store.h"
+
+namespace diskann
+{
+// REFACTOR TODO: By default, the PQDataStore is an in-memory datastore because both Vamana and
+// DiskANN treat it the same way. But with DiskPQ, that may need to change.
+template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>
+{
+
+ public:
+ PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, std::unique_ptr<Distance<data_t>> distance_fn,
+ std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn);
+ PQDataStore(const PQDataStore &) = delete;
+ PQDataStore &operator=(const PQDataStore &) = delete;
+ ~PQDataStore();
+
+ // Load quantized vectors from a set of files. Here filename is treated
+ // as a prefix and the files are assumed to be named with DiskANN
+ // conventions.
+ virtual location_t load(const std::string &file_prefix) override;
+
+ // Save quantized vectors to a set of files whose names start with
+ // file_prefix.
+ // Currently, the plan is to save the quantized vectors to the quantized
+ // vectors file.
+ virtual size_t save(const std::string &file_prefix, const location_t num_points) override;
+
+ // Since base class function is pure virtual, we need to declare it here, even though alignent concept is not needed
+ // for Quantized data stores.
+ virtual size_t get_aligned_dim() const override;
+
+ // Populate quantized data from unaligned data using PQ functionality
+ virtual void populate_data(const data_t *vectors, const location_t num_pts) override;
+ virtual void populate_data(const std::string &filename, const size_t offset) override;
+
+ virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) override;
+
+ virtual void get_vector(const location_t i, data_t *target) const override;
+ virtual void set_vector(const location_t i, const data_t *const vector) override;
+ virtual void prefetch_vector(const location_t loc) override;
+
+ virtual void move_vectors(const location_t old_location_start, const location_t new_location_start,
+ const location_t num_points) override;
+ virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override;
+
+ virtual void preprocess_query(const data_t *query, AbstractScratch<data_t> *scratch) const override;
+
+ virtual float get_distance(const data_t *query, const location_t loc) const override;
+ virtual float get_distance(const location_t loc1, const location_t loc2) const override;
+
+ // NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling
+ // this function.
+ virtual void get_distance(const data_t *preprocessed_query, const location_t *locations,
+ const uint32_t location_count, float *distances,
+ AbstractScratch<data_t> *scratch_space) const override;
+
+ // NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling
+ // this function.
+ virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
+ std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const override;
+
+ // We are returning the distance function that is used for full precision
+ // vectors here, not the PQ distance function. This is because the callers
+ // all are expecting a Distance<T> not QuantizedDistance<T>.
+ virtual Distance<data_t> *get_dist_fn() const override;
+
+ virtual location_t calculate_medoid() const override;
+
+ virtual size_t get_alignment_factor() const override;
+
+ protected:
+ virtual location_t expand(const location_t new_size) override;
+ virtual location_t shrink(const location_t new_size) override;
+
+ virtual location_t load_impl(const std::string &filename);
+#ifdef EXEC_ENV_OLS
+ virtual location_t load_impl(AlignedFileReader &reader);
+#endif
+
+ private:
+ uint8_t *_quantized_data = nullptr;
+ size_t _num_chunks = 0;
+
+ // REFACTOR TODO: Doing this temporarily before refactoring OPQ into
+ // its own class. Remove later.
+ bool _use_opq = false;
+
+ Metric _distance_metric;
+ std::unique_ptr<Distance<data_t>> _distance_fn = nullptr;
+ std::unique_ptr<QuantizedDistance<data_t>> _pq_distance_fn = nullptr;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/pq_flash_index.h b/be/src/extern/diskann/include/pq_flash_index.h
new file mode 100644
index 0000000..e5550a3
--- /dev/null
+++ b/be/src/extern/diskann/include/pq_flash_index.h
@@ -0,0 +1,286 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#include "common_includes.h"
+
+#include "aligned_file_reader.h"
+#include "concurrent_queue.h"
+#include "neighbor.h"
+#include "parameters.h"
+#include "percentile_stats.h"
+#include "pq.h"
+#include "utils.h"
+#include "windows_customizations.h"
+#include "scratch.h"
+#include "tsl/robin_map.h"
+#include "tsl/robin_set.h"
+#include "ThreadPool.h"
+
+#include "vector/stream_wrapper.h"
+
+
+#define FULL_PRECISION_REORDER_MULTIPLIER 3
+
+namespace diskann
+{
+
+struct Filter {
+ virtual bool is_member(uint32_t idx) {
+ return false;
+ }
+ virtual ~Filter() = default;
+};
+
+class BatchReader{
+ public:
+ BatchReader(uint32_t threads){
+ _pool = std::make_shared<ThreadPool>(threads);
+ }
+ void read(std::vector<AlignedRead> requests){
+ if (_reader) {
+ std::vector<std::future<void>> futures;
+ for (const auto& req : requests) {
+ auto future = _pool->enqueue([&, req]() {
+ _reader->read((char*)req.buf, req.len, req.offset);
+ });
+ futures.push_back(std::move(future));
+ }
+ for (auto& fut : futures) {
+ fut.get();
+ }
+ } else {
+ throw std::runtime_error("Reader not initialized properly.");
+ }
+ }
+
+ void set_reader(IReaderWrapperSPtr reader){
+ _reader = reader;
+ }
+
+ private:
+ std::shared_ptr<ThreadPool> _pool;
+ IReaderWrapperSPtr _reader;
+};
+
+template <typename T, typename LabelT = uint32_t> class PQFlashIndex
+{
+ public:
+ DISKANN_DLLEXPORT PQFlashIndex(IReaderWrapperSPtr reader,
+ diskann::Metric metric = diskann::Metric::L2);
+ DISKANN_DLLEXPORT PQFlashIndex(std::shared_ptr<AlignedFileReader> &fileReader,
+ diskann::Metric metric = diskann::Metric::L2);
+ DISKANN_DLLEXPORT ~PQFlashIndex();
+
+ DISKANN_DLLEXPORT int load(uint32_t num_threads,
+ IReaderWrapperSPtr pq_pivots_reader,
+ IReaderWrapperSPtr pq_compressed_reader,
+ IReaderWrapperSPtr vamana_index_reader,
+ IReaderWrapperSPtr disk_layout_reader,
+ IReaderWrapperSPtr tag_reader);
+
+#ifdef EXEC_ENV_OLS
+ DISKANN_DLLEXPORT int load(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix);
+#else
+ // load compressed data, and obtains the handle to the disk-resident index
+ DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix);
+#endif
+
+#ifdef EXEC_ENV_OLS
+ DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
+ const char *index_filepath, const char *pivots_filepath,
+ const char *compressed_filepath);
+#else
+ DISKANN_DLLEXPORT int load_from_separate_paths(uint32_t num_threads, const char *index_filepath,
+ const char *pivots_filepath, const char *compressed_filepath);
+ DISKANN_DLLEXPORT int load_from_compound_file(uint32_t num_threads, const char *compound_filepath);
+#endif
+
+ DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);
+
+#ifdef EXEC_ENV_OLS
+ DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(MemoryMappedFiles &files, std::string sample_bin,
+ uint64_t l_search, uint64_t beamwidth,
+ uint64_t num_nodes_to_cache, uint32_t nthreads,
+ std::vector<uint32_t> &node_list);
+#else
+ DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(std::string sample_bin, uint64_t l_search,
+ uint64_t beamwidth, uint64_t num_nodes_to_cache,
+ uint32_t num_threads,
+ std::vector<uint32_t> &node_list);
+#endif
+
+ DISKANN_DLLEXPORT void cache_bfs_levels(uint64_t num_nodes_to_cache, std::vector<uint32_t> &node_list,
+ const bool shuffle = false);
+
+ DISKANN_DLLEXPORT uint32_t cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
+ uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
+ Filter *filter = nullptr,
+ QueryStats *stats = nullptr);
+
+ DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);
+
+ DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search,
+ const uint64_t max_l_search, std::vector<uint64_t> &indices,
+ std::vector<float> &distances, const uint64_t min_beam_width,
+ QueryStats *stats = nullptr);
+
+ DISKANN_DLLEXPORT uint64_t get_data_dim();
+ std::shared_ptr<AlignedFileReader> reader;
+ IReaderWrapperSPtr customReader;
+
+ DISKANN_DLLEXPORT diskann::Metric get_metric();
+
+ //
+ // node_ids: input list of node_ids to be read
+ // coord_buffers: pointers to pre-allocated buffers that coords need to copied to. If null, dont copy.
+ // nbr_buffers: pre-allocated buffers to copy neighbors into
+ //
+ // returns a vector of bool one for each node_id: true if read is success, else false
+ //
+ DISKANN_DLLEXPORT std::vector<bool> read_nodes(const std::vector<uint32_t> &node_ids,
+ std::vector<T *> &coord_buffers,
+ std::vector<std::pair<uint32_t, uint32_t *>> &nbr_buffers);
+
+ DISKANN_DLLEXPORT std::vector<std::uint8_t> get_pq_vector(std::uint64_t vid);
+ DISKANN_DLLEXPORT uint64_t get_num_points();
+
+ protected:
+ DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
+ DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096);
+ DISKANN_DLLEXPORT void setup_thread_data_without_ctx(uint64_t nthreads, uint64_t visited_reserve = 4096);
+
+
+ DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
+
+ private:
+ DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
+ std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
+ DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
+ DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
+ uint32_t &num_total_labels);
+ DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
+ const uint32_t nthreads);
+ void reset_stream_for_reading(std::basic_istream<char> &infile);
+
+ // sector # on disk where node_id is present with in the graph part
+ DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id);
+
+ // ptr to start of the node
+ DISKANN_DLLEXPORT char *offset_to_node(char *sector_buf, uint64_t node_id);
+
+ // returns region of `node_buf` containing [NNBRS][NBR_ID(uint32_t)]
+ DISKANN_DLLEXPORT uint32_t *offset_to_node_nhood(char *node_buf);
+
+ // returns region of `node_buf` containing [COORD(T)]
+ DISKANN_DLLEXPORT T *offset_to_node_coords(char *node_buf);
+
+ // index info for multi-node sectors
+ // nhood of node `i` is in sector: [i / nnodes_per_sector]
+ // offset in sector: [(i % nnodes_per_sector) * max_node_len]
+ //
+ // index info for multi-sector nodes
+ // nhood of node `i` is in sector: [i * DIV_ROUND_UP(_max_node_len, SECTOR_LEN)]
+ // offset in sector: [0]
+ //
+ // Common info
+ // coords start at ofsset
+ // #nbrs of node `i`: *(unsigned*) (offset + disk_bytes_per_point)
+ // nbrs of node `i` : (unsigned*) (offset + disk_bytes_per_point + 1)
+
+ uint64_t _max_node_len = 0;
+ uint64_t _nnodes_per_sector = 0; // 0 for multi-sector nodes, >0 for multi-node sectors
+ uint64_t _max_degree = 0;
+
+ // Data used for searching with re-order vectors
+ uint64_t _ndims_reorder_vecs = 0;
+ uint64_t _reorder_data_start_sector = 0;
+ uint64_t _nvecs_per_sector = 0;
+
+ diskann::Metric metric = diskann::Metric::L2;
+
+ // used only for inner product search to re-scale the result value
+ // (due to the pre-processing of base during index build)
+ float _max_base_norm = 0.0f;
+
+ // data info
+ uint64_t _num_points = 0;
+ uint64_t _num_frozen_points = 0;
+ uint64_t _frozen_location = 0;
+ uint64_t _data_dim = 0;
+ uint64_t _aligned_dim = 0;
+ uint64_t _disk_bytes_per_point = 0; // Number of bytes
+
+ std::string _disk_index_file;
+ std::vector<std::pair<uint32_t, uint32_t>> _node_visit_counter;
+
+ // PQ data
+ // _n_chunks = # of chunks ndims is split into
+ // data: char * _n_chunks
+ // chunk_size = chunk size of each dimension chunk
+ // pq_tables = float* [[2^8 * [chunk_size]] * _n_chunks]
+ uint8_t *data = nullptr;
+ uint64_t _n_chunks;
+ FixedChunkPQTable _pq_table;
+
+ // distance comparator
+ std::shared_ptr<Distance<T>> _dist_cmp;
+ std::shared_ptr<Distance<float>> _dist_cmp_float;
+
+ // for very large datasets: we use PQ even for the disk resident index
+ bool _use_disk_index_pq = false;
+ uint64_t _disk_pq_n_chunks = 0;
+ FixedChunkPQTable _disk_pq_table;
+
+ // medoid/start info
+
+ // graph has one entry point by default,
+ // we can optionally have multiple starting points
+ uint32_t *_medoids = nullptr;
+ // defaults to 1
+ size_t _num_medoids;
+ // by default, it is empty. If there are multiple
+ // centroids, we pick the medoid corresponding to the
+ // closest centroid as the starting point of search
+ float *_centroid_data = nullptr;
+
+ // nhood_cache; the uint32_t in nhood_Cache are offsets into nhood_cache_buf
+ unsigned *_nhood_cache_buf = nullptr;
+ tsl::robin_map<uint32_t, std::pair<uint32_t, uint32_t *>> _nhood_cache;
+
+ // coord_cache; The T* in coord_cache are offsets into coord_cache_buf
+ T *_coord_cache_buf = nullptr;
+ tsl::robin_map<uint32_t, T *> _coord_cache;
+
+ // thread-specific scratch
+ ConcurrentQueue<SSDThreadData<T> *> _thread_data;
+ uint64_t _max_nthreads;
+ bool _load_flag = false;
+ bool _count_visited_nodes = false;
+ bool _reorder_data_exists = false;
+ uint64_t _reoreder_data_offset = 0;
+
+ // filter support
+ uint32_t *_pts_to_label_offsets = nullptr;
+ uint32_t *_pts_to_label_counts = nullptr;
+ LabelT *_pts_to_labels = nullptr;
+ std::unordered_map<LabelT, std::vector<uint32_t>> _filter_to_medoid_ids;
+ bool _use_universal_label = false;
+ LabelT _universal_filter_label;
+ tsl::robin_set<uint32_t> _dummy_pts;
+ tsl::robin_set<uint32_t> _has_dummy_pts;
+ tsl::robin_map<uint32_t, uint32_t> _dummy_to_real_map;
+ tsl::robin_map<uint32_t, std::vector<uint32_t>> _real_to_dummy_map;
+ std::unordered_map<std::string, LabelT> _label_map;
+
+ std::shared_ptr<BatchReader> _batch_reader;
+
+#ifdef EXEC_ENV_OLS
+ // Set to a larger value than the actual header to accommodate
+ // any additions we make to the header. This is an outer limit
+ // on how big the header can be.
+ static const int HEADER_SIZE = defaults::SECTOR_LEN;
+ char *getHeaderBytes();
+#endif
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/pq_l2_distance.h b/be/src/extern/diskann/include/pq_l2_distance.h
new file mode 100644
index 0000000..e6fc6e4
--- /dev/null
+++ b/be/src/extern/diskann/include/pq_l2_distance.h
@@ -0,0 +1,87 @@
+#pragma once
+#include "quantized_distance.h"
+
+namespace diskann
+{
+template <typename data_t> class PQL2Distance : public QuantizedDistance<data_t>
+{
+ public:
+ // REFACTOR TODO: We could take a file prefix here and load the
+ // PQ pivots file, so that the distance object is initialized
+ // immediately after construction. But this would not work well
+ // with our data store concept where the store is created first
+ // and data populated after.
+ // REFACTOR TODO: Ideally, we should only read the num_chunks from
+ // the pivots file. However, we read the pivots file only later, but
+ // clients can call functions like get_<xxx>_filename without calling
+ // load_pivot_data. Hence this. The TODO is whether we should check
+ // that the num_chunks from the file is the same as this one.
+
+ PQL2Distance(uint32_t num_chunks, bool use_opq = false);
+
+ virtual ~PQL2Distance() override;
+
+ virtual bool is_opq() const override;
+
+ virtual std::string get_quantized_vectors_filename(const std::string &prefix) const override;
+ virtual std::string get_pivot_data_filename(const std::string &prefix) const override;
+ virtual std::string get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const override;
+
+#ifdef EXEC_ENV_OLS
+ virtual void load_pivot_data(MemoryMappedFiles &files, const std::string &pq_table_file,
+ size_t num_chunks) override;
+#else
+ virtual void load_pivot_data(const std::string &pq_table_file, size_t num_chunks) override;
+#endif
+
+ // Number of chunks in the PQ table. Depends on the compression level used.
+ // Has to be < ndim
+ virtual uint32_t get_num_chunks() const override;
+
+ // Preprocess the query by computing chunk distances from the query vector to
+ // various centroids. Since we don't want this class to do scratch management,
+ // we will take a PQScratch object which can come either from Index class or
+ // PQFlashIndex class.
+ virtual void preprocess_query(const data_t *aligned_query, uint32_t original_dim,
+ PQScratch<data_t> &pq_scratch) override;
+
+ // Distance function used for graph traversal. This function must be called
+ // after
+ // preprocess_query. The reason we do not call preprocess ourselves is because
+ // that function has to be called once per query, while this function is
+ // called at each iteration of the graph walk. NOTE: This function expects
+ // 1. the query to be preprocessed using preprocess_query()
+ // 2. the scratch object to contain the quantized vectors corresponding to ids
+ // in aligned_pq_coord_scratch. Done by calling aggregate_coords()
+ //
+ virtual void preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t id_count,
+ float *dists_out) override;
+
+ // Same as above, but returns the distances in a vector instead of an array.
+ // Convenience function for index.cpp.
+ virtual void preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t n_ids,
+ std::vector<float> &dists_out) override;
+
+ // Currently this function is required for DiskPQ. However, it too can be
+ // subsumed under preprocessed_distance if we add the appropriate scratch
+ // variables to PQScratch and initialize them in
+ // pq_flash_index.cpp::disk_iterate_to_fixed_point()
+ virtual float brute_force_distance(const float *query_vec, uint8_t *base_vec) override;
+
+ protected:
+ // assumes pre-processed query
+ virtual void prepopulate_chunkwise_distances(const float *query_vec, float *dist_vec);
+
+ // assumes no rotation is involved
+ // virtual void inflate_vector(uint8_t *base_vec, float *out_vec);
+
+ float *_tables = nullptr; // pq_tables = float array of size [256 * ndims]
+ uint64_t _ndims = 0; // ndims = true dimension of vectors
+ uint64_t _num_chunks = 0;
+ bool _is_opq = false;
+ uint32_t *_chunk_offsets = nullptr;
+ float *_centroid = nullptr;
+ float *_tables_tr = nullptr; // same as pq_tables, but col-major
+ float *_rotmat_tr = nullptr;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/pq_scratch.h b/be/src/extern/diskann/include/pq_scratch.h
new file mode 100644
index 0000000..95f1b13
--- /dev/null
+++ b/be/src/extern/diskann/include/pq_scratch.h
@@ -0,0 +1,23 @@
+#pragma once
+#include <cstdint>
+#include "pq_common.h"
+#include "utils.h"
+
+namespace diskann
+{
+
+template <typename T> class PQScratch
+{
+ public:
+ float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
+ float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
+ uint8_t *aligned_pq_coord_scratch = nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE]
+ float *rotated_query = nullptr;
+ float *aligned_query_float = nullptr;
+
+ PQScratch(size_t graph_degree, size_t aligned_dim);
+ void initialize(size_t dim, const T *query, const float norm = 1.0f);
+ virtual ~PQScratch();
+};
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/include/program_options_utils.hpp b/be/src/extern/diskann/include/program_options_utils.hpp
new file mode 100644
index 0000000..2be6059
--- /dev/null
+++ b/be/src/extern/diskann/include/program_options_utils.hpp
@@ -0,0 +1,81 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <string.h>
+
+namespace program_options_utils
+{
+const std::string make_program_description(const char *executable_name, const char *description)
+{
+ return std::string("\n")
+ .append(description)
+ .append("\n\n")
+ .append("Usage: ")
+ .append(executable_name)
+ .append(" [OPTIONS]");
+}
+
+// Required parameters
+const char *DATA_TYPE_DESCRIPTION = "data type, one of {int8, uint8, float} - float is single precision (32 bit)";
+const char *DISTANCE_FUNCTION_DESCRIPTION =
+ "distance function {l2, mips, fast_l2, cosine}. 'fast l2' and 'mips' only support data_type float";
+const char *INDEX_PATH_PREFIX_DESCRIPTION = "Path prefix to the index, e.g. '/mnt/data/my_ann_index'";
+const char *RESULT_PATH_DESCRIPTION =
+ "Path prefix for saving results of the queries, e.g. '/mnt/data/query_file_X.bin'";
+const char *QUERY_FILE_DESCRIPTION = "Query file in binary format, e.g. '/mnt/data/query_file_X.bin'";
+const char *NUMBER_OF_RESULTS_DESCRIPTION = "Number of neighbors to be returned (K in the DiskANN white paper)";
+const char *SEARCH_LIST_DESCRIPTION =
+ "Size of search list to use. This value is the number of neighbor/distance pairs to keep in memory at the same "
+ "time while performing a query. This can also be described as the size of the working set at query time. This "
+ "must be greater than or equal to the number of results/neighbors to return (K in the white paper). Corresponds "
+ "to L in the DiskANN white paper.";
+const char *INPUT_DATA_PATH = "Input data file in bin format. This is the file you want to build the index over. "
+ "File format: Shape of the vector followed by the vector of embeddings as binary data.";
+
+// Optional parameters
+const char *FILTER_LABEL_DESCRIPTION =
+ "Filter to use when running a query. 'filter_label' and 'query_filters_file' are mutually exclusive.";
+const char *FILTERS_FILE_DESCRIPTION =
+ "Filter file for Queries for Filtered Search. File format is text with one filter per line. File must "
+ "have exactly one filter OR the same number of filters as there are queries in the 'query_file'.";
+const char *LABEL_TYPE_DESCRIPTION =
+ "Storage type of Labels {uint/uint32, ushort/uint16}, default value is uint which will consume memory 4 bytes per "
+ "filter. 'uint' is an alias for 'uint32' and 'ushort' is an alias for 'uint16'.";
+const char *GROUND_TRUTH_FILE_DESCRIPTION =
+ "ground truth file for the queryset"; // what's the format, what's the requirements? does it need to include an
+ // entry for every item or just a small subset? I have so many questions about
+ // this file
+const char *NUMBER_THREADS_DESCRIPTION = "Number of threads used for building index. Defaults to number of logical "
+ "processor cores on your this machine returned by omp_get_num_procs()";
+const char *FAIL_IF_RECALL_BELOW =
+ "Value between 0 (inclusive) and 100 (exclusive) indicating the recall tolerance percentage threshold before "
+ "program fails with a non-zero exit code. The default value of 0 means that the program will complete "
+ "successfully with any recall value. A non-zero value indicates the floor for acceptable recall values. If the "
+ "calculated recall value is below this threshold then the program will write out the results but return a non-zero "
+ "exit code as a signal that the recall was not acceptable."; // does it continue running or die immediately? Will I
+ // still get my results even if the return code is -1?
+
+const char *NUMBER_OF_NODES_TO_CACHE = "Number of BFS nodes around medoid(s) to cache. Default value: 0";
+const char *BEAMWIDTH = "Beamwidth for search. Set 0 to optimize internally. Default value: 2";
+const char *MAX_BUILD_DEGREE = "Maximum graph degree";
+const char *GRAPH_BUILD_COMPLEXITY =
+ "Size of the search working set during build time. This is the numer of neighbor/distance pairs to keep in memory "
+ "while building the index. Higher value results in a higher quality graph but it will take more time to build the "
+ "graph.";
+const char *GRAPH_BUILD_ALPHA = "Alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for "
+ "denser graphs with lower diameter";
+const char *BUIlD_GRAPH_PQ_BYTES = "Number of PQ bytes to build the index; 0 for full precision build";
+const char *USE_OPQ = "Use Optimized Product Quantization (OPQ).";
+const char *LABEL_FILE = "Input label file in txt format for Filtered Index build. The file should contain comma "
+ "separated filters for each node with each line corresponding to a graph node";
+const char *UNIVERSAL_LABEL =
+ "Universal label, Use only in conjunction with label file for filtered index build. If a "
+ "graph node has all the labels against it, we can assign a special universal filter to the "
+ "point instead of comma separated filters for that point. The universal label should be assigned to nodes "
+ "in the labels file instead of listing all labels for a node. DiskANN will not automatically assign a "
+ "universal label to a node.";
+const char *FILTERED_LBUILD = "Build complexity for filtered points, higher value results in better graphs";
+
+} // namespace program_options_utils
diff --git a/be/src/extern/diskann/include/quantized_distance.h b/be/src/extern/diskann/include/quantized_distance.h
new file mode 100644
index 0000000..cc4aea9
--- /dev/null
+++ b/be/src/extern/diskann/include/quantized_distance.h
@@ -0,0 +1,56 @@
+#pragma once
+#include <memory>
+#include <string>
+#include <vector>
+#include "abstract_scratch.h"
+
+namespace diskann
+{
+template <typename data_t> class PQScratch;
+
+template <typename data_t> class QuantizedDistance
+{
+ public:
+ QuantizedDistance() = default;
+ QuantizedDistance(const QuantizedDistance &) = delete;
+ QuantizedDistance &operator=(const QuantizedDistance &) = delete;
+ virtual ~QuantizedDistance() = default;
+
+ virtual bool is_opq() const = 0;
+ virtual std::string get_quantized_vectors_filename(const std::string &prefix) const = 0;
+ virtual std::string get_pivot_data_filename(const std::string &prefix) const = 0;
+ virtual std::string get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const = 0;
+
+ // Loading the PQ centroid table need not be part of the abstract class.
+ // However, we want to indicate that this function will change once we have a
+ // file reader hierarchy, so leave it here as-is.
+#ifdef EXEC_ENV_OLS
+ virtual void load_pivot_data(MemoryMappedFiles &files, const std::String &pq_table_file, size_t num_chunks) = 0;
+#else
+ virtual void load_pivot_data(const std::string &pq_table_file, size_t num_chunks) = 0;
+#endif
+
+ // Number of chunks in the PQ table. Depends on the compression level used.
+ // Has to be < ndim
+ virtual uint32_t get_num_chunks() const = 0;
+
+ // Preprocess the query by computing chunk distances from the query vector to
+ // various centroids. Since we don't want this class to do scratch management,
+ // we will take a PQScratch object which can come either from Index class or
+ // PQFlashIndex class.
+ virtual void preprocess_query(const data_t *query_vec, uint32_t query_dim, PQScratch<data_t> &pq_scratch) = 0;
+
+ // Workhorse
+ // This function must be called after preprocess_query
+ virtual void preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t id_count, float *dists_out) = 0;
+
+ // Same as above, but convenience function for index.cpp.
+ virtual void preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t n_ids,
+ std::vector<float> &dists_out) = 0;
+
+ // Currently this function is required for DiskPQ. However, it too can be subsumed
+ // under preprocessed_distance if we add the appropriate scratch variables to
+ // PQScratch and initialize them in pq_flash_index.cpp::disk_iterate_to_fixed_point()
+ virtual float brute_force_distance(const float *query_vec, uint8_t *base_vec) = 0;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/restapi/common.h b/be/src/extern/diskann/include/restapi/common.h
new file mode 100644
index 0000000..b833963
--- /dev/null
+++ b/be/src/extern/diskann/include/restapi/common.h
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <cpprest/base_uri.h>
+#include <restapi/search_wrapper.h>
+
+namespace diskann
+{
+// Constants
+static const std::string VECTOR_KEY = "query", K_KEY = "k", INDICES_KEY = "indices", DISTANCES_KEY = "distances",
+ TAGS_KEY = "tags", QUERY_ID_KEY = "query_id", ERROR_MESSAGE_KEY = "error", L_KEY = "Ls",
+ TIME_TAKEN_KEY = "time_taken_in_us", PARTITION_KEY = "partition",
+ UNKNOWN_ERROR = "unknown_error";
+const unsigned int DEFAULT_L = 100;
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/include/restapi/search_wrapper.h b/be/src/extern/diskann/include/restapi/search_wrapper.h
new file mode 100644
index 0000000..ebd067d
--- /dev/null
+++ b/be/src/extern/diskann/include/restapi/search_wrapper.h
@@ -0,0 +1,140 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <string>
+#include <vector>
+#include <stdexcept>
+
+#include <index.h>
+#include <pq_flash_index.h>
+
+namespace diskann
+{
+class SearchResult
+{
+ public:
+ SearchResult(unsigned int K, unsigned int elapsed_time_in_ms, const unsigned *const indices,
+ const float *const distances, const std::string *const tags = nullptr,
+ const unsigned *const partitions = nullptr);
+
+ const std::vector<unsigned int> &get_indices() const
+ {
+ return _indices;
+ }
+ const std::vector<float> &get_distances() const
+ {
+ return _distances;
+ }
+ bool tags_enabled() const
+ {
+ return _tags_enabled;
+ }
+ const std::vector<std::string> &get_tags() const
+ {
+ return _tags;
+ }
+ bool partitions_enabled() const
+ {
+ return _partitions_enabled;
+ }
+ const std::vector<unsigned> &get_partitions() const
+ {
+ return _partitions;
+ }
+ unsigned get_time() const
+ {
+ return _search_time_in_ms;
+ }
+
+ private:
+ unsigned int _K;
+ unsigned int _search_time_in_ms;
+ std::vector<unsigned int> _indices;
+ std::vector<float> _distances;
+
+ bool _tags_enabled;
+ std::vector<std::string> _tags;
+
+ bool _partitions_enabled;
+ std::vector<unsigned> _partitions;
+};
+
+class SearchNotImplementedException : public std::logic_error
+{
+ private:
+ std::string _errormsg;
+
+ public:
+ SearchNotImplementedException(const char *type) : std::logic_error("Not Implemented")
+ {
+ _errormsg = "Search with data type ";
+ _errormsg += std::string(type);
+ _errormsg += " not implemented : ";
+ _errormsg += __FUNCTION__;
+ }
+
+ virtual const char *what() const throw()
+ {
+ return _errormsg.c_str();
+ }
+};
+
+class BaseSearch
+{
+ public:
+ BaseSearch(const std::string &tagsFile = nullptr);
+ virtual SearchResult search(const float *query, const unsigned int dimensions, const unsigned int K,
+ const unsigned int Ls)
+ {
+ throw SearchNotImplementedException("float");
+ }
+ virtual SearchResult search(const int8_t *query, const unsigned int dimensions, const unsigned int K,
+ const unsigned int Ls)
+ {
+ throw SearchNotImplementedException("int8_t");
+ }
+
+ virtual SearchResult search(const uint8_t *query, const unsigned int dimensions, const unsigned int K,
+ const unsigned int Ls)
+ {
+ throw SearchNotImplementedException("uint8_t");
+ }
+
+ void lookup_tags(const unsigned K, const unsigned *indices, std::string *ret_tags);
+
+ protected:
+ bool _tags_enabled;
+ std::vector<std::string> _tags_str;
+};
+
+template <typename T> class InMemorySearch : public BaseSearch
+{
+ public:
+ InMemorySearch(const std::string &baseFile, const std::string &indexFile, const std::string &tagsFile, Metric m,
+ uint32_t num_threads, uint32_t search_l);
+ virtual ~InMemorySearch();
+
+ SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls);
+
+ private:
+ unsigned int _dimensions, _numPoints;
+ std::unique_ptr<diskann::Index<T>> _index;
+};
+
+template <typename T> class PQFlashSearch : public BaseSearch
+{
+ public:
+ PQFlashSearch(const std::string &indexPrefix, const unsigned num_nodes_to_cache, const unsigned num_threads,
+ const std::string &tagsFile, Metric m);
+ virtual ~PQFlashSearch();
+
+ SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls);
+
+ private:
+ unsigned int _dimensions, _numPoints;
+ std::unique_ptr<diskann::PQFlashIndex<T>> _index;
+ std::shared_ptr<AlignedFileReader> reader;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/restapi/server.h b/be/src/extern/diskann/include/restapi/server.h
new file mode 100644
index 0000000..1d75847
--- /dev/null
+++ b/be/src/extern/diskann/include/restapi/server.h
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <restapi/common.h>
+#include <cpprest/http_listener.h>
+
+namespace diskann
+{
+class Server
+{
+ public:
+ Server(web::uri &url, std::vector<std::unique_ptr<diskann::BaseSearch>> &multi_searcher,
+ const std::string &typestring);
+ virtual ~Server();
+
+ pplx::task<void> open();
+ pplx::task<void> close();
+
+ protected:
+ template <class T> void handle_post(web::http::http_request message);
+
+ template <typename T>
+ web::json::value toJsonArray(const std::vector<T> &v, std::function<web::json::value(const T &)> valConverter);
+ web::json::value prepareResponse(const int64_t &queryId, const int k);
+
+ template <class T>
+ void parseJson(const utility::string_t &body, unsigned int &k, int64_t &queryId, T *&queryVector,
+ unsigned int &dimensions, unsigned &Ls);
+
+ web::json::value idsToJsonArray(const diskann::SearchResult &result);
+ web::json::value distancesToJsonArray(const diskann::SearchResult &result);
+ web::json::value tagsToJsonArray(const diskann::SearchResult &result);
+ web::json::value partitionsToJsonArray(const diskann::SearchResult &result);
+
+ SearchResult aggregate_results(const unsigned K, const std::vector<diskann::SearchResult> &results);
+
+ private:
+ bool _isDebug;
+ std::unique_ptr<web::http::experimental::listener::http_listener> _listener;
+ const bool _multi_search;
+ std::vector<std::unique_ptr<diskann::BaseSearch>> _multi_searcher;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/scratch.h b/be/src/extern/diskann/include/scratch.h
new file mode 100644
index 0000000..2f43e33
--- /dev/null
+++ b/be/src/extern/diskann/include/scratch.h
@@ -0,0 +1,216 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <vector>
+
+#include "boost_dynamic_bitset_fwd.h"
+// #include "boost/dynamic_bitset.hpp"
+#include "tsl/robin_set.h"
+#include "tsl/robin_map.h"
+#include "tsl/sparse_map.h"
+
+#include "aligned_file_reader.h"
+#include "abstract_scratch.h"
+#include "neighbor.h"
+#include "defaults.h"
+#include "concurrent_queue.h"
+
+namespace diskann
+{
+template <typename T> class PQScratch;
+
+//
+// AbstractScratch space for in-memory index based search
+//
+template <typename T> class InMemQueryScratch : public AbstractScratch<T>
+{
+ public:
+ ~InMemQueryScratch();
+ InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim,
+ size_t alignment_factor, bool init_pq_scratch = false);
+ void resize_for_new_L(uint32_t new_search_l);
+ void clear();
+
+ inline uint32_t get_L()
+ {
+ return _L;
+ }
+ inline uint32_t get_R()
+ {
+ return _R;
+ }
+ inline uint32_t get_maxc()
+ {
+ return _maxc;
+ }
+ inline T *aligned_query()
+ {
+ return this->_aligned_query_T;
+ }
+ inline PQScratch<T> *pq_scratch()
+ {
+ return this->_pq_scratch;
+ }
+ inline std::vector<Neighbor> &pool()
+ {
+ return _pool;
+ }
+ inline NeighborPriorityQueue &best_l_nodes()
+ {
+ return _best_l_nodes;
+ }
+ inline std::vector<float> &occlude_factor()
+ {
+ return _occlude_factor;
+ }
+ inline tsl::robin_set<uint32_t> &inserted_into_pool_rs()
+ {
+ return _inserted_into_pool_rs;
+ }
+ inline boost::dynamic_bitset<> &inserted_into_pool_bs()
+ {
+ return *_inserted_into_pool_bs;
+ }
+ inline std::vector<uint32_t> &id_scratch()
+ {
+ return _id_scratch;
+ }
+ inline std::vector<float> &dist_scratch()
+ {
+ return _dist_scratch;
+ }
+ inline tsl::robin_set<uint32_t> &expanded_nodes_set()
+ {
+ return _expanded_nodes_set;
+ }
+ inline std::vector<Neighbor> &expanded_nodes_vec()
+ {
+ return _expanded_nghrs_vec;
+ }
+ inline std::vector<uint32_t> &occlude_list_output()
+ {
+ return _occlude_list_output;
+ }
+
+ private:
+ uint32_t _L;
+ uint32_t _R;
+ uint32_t _maxc;
+
+ // _pool stores all neighbors explored from best_L_nodes.
+ // Usually around L+R, but could be higher.
+ // Initialized to 3L+R for some slack, expands as needed.
+ std::vector<Neighbor> _pool;
+
+ // _best_l_nodes is reserved for storing best L entries
+ // Underlying storage is L+1 to support inserts
+ NeighborPriorityQueue _best_l_nodes;
+
+ // _occlude_factor.size() >= pool.size() in occlude_list function
+ // _pool is clipped to maxc in occlude_list before affecting _occlude_factor
+ // _occlude_factor is initialized to maxc size
+ std::vector<float> _occlude_factor;
+
+ // Capacity initialized to 20L
+ tsl::robin_set<uint32_t> _inserted_into_pool_rs;
+
+ // Use a pointer here to allow for forward declaration of dynamic_bitset
+ // in public headers to avoid making boost a dependency for clients
+ // of DiskANN.
+ boost::dynamic_bitset<> *_inserted_into_pool_bs;
+
+ // _id_scratch.size() must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp
+ std::vector<uint32_t> _id_scratch;
+
+ // _dist_scratch must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp
+ // _dist_scratch should be at least the size of id_scratch
+ std::vector<float> _dist_scratch;
+
+ // Buffers used in process delete, capacity increases as needed
+ tsl::robin_set<uint32_t> _expanded_nodes_set;
+ std::vector<Neighbor> _expanded_nghrs_vec;
+ std::vector<uint32_t> _occlude_list_output;
+};
+
+//
+// AbstractScratch space for SSD index based search
+//
+
+template <typename T> class SSDQueryScratch : public AbstractScratch<T>
+{
+ public:
+ T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim]
+
+ char *sector_scratch = nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN]
+ size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use
+
+ tsl::robin_set<size_t> visited;
+ NeighborPriorityQueue retset;
+ std::vector<Neighbor> full_retset;
+
+ SSDQueryScratch(size_t aligned_dim, size_t visited_reserve);
+ ~SSDQueryScratch();
+
+ void reset();
+};
+
+template <typename T> class SSDThreadData
+{
+ public:
+ SSDQueryScratch<T> scratch;
+ IOContext ctx;
+
+ SSDThreadData(size_t aligned_dim, size_t visited_reserve);
+ void clear();
+};
+
+//
+// Class to avoid the hassle of pushing and popping the query scratch.
+//
+template <typename T> class ScratchStoreManager
+{
+ public:
+ ScratchStoreManager(ConcurrentQueue<T *> &query_scratch) : _scratch_pool(query_scratch)
+ {
+ _scratch = query_scratch.pop();
+ while (_scratch == nullptr)
+ {
+ query_scratch.wait_for_push_notify();
+ _scratch = query_scratch.pop();
+ }
+ }
+ T *scratch_space()
+ {
+ return _scratch;
+ }
+
+ ~ScratchStoreManager()
+ {
+ _scratch->clear();
+ _scratch_pool.push(_scratch);
+ _scratch_pool.push_notify_all();
+ }
+
+ void destroy()
+ {
+ while (!_scratch_pool.empty())
+ {
+ auto scratch = _scratch_pool.pop();
+ while (scratch == nullptr)
+ {
+ _scratch_pool.wait_for_push_notify();
+ scratch = _scratch_pool.pop();
+ }
+ delete scratch;
+ }
+ }
+
+ private:
+ T *_scratch;
+ ConcurrentQueue<T *> &_scratch_pool;
+ ScratchStoreManager(const ScratchStoreManager<T> &);
+ ScratchStoreManager &operator=(const ScratchStoreManager<T> &);
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/simd_utils.h b/be/src/extern/diskann/include/simd_utils.h
new file mode 100644
index 0000000..4b07369
--- /dev/null
+++ b/be/src/extern/diskann/include/simd_utils.h
@@ -0,0 +1,106 @@
+#pragma once
+
+#ifdef _WINDOWS
+#include <immintrin.h>
+#include <smmintrin.h>
+#include <tmmintrin.h>
+#include <intrin.h>
+#else
+#include <immintrin.h>
+#endif
+
+namespace diskann
+{
+static inline __m256 _mm256_mul_epi8(__m256i X)
+{
+ __m256i zero = _mm256_setzero_si256();
+
+ __m256i sign_x = _mm256_cmpgt_epi8(zero, X);
+
+ __m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
+ __m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
+
+ return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, xlo), _mm256_madd_epi16(xhi, xhi)));
+}
+
+static inline __m128 _mm_mulhi_epi8(__m128i X)
+{
+ __m128i zero = _mm_setzero_si128();
+ __m128i sign_x = _mm_cmplt_epi8(X, zero);
+ __m128i xhi = _mm_unpackhi_epi8(X, sign_x);
+
+ return _mm_cvtepi32_ps(_mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi)));
+}
+
+static inline __m128 _mm_mulhi_epi8_shift32(__m128i X)
+{
+ __m128i zero = _mm_setzero_si128();
+ X = _mm_srli_epi64(X, 32);
+ __m128i sign_x = _mm_cmplt_epi8(X, zero);
+ __m128i xhi = _mm_unpackhi_epi8(X, sign_x);
+
+ return _mm_cvtepi32_ps(_mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi)));
+}
+static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y)
+{
+ __m128i zero = _mm_setzero_si128();
+
+ __m128i sign_x = _mm_cmplt_epi8(X, zero);
+ __m128i sign_y = _mm_cmplt_epi8(Y, zero);
+
+ __m128i xlo = _mm_unpacklo_epi8(X, sign_x);
+ __m128i xhi = _mm_unpackhi_epi8(X, sign_x);
+ __m128i ylo = _mm_unpacklo_epi8(Y, sign_y);
+ __m128i yhi = _mm_unpackhi_epi8(Y, sign_y);
+
+ return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi)));
+}
+static inline __m128 _mm_mul_epi8(__m128i X)
+{
+ __m128i zero = _mm_setzero_si128();
+ __m128i sign_x = _mm_cmplt_epi8(X, zero);
+ __m128i xlo = _mm_unpacklo_epi8(X, sign_x);
+ __m128i xhi = _mm_unpackhi_epi8(X, sign_x);
+
+ return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, xlo), _mm_madd_epi16(xhi, xhi)));
+}
+
+static inline __m128 _mm_mul32_pi8(__m128i X, __m128i Y)
+{
+ __m128i xlo = _mm_cvtepi8_epi16(X), ylo = _mm_cvtepi8_epi16(Y);
+ return _mm_cvtepi32_ps(_mm_unpacklo_epi32(_mm_madd_epi16(xlo, ylo), _mm_setzero_si128()));
+}
+
+static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y)
+{
+ __m256i zero = _mm256_setzero_si256();
+
+ __m256i sign_x = _mm256_cmpgt_epi8(zero, X);
+ __m256i sign_y = _mm256_cmpgt_epi8(zero, Y);
+
+ __m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
+ __m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
+ __m256i ylo = _mm256_unpacklo_epi8(Y, sign_y);
+ __m256i yhi = _mm256_unpackhi_epi8(Y, sign_y);
+
+ return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi)));
+}
+
+static inline __m256 _mm256_mul32_pi8(__m128i X, __m128i Y)
+{
+ __m256i xlo = _mm256_cvtepi8_epi16(X), ylo = _mm256_cvtepi8_epi16(Y);
+ return _mm256_blend_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(xlo, ylo)), _mm256_setzero_ps(), 252);
+}
+
+static inline float _mm256_reduce_add_ps(__m256 x)
+{
+ /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */
+ const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
+ /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
+ const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
+ /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
+ const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
+ /* Conversion to float is a no-op on x86-64 */
+ return _mm_cvtss_f32(x32);
+}
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/tag_uint128.h b/be/src/extern/diskann/include/tag_uint128.h
new file mode 100644
index 0000000..642de31
--- /dev/null
+++ b/be/src/extern/diskann/include/tag_uint128.h
@@ -0,0 +1,68 @@
+#pragma once
+#include <cstdint>
+#include <type_traits>
+
+namespace diskann
+{
+#pragma pack(push, 1)
+
+struct tag_uint128
+{
+ std::uint64_t _data1 = 0;
+ std::uint64_t _data2 = 0;
+
+ bool operator==(const tag_uint128 &other) const
+ {
+ return _data1 == other._data1 && _data2 == other._data2;
+ }
+
+ bool operator==(std::uint64_t other) const
+ {
+ return _data1 == other && _data2 == 0;
+ }
+
+ tag_uint128 &operator=(const tag_uint128 &other)
+ {
+ _data1 = other._data1;
+ _data2 = other._data2;
+
+ return *this;
+ }
+
+ tag_uint128 &operator=(std::uint64_t other)
+ {
+ _data1 = other;
+ _data2 = 0;
+
+ return *this;
+ }
+};
+
+#pragma pack(pop)
+} // namespace diskann
+
+namespace std
+{
+// Hash 128 input bits down to 64 bits of output.
+// This is intended to be a reasonably good hash function.
+inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t &high)
+{
+ // Murmur-inspired hashing.
+ const std::uint64_t kMul = 0x9ddfea08eb382d69ULL;
+ std::uint64_t a = (low ^ high) * kMul;
+ a ^= (a >> 47);
+ std::uint64_t b = (high ^ a) * kMul;
+ b ^= (b >> 47);
+ b *= kMul;
+ return b;
+}
+
+template <> struct hash<diskann::tag_uint128>
+{
+ size_t operator()(const diskann::tag_uint128 &key) const noexcept
+ {
+ return Hash128to64(key._data1, key._data2); // map -0 to 0
+ }
+};
+
+} // namespace std
\ No newline at end of file
diff --git a/be/src/extern/diskann/include/timer.h b/be/src/extern/diskann/include/timer.h
new file mode 100644
index 0000000..325edf3
--- /dev/null
+++ b/be/src/extern/diskann/include/timer.h
@@ -0,0 +1,40 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+#pragma once
+
+#include <chrono>
+
+namespace diskann
+{
+class Timer
+{
+ typedef std::chrono::high_resolution_clock _clock;
+ std::chrono::time_point<_clock> check_point;
+
+ public:
+ Timer() : check_point(_clock::now())
+ {
+ }
+
+ void reset()
+ {
+ check_point = _clock::now();
+ }
+
+ long long elapsed() const
+ {
+ return std::chrono::duration_cast<std::chrono::microseconds>(_clock::now() - check_point).count();
+ }
+
+ float elapsed_seconds() const
+ {
+ return (float)elapsed() / 1000000.0f;
+ }
+
+ std::string elapsed_seconds_for_step(const std::string &step) const
+ {
+ return std::string("Time for ") + step + std::string(": ") + std::to_string(elapsed_seconds()) +
+ std::string(" seconds");
+ }
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/tsl/.clang-format b/be/src/extern/diskann/include/tsl/.clang-format
new file mode 100644
index 0000000..9d15924
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/.clang-format
@@ -0,0 +1,2 @@
+DisableFormat: true
+SortIncludes: false
diff --git a/be/src/extern/diskann/include/tsl/robin_growth_policy.h b/be/src/extern/diskann/include/tsl/robin_growth_policy.h
new file mode 100644
index 0000000..6bfa9e5f
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/robin_growth_policy.h
@@ -0,0 +1,330 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2017 Tessil
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef TSL_ROBIN_GROWTH_POLICY_H
+#define TSL_ROBIN_GROWTH_POLICY_H
+
+
+#include <algorithm>
+#include <array>
+#include <climits>
+#include <cmath>
+#include <cstddef>
+#include <iterator>
+#include <limits>
+#include <ratio>
+#include <stdexcept>
+
+
+#ifndef tsl_assert
+# ifdef TSL_DEBUG
+# define tsl_assert(expr) assert(expr)
+# else
+# define tsl_assert(expr) (static_cast<void>(0))
+# endif
+#endif
+
+
+/**
+ * If exceptions are enabled, throw the exception passed in parameter, otherwise call std::terminate.
+ */
+#ifndef TSL_THROW_OR_TERMINATE
+# if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (defined (_MSC_VER) && defined (_CPPUNWIND))) && !defined(TSL_NO_EXCEPTIONS)
+# define TSL_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
+# else
+# ifdef NDEBUG
+# define TSL_THROW_OR_TERMINATE(ex, msg) std::terminate()
+# else
+# include <cstdio>
+# define TSL_THROW_OR_TERMINATE(ex, msg) do { std::fprintf(stderr, msg); std::terminate(); } while(0)
+# endif
+# endif
+#endif
+
+
+#ifndef TSL_LIKELY
+# if defined(__GNUC__) || defined(__clang__)
+# define TSL_LIKELY(exp) (__builtin_expect(!!(exp), true))
+# else
+# define TSL_LIKELY(exp) (exp)
+# endif
+#endif
+
+
+namespace tsl {
+namespace rh {
+
+/**
+ * Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows
+ * the table to use a mask operation instead of a modulo operation to map a hash to a bucket.
+ *
+ * GrowthFactor must be a power of two >= 2.
+ */
+template<std::size_t GrowthFactor>
+class power_of_two_growth_policy {
+public:
+ /**
+ * Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter.
+ * This number is a minimum, the policy may update this value with a higher value if needed (but not lower).
+ *
+ * If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and
+ * bucket_for_hash must always return 0 in this case.
+ */
+ explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) {
+ if(min_bucket_count_in_out > max_bucket_count()) {
+ TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
+ }
+
+ if(min_bucket_count_in_out > 0) {
+ min_bucket_count_in_out = round_up_to_power_of_two(min_bucket_count_in_out);
+ m_mask = min_bucket_count_in_out - 1;
+ }
+ else {
+ m_mask = 0;
+ }
+ }
+
+ /**
+ * Return the bucket [0, bucket_count()) to which the hash belongs.
+ * If bucket_count() is 0, it must always return 0.
+ */
+ std::size_t bucket_for_hash(std::size_t hash) const noexcept {
+ return hash & m_mask;
+ }
+
+ /**
+ * Return the number of buckets that should be used on next growth.
+ */
+ std::size_t next_bucket_count() const {
+ if((m_mask + 1) > max_bucket_count() / GrowthFactor) {
+ TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
+ }
+
+ return (m_mask + 1) * GrowthFactor;
+ }
+
+ /**
+ * Return the maximum number of buckets supported by the policy.
+ */
+ std::size_t max_bucket_count() const {
+ // Largest power of two.
+ return ((std::numeric_limits<std::size_t>::max)() / 2) + 1;
+ }
+
+ /**
+ * Reset the growth policy as if it was created with a bucket count of 0.
+ * After a clear, the policy must always return 0 when bucket_for_hash is called.
+ */
+ void clear() noexcept {
+ m_mask = 0;
+ }
+
+private:
+ static std::size_t round_up_to_power_of_two(std::size_t value) {
+ if(is_power_of_two(value)) {
+ return value;
+ }
+
+ if(value == 0) {
+ return 1;
+ }
+
+ --value;
+ for(std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
+ value |= value >> i;
+ }
+
+ return value + 1;
+ }
+
+ static constexpr bool is_power_of_two(std::size_t value) {
+ return value != 0 && (value & (value - 1)) == 0;
+ }
+
+protected:
+ static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, "GrowthFactor must be a power of two >= 2.");
+
+ std::size_t m_mask;
+};
+
+
+/**
+ * Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash
+ * to a bucket. Slower but it can be useful if you want a slower growth.
+ */
+template<class GrowthFactor = std::ratio<3, 2>>
+class mod_growth_policy {
+public:
+ explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) {
+ if(min_bucket_count_in_out > max_bucket_count()) {
+ TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
+ }
+
+ if(min_bucket_count_in_out > 0) {
+ m_mod = min_bucket_count_in_out;
+ }
+ else {
+ m_mod = 1;
+ }
+ }
+
+ std::size_t bucket_for_hash(std::size_t hash) const noexcept {
+ return hash % m_mod;
+ }
+
+ std::size_t next_bucket_count() const {
+ if(m_mod == max_bucket_count()) {
+ TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
+ }
+
+ const double next_bucket_count = std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
+ if(!std::isnormal(next_bucket_count)) {
+ TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
+ }
+
+ if(next_bucket_count > double(max_bucket_count())) {
+ return max_bucket_count();
+ }
+ else {
+ return std::size_t(next_bucket_count);
+ }
+ }
+
+ std::size_t max_bucket_count() const {
+ return MAX_BUCKET_COUNT;
+ }
+
+ void clear() noexcept {
+ m_mod = 1;
+ }
+
+private:
+ static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = 1.0 * GrowthFactor::num / GrowthFactor::den;
+ static const std::size_t MAX_BUCKET_COUNT =
+ std::size_t(double(
+ (std::numeric_limits<std::size_t>::max)() / REHASH_SIZE_MULTIPLICATION_FACTOR
+ ));
+
+ static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, "Growth factor should be >= 1.1.");
+
+ std::size_t m_mod;
+};
+
+
+
+namespace detail {
+
+static constexpr const std::array<std::size_t, 40> PRIMES = {{
+ 1ul, 5ul, 17ul, 29ul, 37ul, 53ul, 67ul, 79ul, 97ul, 131ul, 193ul, 257ul, 389ul, 521ul, 769ul, 1031ul,
+ 1543ul, 2053ul, 3079ul, 6151ul, 12289ul, 24593ul, 49157ul, 98317ul, 196613ul, 393241ul, 786433ul,
+ 1572869ul, 3145739ul, 6291469ul, 12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
+ 402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul
+}};
+
+template<unsigned int IPrime>
+static constexpr std::size_t mod(std::size_t hash) { return hash % PRIMES[IPrime]; }
+
+// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the
+// compiler can optimize the modulo code better with a constant known at the compilation.
+static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {{
+ &mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>,
+ &mod<11>, &mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
+ &mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>, &mod<29>, &mod<30>,
+ &mod<31>, &mod<32>, &mod<33>, &mod<34>, &mod<35>, &mod<36>, &mod<37> , &mod<38>, &mod<39>
+}};
+
+}
+
+/**
+ * Grow the hash table by using prime numbers as bucket count. Slower than tsl::rh::power_of_two_growth_policy in
+ * general but will probably distribute the values around better in the buckets with a poor hash function.
+ *
+ * To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers.
+ *
+ * With a switch the code would look like:
+ * \code
+ * switch(iprime) { // iprime is the current prime of the hash table
+ * case 0: hash % 5ul;
+ * break;
+ * case 1: hash % 17ul;
+ * break;
+ * case 2: hash % 29ul;
+ * break;
+ * ...
+ * }
+ * \endcode
+ *
+ * Due to the constant variable in the modulo the compiler is able to optimize the operation
+ * by a series of multiplications, substractions and shifts.
+ *
+ * The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement.
+ */
+class prime_growth_policy {
+public:
+ explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) {
+ auto it_prime = std::lower_bound(detail::PRIMES.begin(),
+ detail::PRIMES.end(), min_bucket_count_in_out);
+ if(it_prime == detail::PRIMES.end()) {
+ TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
+ }
+
+ m_iprime = static_cast<unsigned int>(std::distance(detail::PRIMES.begin(), it_prime));
+ if(min_bucket_count_in_out > 0) {
+ min_bucket_count_in_out = *it_prime;
+ }
+ else {
+ min_bucket_count_in_out = 0;
+ }
+ }
+
+ std::size_t bucket_for_hash(std::size_t hash) const noexcept {
+ return detail::MOD_PRIME[m_iprime](hash);
+ }
+
+ std::size_t next_bucket_count() const {
+ if(m_iprime + 1 >= detail::PRIMES.size()) {
+ TSL_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
+ }
+
+ return detail::PRIMES[m_iprime + 1];
+ }
+
+ std::size_t max_bucket_count() const {
+ return detail::PRIMES.back();
+ }
+
+ void clear() noexcept {
+ m_iprime = 0;
+ }
+
+private:
+ unsigned int m_iprime;
+
+ static_assert((std::numeric_limits<decltype(m_iprime)>::max)() >= detail::PRIMES.size(),
+ "The type of m_iprime is not big enough.");
+};
+
+}
+}
+
+#endif
diff --git a/be/src/extern/diskann/include/tsl/robin_hash.h b/be/src/extern/diskann/include/tsl/robin_hash.h
new file mode 100644
index 0000000..5ecc962
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/robin_hash.h
@@ -0,0 +1,1285 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2017 Tessil
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef TSL_ROBIN_HASH_H
+#define TSL_ROBIN_HASH_H
+
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <exception>
+#include <iterator>
+#include <limits>
+#include <memory>
+#include <stdexcept>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <vector>
+#include "robin_growth_policy.h"
+
+
+namespace tsl {
+
+namespace detail_robin_hash {
+
+template<typename T>
+struct make_void {
+ using type = void;
+};
+
+template<typename T, typename = void>
+struct has_is_transparent: std::false_type {
+};
+
+template<typename T>
+struct has_is_transparent<T, typename make_void<typename T::is_transparent>::type>: std::true_type {
+};
+
+template<typename U>
+struct is_power_of_two_policy: std::false_type {
+};
+
+template<std::size_t GrowthFactor>
+struct is_power_of_two_policy<tsl::rh::power_of_two_growth_policy<GrowthFactor>>: std::true_type {
+};
+
+
+
+using truncated_hash_type = std::uint_least32_t;
+
+/**
+ * Helper class that store a truncated hash if StoreHash is true and nothing otherwise.
+ */
+template<bool StoreHash>
+class bucket_entry_hash {
+public:
+ bool bucket_hash_equal(std::size_t /*hash*/) const noexcept {
+ return true;
+ }
+
+ truncated_hash_type truncated_hash() const noexcept {
+ return 0;
+ }
+
+protected:
+ void set_hash(truncated_hash_type /*hash*/) noexcept {
+ }
+};
+
+template<>
+class bucket_entry_hash<true> {
+public:
+ bool bucket_hash_equal(std::size_t hash) const noexcept {
+ return m_hash == truncated_hash_type(hash);
+ }
+
+ truncated_hash_type truncated_hash() const noexcept {
+ return m_hash;
+ }
+
+protected:
+ void set_hash(truncated_hash_type hash) noexcept {
+ m_hash = truncated_hash_type(hash);
+ }
+
+private:
+ truncated_hash_type m_hash;
+};
+
+
+/**
+ * Each bucket entry has:
+ * - A value of type `ValueType`.
+ * - An integer to store how far the value of the bucket, if any, is from its ideal bucket
+ * (ex: if the current bucket 5 has the value 'foo' and `hash('foo') % nb_buckets` == 3,
+ * `dist_from_ideal_bucket()` will return 2 as the current value of the bucket is two
+ * buckets away from its ideal bucket)
+ * If there is no value in the bucket (i.e. `empty()` is true) `dist_from_ideal_bucket()` will be < 0.
+ * - A marker which tells us if the bucket is the last bucket of the bucket array (useful for the
+ * iterator of the hash table).
+ * - If `StoreHash` is true, 32 bits of the hash of the value, if any, are also stored in the bucket.
+ * If the size of the hash is more than 32 bits, it is truncated. We don't store the full hash
+ * as storing the hash is a potential opportunity to use the unused space due to the alignement
+ * of the bucket_entry structure. We can thus potentially store the hash without any extra space
+ * (which would not be possible with 64 bits of the hash).
+ */
+template<typename ValueType, bool StoreHash>
+class bucket_entry: public bucket_entry_hash<StoreHash> {
+ using bucket_hash = bucket_entry_hash<StoreHash>;
+
+public:
+ using value_type = ValueType;
+ using distance_type = std::int_least16_t;
+
+
+ bucket_entry() noexcept: bucket_hash(), m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET),
+ m_last_bucket(false)
+ {
+ tsl_assert(empty());
+ }
+
+ bucket_entry(bool last_bucket) noexcept: bucket_hash(), m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET),
+ m_last_bucket(last_bucket)
+ {
+ tsl_assert(empty());
+ }
+
+ bucket_entry(const bucket_entry& other) noexcept(std::is_nothrow_copy_constructible<value_type>::value):
+ bucket_hash(other),
+ m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET),
+ m_last_bucket(other.m_last_bucket)
+ {
+ if(!other.empty()) {
+ ::new (static_cast<void*>(std::addressof(m_value))) value_type(other.value());
+ m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket;
+ }
+ }
+
+ /**
+ * Never really used, but still necessary as we must call resize on an empty `std::vector<bucket_entry>`.
+ * and we need to support move-only types. See robin_hash constructor for details.
+ */
+ bucket_entry(bucket_entry&& other) noexcept(std::is_nothrow_move_constructible<value_type>::value):
+ bucket_hash(std::move(other)),
+ m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET),
+ m_last_bucket(other.m_last_bucket)
+ {
+ if(!other.empty()) {
+ ::new (static_cast<void*>(std::addressof(m_value))) value_type(std::move(other.value()));
+ m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket;
+ }
+ }
+
+ bucket_entry& operator=(const bucket_entry& other)
+ noexcept(std::is_nothrow_copy_constructible<value_type>::value)
+ {
+ if(this != &other) {
+ clear();
+
+ bucket_hash::operator=(other);
+ if(!other.empty()) {
+ ::new (static_cast<void*>(std::addressof(m_value))) value_type(other.value());
+ }
+
+ m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket;
+ m_last_bucket = other.m_last_bucket;
+ }
+
+ return *this;
+ }
+
+ bucket_entry& operator=(bucket_entry&& ) = delete;
+
+ ~bucket_entry() noexcept {
+ clear();
+ }
+
+ void clear() noexcept {
+ if(!empty()) {
+ destroy_value();
+ m_dist_from_ideal_bucket = EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET;
+ }
+ }
+
+ bool empty() const noexcept {
+ return m_dist_from_ideal_bucket == EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET;
+ }
+
+ value_type& value() noexcept {
+ tsl_assert(!empty());
+ return *reinterpret_cast<value_type*>(std::addressof(m_value));
+ }
+
+ const value_type& value() const noexcept {
+ tsl_assert(!empty());
+ return *reinterpret_cast<const value_type*>(std::addressof(m_value));
+ }
+
+ distance_type dist_from_ideal_bucket() const noexcept {
+ return m_dist_from_ideal_bucket;
+ }
+
+ bool last_bucket() const noexcept {
+ return m_last_bucket;
+ }
+
+ void set_as_last_bucket() noexcept {
+ m_last_bucket = true;
+ }
+
+ template<typename... Args>
+ void set_value_of_empty_bucket(distance_type dist_from_ideal_bucket,
+ truncated_hash_type hash, Args&&... value_type_args)
+ {
+ tsl_assert(dist_from_ideal_bucket >= 0);
+ tsl_assert(empty());
+
+ ::new (static_cast<void*>(std::addressof(m_value))) value_type(std::forward<Args>(value_type_args)...);
+ this->set_hash(hash);
+ m_dist_from_ideal_bucket = dist_from_ideal_bucket;
+
+ tsl_assert(!empty());
+ }
+
+ void swap_with_value_in_bucket(distance_type& dist_from_ideal_bucket,
+ truncated_hash_type& hash, value_type& value)
+ {
+ tsl_assert(!empty());
+
+ using std::swap;
+ swap(value, this->value());
+ swap(dist_from_ideal_bucket, m_dist_from_ideal_bucket);
+
+ // Avoid warning of unused variable if StoreHash is false
+ (void) hash;
+ if(StoreHash) {
+ const truncated_hash_type tmp_hash = this->truncated_hash();
+ this->set_hash(hash);
+ hash = tmp_hash;
+ }
+ }
+
+ static truncated_hash_type truncate_hash(std::size_t hash) noexcept {
+ return truncated_hash_type(hash);
+ }
+
+private:
+ void destroy_value() noexcept {
+ tsl_assert(!empty());
+ value().~value_type();
+ }
+
+private:
+ using storage = typename std::aligned_storage<sizeof(value_type), alignof(value_type)>::type;
+
+ static const distance_type EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET = -1;
+
+ distance_type m_dist_from_ideal_bucket;
+ bool m_last_bucket;
+ storage m_value;
+};
+
+
+
+/**
+ * Internal common class used by `robin_map` and `robin_set`.
+ *
+ * ValueType is what will be stored by `robin_hash` (usually `std::pair<Key, T>` for map and `Key` for set).
+ *
+ * `KeySelect` should be a `FunctionObject` which takes a `ValueType` in parameter and returns a
+ * reference to the key.
+ *
+ * `ValueSelect` should be a `FunctionObject` which takes a `ValueType` in parameter and returns a
+ * reference to the value. `ValueSelect` should be void if there is no value (in a set for example).
+ *
+ * The strong exception guarantee only holds if the expression
+ * `std::is_nothrow_swappable<ValueType>::value && std::is_nothrow_move_constructible<ValueType>::value` is true.
+ *
+ * Behaviour is undefined if the destructor of `ValueType` throws.
+ */
+template<class ValueType,
+ class KeySelect,
+ class ValueSelect,
+ class Hash,
+ class KeyEqual,
+ class Allocator,
+ bool StoreHash,
+ class GrowthPolicy>
+class robin_hash: private Hash, private KeyEqual, private GrowthPolicy {
+private:
+ template<typename U>
+ using has_mapped_type = typename std::integral_constant<bool, !std::is_same<U, void>::value>;
+
+ static_assert(noexcept(std::declval<GrowthPolicy>().bucket_for_hash(std::size_t(0))), "GrowthPolicy::bucket_for_hash must be noexcept.");
+ static_assert(noexcept(std::declval<GrowthPolicy>().clear()), "GrowthPolicy::clear must be noexcept.");
+
+public:
+ template<bool IsConst>
+ class robin_iterator;
+
+ using key_type = typename KeySelect::key_type;
+ using value_type = ValueType;
+ using size_type = std::size_t;
+ using difference_type = std::ptrdiff_t;
+ using hasher = Hash;
+ using key_equal = KeyEqual;
+ using allocator_type = Allocator;
+ using reference = value_type&;
+ using const_reference = const value_type&;
+ using pointer = value_type*;
+ using const_pointer = const value_type*;
+ using iterator = robin_iterator<false>;
+ using const_iterator = robin_iterator<true>;
+
+
+private:
+ /**
+ * Either store the hash because we are asked by the `StoreHash` template parameter
+ * or store the hash because it doesn't cost us anything in size and can be used to speed up rehash.
+ */
+ static constexpr bool STORE_HASH = StoreHash ||
+ (
+ (sizeof(tsl::detail_robin_hash::bucket_entry<value_type, true>) ==
+ sizeof(tsl::detail_robin_hash::bucket_entry<value_type, false>))
+ &&
+ (sizeof(std::size_t) == sizeof(truncated_hash_type) ||
+ is_power_of_two_policy<GrowthPolicy>::value)
+ &&
+ // Don't store the hash for primitive types with default hash.
+ (!std::is_arithmetic<key_type>::value ||
+ !std::is_same<Hash, std::hash<key_type>>::value)
+ );
+
+ /**
+ * Only use the stored hash on lookup if we are explictly asked. We are not sure how slow
+ * the KeyEqual operation is. An extra comparison may slow things down with a fast KeyEqual.
+ */
+ static constexpr bool USE_STORED_HASH_ON_LOOKUP = StoreHash;
+
+ /**
+ * We can only use the hash on rehash if the size of the hash type is the same as the stored one or
+ * if we use a power of two modulo. In the case of the power of two modulo, we just mask
+ * the least significant bytes, we just have to check that the truncated_hash_type didn't truncated
+ * more bytes.
+ */
+ static bool USE_STORED_HASH_ON_REHASH(size_type bucket_count) {
+ (void) bucket_count;
+ if(STORE_HASH && sizeof(std::size_t) == sizeof(truncated_hash_type)) {
+ return true;
+ }
+ else if(STORE_HASH && is_power_of_two_policy<GrowthPolicy>::value) {
+ tsl_assert(bucket_count > 0);
+ return (bucket_count - 1) <= (std::numeric_limits<truncated_hash_type>::max)();
+ }
+ else {
+ return false;
+ }
+ }
+
+ using bucket_entry = tsl::detail_robin_hash::bucket_entry<value_type, STORE_HASH>;
+ using distance_type = typename bucket_entry::distance_type;
+
+ using buckets_allocator = typename std::allocator_traits<allocator_type>::template rebind_alloc<bucket_entry>;
+ using buckets_container_type = std::vector<bucket_entry, buckets_allocator>;
+
+
+public:
+ /**
+ * The 'operator*()' and 'operator->()' methods return a const reference and const pointer respectively to the
+ * stored value type.
+ *
+ * In case of a map, to get a mutable reference to the value associated to a key (the '.second' in the
+ * stored pair), you have to call 'value()'.
+ *
+ * The main reason for this is that if we returned a `std::pair<Key, T>&` instead
+ * of a `const std::pair<Key, T>&`, the user may modify the key which will put the map in a undefined state.
+ */
+ template<bool IsConst>
+ class robin_iterator {
+ friend class robin_hash;
+
+ private:
+ using iterator_bucket = typename std::conditional<IsConst,
+ typename buckets_container_type::const_iterator,
+ typename buckets_container_type::iterator>::type;
+
+
+ robin_iterator(iterator_bucket it) noexcept: m_iterator(it) {
+ }
+
+ public:
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = const typename robin_hash::value_type;
+ using difference_type = std::ptrdiff_t;
+ using reference = value_type&;
+ using pointer = value_type*;
+
+
+ robin_iterator() noexcept {
+ }
+
+ robin_iterator(const robin_iterator<false>& other) noexcept: m_iterator(other.m_iterator) {
+ }
+
+ const typename robin_hash::key_type& key() const {
+ return KeySelect()(m_iterator->value());
+ }
+
+ template<class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value && IsConst>::type* = nullptr>
+ const typename U::value_type& value() const {
+ return U()(m_iterator->value());
+ }
+
+ template<class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value && !IsConst>::type* = nullptr>
+ typename U::value_type& value() {
+ return U()(m_iterator->value());
+ }
+
+ reference operator*() const {
+ return m_iterator->value();
+ }
+
+ pointer operator->() const {
+ return std::addressof(m_iterator->value());
+ }
+
+ robin_iterator& operator++() {
+ while(true) {
+ if(m_iterator->last_bucket()) {
+ ++m_iterator;
+ return *this;
+ }
+
+ ++m_iterator;
+ if(!m_iterator->empty()) {
+ return *this;
+ }
+ }
+ }
+
+ robin_iterator operator++(int) {
+ robin_iterator tmp(*this);
+ ++*this;
+
+ return tmp;
+ }
+
+ friend bool operator==(const robin_iterator& lhs, const robin_iterator& rhs) {
+ return lhs.m_iterator == rhs.m_iterator;
+ }
+
+ friend bool operator!=(const robin_iterator& lhs, const robin_iterator& rhs) {
+ return !(lhs == rhs);
+ }
+
+ private:
+ iterator_bucket m_iterator;
+ };
+
+
+public:
+ robin_hash(size_type bucket_count,
+ const Hash& hash,
+ const KeyEqual& equal,
+ const Allocator& alloc,
+ float max_load_factor): Hash(hash),
+ KeyEqual(equal),
+ GrowthPolicy(bucket_count),
+ m_buckets(alloc),
+ m_first_or_empty_bucket(static_empty_bucket_ptr()),
+ m_bucket_count(bucket_count),
+ m_nb_elements(0),
+ m_grow_on_next_insert(false)
+ {
+ if(bucket_count > max_bucket_count()) {
+ TSL_THROW_OR_TERMINATE(std::length_error, "The map exceeds its maxmimum size.");
+ }
+
+ if(m_bucket_count > 0) {
+ /*
+ * We can't use the `vector(size_type count, const Allocator& alloc)` constructor
+ * as it's only available in C++14 and we need to support C++11. We thus must resize after using
+ * the `vector(const Allocator& alloc)` constructor.
+ *
+ * We can't use `vector(size_type count, const T& value, const Allocator& alloc)` as it requires the
+ * value T to be copyable.
+ */
+ m_buckets.resize(m_bucket_count);
+ m_first_or_empty_bucket = m_buckets.data();
+
+ tsl_assert(!m_buckets.empty());
+ m_buckets.back().set_as_last_bucket();
+ }
+
+
+ this->max_load_factor(max_load_factor);
+ }
+
+ robin_hash(const robin_hash& other): Hash(other),
+ KeyEqual(other),
+ GrowthPolicy(other),
+ m_buckets(other.m_buckets),
+ m_first_or_empty_bucket(m_buckets.empty()?static_empty_bucket_ptr():m_buckets.data()),
+ m_bucket_count(other.m_bucket_count),
+ m_nb_elements(other.m_nb_elements),
+ m_load_threshold(other.m_load_threshold),
+ m_max_load_factor(other.m_max_load_factor),
+ m_grow_on_next_insert(other.m_grow_on_next_insert)
+ {
+ }
+
+ robin_hash(robin_hash&& other) noexcept(std::is_nothrow_move_constructible<Hash>::value &&
+ std::is_nothrow_move_constructible<KeyEqual>::value &&
+ std::is_nothrow_move_constructible<GrowthPolicy>::value &&
+ std::is_nothrow_move_constructible<buckets_container_type>::value)
+ : Hash(std::move(static_cast<Hash&>(other))),
+ KeyEqual(std::move(static_cast<KeyEqual&>(other))),
+ GrowthPolicy(std::move(static_cast<GrowthPolicy&>(other))),
+ m_buckets(std::move(other.m_buckets)),
+ m_first_or_empty_bucket(m_buckets.empty()?static_empty_bucket_ptr():m_buckets.data()),
+ m_bucket_count(other.m_bucket_count),
+ m_nb_elements(other.m_nb_elements),
+ m_load_threshold(other.m_load_threshold),
+ m_max_load_factor(other.m_max_load_factor),
+ m_grow_on_next_insert(other.m_grow_on_next_insert)
+ {
+ other.GrowthPolicy::clear();
+ other.m_buckets.clear();
+ other.m_first_or_empty_bucket = static_empty_bucket_ptr();
+ other.m_bucket_count = 0;
+ other.m_nb_elements = 0;
+ other.m_load_threshold = 0;
+ other.m_grow_on_next_insert = false;
+ }
+
+ robin_hash& operator=(const robin_hash& other) {
+ if(&other != this) {
+ Hash::operator=(other);
+ KeyEqual::operator=(other);
+ GrowthPolicy::operator=(other);
+
+ m_buckets = other.m_buckets;
+ m_first_or_empty_bucket = m_buckets.empty()?static_empty_bucket_ptr():
+ m_buckets.data();
+ m_bucket_count = other.m_bucket_count;
+ m_nb_elements = other.m_nb_elements;
+ m_load_threshold = other.m_load_threshold;
+ m_max_load_factor = other.m_max_load_factor;
+ m_grow_on_next_insert = other.m_grow_on_next_insert;
+ }
+
+ return *this;
+ }
+
+ robin_hash& operator=(robin_hash&& other) {
+ other.swap(*this);
+ other.clear();
+
+ return *this;
+ }
+
+ allocator_type get_allocator() const {
+ return m_buckets.get_allocator();
+ }
+
+
+ /*
+ * Iterators
+ */
+ iterator begin() noexcept {
+ auto begin = m_buckets.begin();
+ while(begin != m_buckets.end() && begin->empty()) {
+ ++begin;
+ }
+
+ return iterator(begin);
+ }
+
+ const_iterator begin() const noexcept {
+ return cbegin();
+ }
+
+ const_iterator cbegin() const noexcept {
+ auto begin = m_buckets.cbegin();
+ while(begin != m_buckets.cend() && begin->empty()) {
+ ++begin;
+ }
+
+ return const_iterator(begin);
+ }
+
+ iterator end() noexcept {
+ return iterator(m_buckets.end());
+ }
+
+ const_iterator end() const noexcept {
+ return cend();
+ }
+
+ const_iterator cend() const noexcept {
+ return const_iterator(m_buckets.cend());
+ }
+
+
+ /*
+ * Capacity
+ */
+ bool empty() const noexcept {
+ return m_nb_elements == 0;
+ }
+
+ size_type size() const noexcept {
+ return m_nb_elements;
+ }
+
+ size_type max_size() const noexcept {
+ return m_buckets.max_size();
+ }
+
+ /*
+ * Modifiers
+ */
+ void clear() noexcept {
+ for(auto& bucket: m_buckets) {
+ bucket.clear();
+ }
+
+ m_nb_elements = 0;
+ m_grow_on_next_insert = false;
+ }
+
+
+
+ template<typename P>
+ std::pair<iterator, bool> insert(P&& value) {
+ return insert_impl(KeySelect()(value), std::forward<P>(value));
+ }
+
+ template<typename P>
+ iterator insert(const_iterator hint, P&& value) {
+ if(hint != cend() && compare_keys(KeySelect()(*hint), KeySelect()(value))) {
+ return mutable_iterator(hint);
+ }
+
+ return insert(std::forward<P>(value)).first;
+ }
+
+ template<class InputIt>
+ void insert(InputIt first, InputIt last) {
+ if(std::is_base_of<std::forward_iterator_tag,
+ typename std::iterator_traits<InputIt>::iterator_category>::value)
+ {
+ const auto nb_elements_insert = std::distance(first, last);
+ const size_type nb_free_buckets = m_load_threshold - size();
+ tsl_assert(m_load_threshold >= size());
+
+ if(nb_elements_insert > 0 && nb_free_buckets < size_type(nb_elements_insert)) {
+ reserve(size() + size_type(nb_elements_insert));
+ }
+ }
+
+ for(; first != last; ++first) {
+ insert(*first);
+ }
+ }
+
+
+
+ template<class K, class M>
+ std::pair<iterator, bool> insert_or_assign(K&& key, M&& obj) {
+ auto it = try_emplace(std::forward<K>(key), std::forward<M>(obj));
+ if(!it.second) {
+ it.first.value() = std::forward<M>(obj);
+ }
+
+ return it;
+ }
+
+ template<class K, class M>
+ iterator insert_or_assign(const_iterator hint, K&& key, M&& obj) {
+ if(hint != cend() && compare_keys(KeySelect()(*hint), key)) {
+ auto it = mutable_iterator(hint);
+ it.value() = std::forward<M>(obj);
+
+ return it;
+ }
+
+ return insert_or_assign(std::forward<K>(key), std::forward<M>(obj)).first;
+ }
+
+
+ template<class... Args>
+ std::pair<iterator, bool> emplace(Args&&... args) {
+ return insert(value_type(std::forward<Args>(args)...));
+ }
+
+ template<class... Args>
+ iterator emplace_hint(const_iterator hint, Args&&... args) {
+ return insert(hint, value_type(std::forward<Args>(args)...));
+ }
+
+
+
+ template<class K, class... Args>
+ std::pair<iterator, bool> try_emplace(K&& key, Args&&... args) {
+ return insert_impl(key, std::piecewise_construct,
+ std::forward_as_tuple(std::forward<K>(key)),
+ std::forward_as_tuple(std::forward<Args>(args)...));
+ }
+
+ template<class K, class... Args>
+ iterator try_emplace(const_iterator hint, K&& key, Args&&... args) {
+ if(hint != cend() && compare_keys(KeySelect()(*hint), key)) {
+ return mutable_iterator(hint);
+ }
+
+ return try_emplace(std::forward<K>(key), std::forward<Args>(args)...).first;
+ }
+
+ /**
+ * Here to avoid `template<class K> size_type erase(const K& key)` being used when
+ * we use an `iterator` instead of a `const_iterator`.
+ */
+ iterator erase(iterator pos) {
+ erase_from_bucket(pos);
+
+ /**
+ * Erase bucket used a backward shift after clearing the bucket.
+ * Check if there is a new value in the bucket, if not get the next non-empty.
+ */
+ if(pos.m_iterator->empty()) {
+ ++pos;
+ }
+
+ return pos;
+ }
+
+ iterator erase(const_iterator pos) {
+ return erase(mutable_iterator(pos));
+ }
+
+ iterator erase(const_iterator first, const_iterator last) {
+ if(first == last) {
+ return mutable_iterator(first);
+ }
+
+ auto first_mutable = mutable_iterator(first);
+ auto last_mutable = mutable_iterator(last);
+ for(auto it = first_mutable.m_iterator; it != last_mutable.m_iterator; ++it) {
+ if(!it->empty()) {
+ it->clear();
+ m_nb_elements--;
+ }
+ }
+
+ if(last_mutable == end()) {
+ return end();
+ }
+
+
+ /*
+ * Backward shift on the values which come after the deleted values.
+ * We try to move the values closer to their ideal bucket.
+ */
+ std::size_t icloser_bucket = std::size_t(std::distance(m_buckets.begin(), first_mutable.m_iterator));
+ std::size_t ito_move_closer_value = std::size_t(std::distance(m_buckets.begin(), last_mutable.m_iterator));
+ tsl_assert(ito_move_closer_value > icloser_bucket);
+
+ const std::size_t ireturn_bucket = ito_move_closer_value -
+ (std::min)(ito_move_closer_value - icloser_bucket,
+ std::size_t(m_buckets[ito_move_closer_value].dist_from_ideal_bucket()));
+
+ while(ito_move_closer_value < m_buckets.size() && m_buckets[ito_move_closer_value].dist_from_ideal_bucket() > 0) {
+ icloser_bucket = ito_move_closer_value -
+ (std::min)(ito_move_closer_value - icloser_bucket,
+ std::size_t(m_buckets[ito_move_closer_value].dist_from_ideal_bucket()));
+
+
+ tsl_assert(m_buckets[icloser_bucket].empty());
+ const distance_type new_distance = distance_type(m_buckets[ito_move_closer_value].dist_from_ideal_bucket() -
+ (ito_move_closer_value - icloser_bucket));
+ m_buckets[icloser_bucket].set_value_of_empty_bucket(new_distance,
+ m_buckets[ito_move_closer_value].truncated_hash(),
+ std::move(m_buckets[ito_move_closer_value].value()));
+ m_buckets[ito_move_closer_value].clear();
+
+
+ ++icloser_bucket;
+ ++ito_move_closer_value;
+ }
+
+
+ return iterator(m_buckets.begin() + ireturn_bucket);
+ }
+
+
+ template<class K>
+ size_type erase(const K& key) {
+ return erase(key, hash_key(key));
+ }
+
+ template<class K>
+ size_type erase(const K& key, std::size_t hash) {
+ auto it = find(key, hash);
+ if(it != end()) {
+ erase_from_bucket(it);
+
+ return 1;
+ }
+ else {
+ return 0;
+ }
+ }
+
+
+
+
+
+ void swap(robin_hash& other) {
+ using std::swap;
+
+ swap(static_cast<Hash&>(*this), static_cast<Hash&>(other));
+ swap(static_cast<KeyEqual&>(*this), static_cast<KeyEqual&>(other));
+ swap(static_cast<GrowthPolicy&>(*this), static_cast<GrowthPolicy&>(other));
+ swap(m_buckets, other.m_buckets);
+ swap(m_first_or_empty_bucket, other.m_first_or_empty_bucket);
+ swap(m_bucket_count, other.m_bucket_count);
+ swap(m_nb_elements, other.m_nb_elements);
+ swap(m_load_threshold, other.m_load_threshold);
+ swap(m_max_load_factor, other.m_max_load_factor);
+ swap(m_grow_on_next_insert, other.m_grow_on_next_insert);
+ }
+
+
+ /*
+ * Lookup
+ */
+ template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
+ typename U::value_type& at(const K& key) {
+ return at(key, hash_key(key));
+ }
+
+ template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
+ typename U::value_type& at(const K& key, std::size_t hash) {
+ return const_cast<typename U::value_type&>(static_cast<const robin_hash*>(this)->at(key, hash));
+ }
+
+
+ template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
+ const typename U::value_type& at(const K& key) const {
+ return at(key, hash_key(key));
+ }
+
+ template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
+ const typename U::value_type& at(const K& key, std::size_t hash) const {
+ auto it = find(key, hash);
+ if(it != cend()) {
+ return it.value();
+ }
+ else {
+ TSL_THROW_OR_TERMINATE(std::out_of_range, "Couldn't find key.");
+ }
+ }
+
+ template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
+ typename U::value_type& operator[](K&& key) {
+ return try_emplace(std::forward<K>(key)).first.value();
+ }
+
+
+ template<class K>
+ size_type count(const K& key) const {
+ return count(key, hash_key(key));
+ }
+
+ template<class K>
+ size_type count(const K& key, std::size_t hash) const {
+ if(find(key, hash) != cend()) {
+ return 1;
+ }
+ else {
+ return 0;
+ }
+ }
+
+
+ template<class K>
+ iterator find(const K& key) {
+ return find_impl(key, hash_key(key));
+ }
+
+ template<class K>
+ iterator find(const K& key, std::size_t hash) {
+ return find_impl(key, hash);
+ }
+
+
+ template<class K>
+ const_iterator find(const K& key) const {
+ return find_impl(key, hash_key(key));
+ }
+
+ template<class K>
+ const_iterator find(const K& key, std::size_t hash) const {
+ return find_impl(key, hash);
+ }
+
+
+ template<class K>
+ std::pair<iterator, iterator> equal_range(const K& key) {
+ return equal_range(key, hash_key(key));
+ }
+
+ template<class K>
+ std::pair<iterator, iterator> equal_range(const K& key, std::size_t hash) {
+ iterator it = find(key, hash);
+ return std::make_pair(it, (it == end())?it:std::next(it));
+ }
+
+
+ template<class K>
+ std::pair<const_iterator, const_iterator> equal_range(const K& key) const {
+ return equal_range(key, hash_key(key));
+ }
+
+ template<class K>
+ std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t hash) const {
+ const_iterator it = find(key, hash);
+ return std::make_pair(it, (it == cend())?it:std::next(it));
+ }
+
+ /*
+ * Bucket interface
+ */
+ size_type bucket_count() const {
+ return m_bucket_count;
+ }
+
+ size_type max_bucket_count() const {
+ return (std::min)(GrowthPolicy::max_bucket_count(), m_buckets.max_size());
+ }
+
+ /*
+ * Hash policy
+ */
+ float load_factor() const {
+ if(bucket_count() == 0) {
+ return 0;
+ }
+
+ return float(m_nb_elements)/float(bucket_count());
+ }
+
+ float max_load_factor() const {
+ return m_max_load_factor;
+ }
+
+ void max_load_factor(float ml) {
+ m_max_load_factor = (std::max)(0.1f, (std::min)(ml, 0.95f));
+ m_load_threshold = size_type(float(bucket_count())*m_max_load_factor);
+ }
+
+ void rehash(size_type count) {
+ count = (std::max)(count, size_type(std::ceil(float(size())/max_load_factor())));
+ rehash_impl(count);
+ }
+
+ void reserve(size_type count) {
+ rehash(size_type(std::ceil(float(count)/max_load_factor())));
+ }
+
+ /*
+ * Observers
+ */
+ hasher hash_function() const {
+ return static_cast<const Hash&>(*this);
+ }
+
+ key_equal key_eq() const {
+ return static_cast<const KeyEqual&>(*this);
+ }
+
+
+ /*
+ * Other
+ */
+ iterator mutable_iterator(const_iterator pos) {
+ return iterator(m_buckets.begin() + std::distance(m_buckets.cbegin(), pos.m_iterator));
+ }
+
+private:
+ template<class K>
+ std::size_t hash_key(const K& key) const {
+ return Hash::operator()(key);
+ }
+
+ template<class K1, class K2>
+ bool compare_keys(const K1& key1, const K2& key2) const {
+ return KeyEqual::operator()(key1, key2);
+ }
+
+ std::size_t bucket_for_hash(std::size_t hash) const {
+ const std::size_t bucket = GrowthPolicy::bucket_for_hash(hash);
+ tsl_assert(bucket < m_buckets.size() || (bucket == 0 && m_buckets.empty()));
+
+ return bucket;
+ }
+
+ template<class U = GrowthPolicy, typename std::enable_if<is_power_of_two_policy<U>::value>::type* = nullptr>
+ std::size_t next_bucket(std::size_t index) const noexcept {
+ tsl_assert(index < bucket_count());
+
+ return (index + 1) & this->m_mask;
+ }
+
+ template<class U = GrowthPolicy, typename std::enable_if<!is_power_of_two_policy<U>::value>::type* = nullptr>
+ std::size_t next_bucket(std::size_t index) const noexcept {
+ tsl_assert(index < bucket_count());
+
+ index++;
+ return (index != bucket_count())?index:0;
+ }
+
+
+
+ template<class K>
+ iterator find_impl(const K& key, std::size_t hash) {
+ return mutable_iterator(static_cast<const robin_hash*>(this)->find(key, hash));
+ }
+
+ template<class K>
+ const_iterator find_impl(const K& key, std::size_t hash) const {
+ std::size_t ibucket = bucket_for_hash(hash);
+ distance_type dist_from_ideal_bucket = 0;
+
+ while(dist_from_ideal_bucket <= (m_first_or_empty_bucket + ibucket)->dist_from_ideal_bucket()) {
+ if(TSL_LIKELY((!USE_STORED_HASH_ON_LOOKUP || (m_first_or_empty_bucket + ibucket)->bucket_hash_equal(hash)) &&
+ compare_keys(KeySelect()((m_first_or_empty_bucket + ibucket)->value()), key)))
+ {
+ return const_iterator(m_buckets.begin() + ibucket);
+ }
+
+ ibucket = next_bucket(ibucket);
+ dist_from_ideal_bucket++;
+ }
+
+ return cend();
+ }
+
+ void erase_from_bucket(iterator pos) {
+ pos.m_iterator->clear();
+ m_nb_elements--;
+
+ /**
+ * Backward shift, swap the empty bucket, previous_ibucket, with the values on its right, ibucket,
+ * until we cross another empty bucket or if the other bucket has a distance_from_ideal_bucket == 0.
+ *
+ * We try to move the values closer to their ideal bucket.
+ */
+ std::size_t previous_ibucket = std::size_t(std::distance(m_buckets.begin(), pos.m_iterator));
+ std::size_t ibucket = next_bucket(previous_ibucket);
+
+ while(m_buckets[ibucket].dist_from_ideal_bucket() > 0) {
+ tsl_assert(m_buckets[previous_ibucket].empty());
+
+ const distance_type new_distance = distance_type(m_buckets[ibucket].dist_from_ideal_bucket() - 1);
+ m_buckets[previous_ibucket].set_value_of_empty_bucket(new_distance, m_buckets[ibucket].truncated_hash(),
+ std::move(m_buckets[ibucket].value()));
+ m_buckets[ibucket].clear();
+
+ previous_ibucket = ibucket;
+ ibucket = next_bucket(ibucket);
+ }
+ }
+
+ template<class K, class... Args>
+ std::pair<iterator, bool> insert_impl(const K& key, Args&&... value_type_args) {
+ const std::size_t hash = hash_key(key);
+
+ std::size_t ibucket = bucket_for_hash(hash);
+ distance_type dist_from_ideal_bucket = 0;
+
+ while(dist_from_ideal_bucket <= (m_first_or_empty_bucket + ibucket)->dist_from_ideal_bucket()) {
+ if((!USE_STORED_HASH_ON_LOOKUP || (m_first_or_empty_bucket + ibucket)->bucket_hash_equal(hash)) &&
+ compare_keys(KeySelect()((m_first_or_empty_bucket + ibucket)->value()), key))
+ {
+ return std::make_pair(iterator(m_buckets.begin() + ibucket), false);
+ }
+
+ ibucket = next_bucket(ibucket);
+ dist_from_ideal_bucket++;
+ }
+
+ if(grow_on_high_load()) {
+ ibucket = bucket_for_hash(hash);
+ dist_from_ideal_bucket = 0;
+
+ while(dist_from_ideal_bucket <= (m_first_or_empty_bucket + ibucket)->dist_from_ideal_bucket()) {
+ ibucket = next_bucket(ibucket);
+ dist_from_ideal_bucket++;
+ }
+ }
+
+
+ if((m_first_or_empty_bucket + ibucket)->empty()) {
+ (m_first_or_empty_bucket + ibucket)->set_value_of_empty_bucket(dist_from_ideal_bucket, bucket_entry::truncate_hash(hash),
+ std::forward<Args>(value_type_args)...);
+ }
+ else {
+ insert_value(ibucket, dist_from_ideal_bucket, bucket_entry::truncate_hash(hash),
+ std::forward<Args>(value_type_args)...);
+ }
+
+
+ m_nb_elements++;
+ /*
+ * The value will be inserted in ibucket in any case, either because it was
+ * empty or by stealing the bucket (robin hood).
+ */
+ return std::make_pair(iterator(m_buckets.begin() + ibucket), true);
+ }
+
+
+ template<class... Args>
+ void insert_value(std::size_t ibucket, distance_type dist_from_ideal_bucket,
+ truncated_hash_type hash, Args&&... value_type_args)
+ {
+ insert_value(ibucket, dist_from_ideal_bucket, hash, value_type(std::forward<Args>(value_type_args)...));
+ }
+
+ void insert_value(std::size_t ibucket, distance_type dist_from_ideal_bucket,
+ truncated_hash_type hash, value_type&& value)
+ {
+ m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket, hash, value);
+ ibucket = next_bucket(ibucket);
+ dist_from_ideal_bucket++;
+
+ while(!m_buckets[ibucket].empty()) {
+ if(dist_from_ideal_bucket > m_buckets[ibucket].dist_from_ideal_bucket()) {
+ if(dist_from_ideal_bucket >= REHASH_ON_HIGH_NB_PROBES__NPROBES &&
+ load_factor() >= REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR)
+ {
+ /**
+ * The number of probes is really high, rehash the map on the next insert.
+ * Difficult to do now as rehash may throw an exception.
+ */
+ m_grow_on_next_insert = true;
+ }
+
+ m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket, hash, value);
+ }
+
+ ibucket = next_bucket(ibucket);
+ dist_from_ideal_bucket++;
+ }
+
+ m_buckets[ibucket].set_value_of_empty_bucket(dist_from_ideal_bucket, hash, std::move(value));
+ }
+
+
+ void rehash_impl(size_type count) {
+ robin_hash new_table(count, static_cast<Hash&>(*this), static_cast<KeyEqual&>(*this),
+ get_allocator(), m_max_load_factor);
+
+ const bool use_stored_hash = USE_STORED_HASH_ON_REHASH(new_table.bucket_count());
+ for(auto& bucket: m_buckets) {
+ if(bucket.empty()) {
+ continue;
+ }
+
+ const std::size_t hash = use_stored_hash?bucket.truncated_hash():
+ new_table.hash_key(KeySelect()(bucket.value()));
+
+ new_table.insert_value_on_rehash(new_table.bucket_for_hash(hash), 0,
+ bucket_entry::truncate_hash(hash), std::move(bucket.value()));
+ }
+
+ new_table.m_nb_elements = m_nb_elements;
+ new_table.swap(*this);
+ }
+
+ void insert_value_on_rehash(std::size_t ibucket, distance_type dist_from_ideal_bucket,
+ truncated_hash_type hash, value_type&& value)
+ {
+ while(true) {
+ if(dist_from_ideal_bucket > m_buckets[ibucket].dist_from_ideal_bucket()) {
+ if(m_buckets[ibucket].empty()) {
+ m_buckets[ibucket].set_value_of_empty_bucket(dist_from_ideal_bucket, hash, std::move(value));
+ return;
+ }
+ else {
+ m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket, hash, value);
+ }
+ }
+
+ dist_from_ideal_bucket++;
+ ibucket = next_bucket(ibucket);
+ }
+ }
+
+
+
+ /**
+ * Return true if the map has been rehashed.
+ */
+ bool grow_on_high_load() {
+ if(m_grow_on_next_insert || size() >= m_load_threshold) {
+ rehash_impl(GrowthPolicy::next_bucket_count());
+ m_grow_on_next_insert = false;
+
+ return true;
+ }
+
+ return false;
+ }
+
+
+public:
+ static const size_type DEFAULT_INIT_BUCKETS_SIZE = 16;
+ static constexpr float DEFAULT_MAX_LOAD_FACTOR = 0.5f;
+
+private:
+ static const distance_type REHASH_ON_HIGH_NB_PROBES__NPROBES = 128;
+ static constexpr float REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR = 0.15f;
+
+
+ /**
+ * Return an always valid pointer to an static empty bucket_entry with last_bucket() == true.
+ */
+ bucket_entry* static_empty_bucket_ptr() {
+ static bucket_entry empty_bucket(true);
+ return &empty_bucket;
+ }
+
+private:
+ buckets_container_type m_buckets;
+
+ /**
+ * Points to m_buckets.data() if !m_buckets.empty() otherwise points to static_empty_bucket_ptr.
+ * This variable is useful to avoid the cost of checking if m_buckets is empty when trying
+ * to find an element.
+ */
+ bucket_entry* m_first_or_empty_bucket;
+
+ /**
+ * Used a lot in find, avoid the call to m_buckets.size() which is a bit slower.
+ */
+ size_type m_bucket_count;
+
+ size_type m_nb_elements;
+
+ size_type m_load_threshold;
+ float m_max_load_factor;
+
+ bool m_grow_on_next_insert;
+};
+
+}
+
+}
+
+#endif
diff --git a/be/src/extern/diskann/include/tsl/robin_map.h b/be/src/extern/diskann/include/tsl/robin_map.h
new file mode 100644
index 0000000..5958e70
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/robin_map.h
@@ -0,0 +1,668 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2017 Tessil
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef TSL_ROBIN_MAP_H
+#define TSL_ROBIN_MAP_H
+
+
+#include <cstddef>
+#include <functional>
+#include <initializer_list>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include "robin_hash.h"
+
+
+namespace tsl {
+
+
+/**
+ * Implementation of a hash map using open-adressing and the robin hood hashing algorithm with backward shift deletion.
+ *
+ * For operations modifying the hash map (insert, erase, rehash, ...), the strong exception guarantee
+ * is only guaranteed when the expression `std::is_nothrow_swappable<std::pair<Key, T>>::value &&
+ * std::is_nothrow_move_constructible<std::pair<Key, T>>::value` is true, otherwise if an exception
+ * is thrown during the swap or the move, the hash map may end up in a undefined state. Per the standard
+ * a `Key` or `T` with a noexcept copy constructor and no move constructor also satisfies the
+ * `std::is_nothrow_move_constructible<std::pair<Key, T>>::value` criterion (and will thus guarantee the
+ * strong exception for the map).
+ *
+ * When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve
+ * the performance during lookups if the `KeyEqual` function takes time (if it engenders a cache-miss for example)
+ * as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used
+ * as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash.
+ * When it is detected that storing the hash will not incur any memory penality due to alignement (i.e.
+ * `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, true>) ==
+ * sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`) and `tsl::rh::power_of_two_growth_policy` is
+ * used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will
+ * not be used on lookups unless `StoreHash` is true).
+ *
+ * `GrowthPolicy` defines how the map grows and consequently how a hash value is mapped to a bucket.
+ * By default the map uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets
+ * to a power of two and uses a mask to map the hash to a bucket instead of the slow modulo.
+ * Other growth policies are available and you may define your own growth policy,
+ * check `tsl::rh::power_of_two_growth_policy` for the interface.
+ *
+ * If the destructor of `Key` or `T` throws an exception, the behaviour of the class is undefined.
+ *
+ * Iterators invalidation:
+ * - clear, operator=, reserve, rehash: always invalidate the iterators.
+ * - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators.
+ * - erase: always invalidate the iterators.
+ */
+template<class Key,
+ class T,
+ class Hash = std::hash<Key>,
+ class KeyEqual = std::equal_to<Key>,
+ class Allocator = std::allocator<std::pair<Key, T>>,
+ bool StoreHash = false,
+ class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
+class robin_map {
+private:
+ template<typename U>
+ using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
+
+ class KeySelect {
+ public:
+ using key_type = Key;
+
+ const key_type& operator()(const std::pair<Key, T>& key_value) const noexcept {
+ return key_value.first;
+ }
+
+ key_type& operator()(std::pair<Key, T>& key_value) noexcept {
+ return key_value.first;
+ }
+ };
+
+ class ValueSelect {
+ public:
+ using value_type = T;
+
+ const value_type& operator()(const std::pair<Key, T>& key_value) const noexcept {
+ return key_value.second;
+ }
+
+ value_type& operator()(std::pair<Key, T>& key_value) noexcept {
+ return key_value.second;
+ }
+ };
+
+ using ht = detail_robin_hash::robin_hash<std::pair<Key, T>, KeySelect, ValueSelect,
+ Hash, KeyEqual, Allocator, StoreHash, GrowthPolicy>;
+
+public:
+ using key_type = typename ht::key_type;
+ using mapped_type = T;
+ using value_type = typename ht::value_type;
+ using size_type = typename ht::size_type;
+ using difference_type = typename ht::difference_type;
+ using hasher = typename ht::hasher;
+ using key_equal = typename ht::key_equal;
+ using allocator_type = typename ht::allocator_type;
+ using reference = typename ht::reference;
+ using const_reference = typename ht::const_reference;
+ using pointer = typename ht::pointer;
+ using const_pointer = typename ht::const_pointer;
+ using iterator = typename ht::iterator;
+ using const_iterator = typename ht::const_iterator;
+
+
+public:
+ /*
+ * Constructors
+ */
+ robin_map(): robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE) {
+ }
+
+ explicit robin_map(size_type bucket_count,
+ const Hash& hash = Hash(),
+ const KeyEqual& equal = KeyEqual(),
+ const Allocator& alloc = Allocator()):
+ m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
+ {
+ }
+
+ robin_map(size_type bucket_count,
+ const Allocator& alloc): robin_map(bucket_count, Hash(), KeyEqual(), alloc)
+ {
+ }
+
+ robin_map(size_type bucket_count,
+ const Hash& hash,
+ const Allocator& alloc): robin_map(bucket_count, hash, KeyEqual(), alloc)
+ {
+ }
+
+ explicit robin_map(const Allocator& alloc): robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
+ }
+
+ template<class InputIt>
+ robin_map(InputIt first, InputIt last,
+ size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
+ const Hash& hash = Hash(),
+ const KeyEqual& equal = KeyEqual(),
+ const Allocator& alloc = Allocator()): robin_map(bucket_count, hash, equal, alloc)
+ {
+ insert(first, last);
+ }
+
+ template<class InputIt>
+ robin_map(InputIt first, InputIt last,
+ size_type bucket_count,
+ const Allocator& alloc): robin_map(first, last, bucket_count, Hash(), KeyEqual(), alloc)
+ {
+ }
+
+ template<class InputIt>
+ robin_map(InputIt first, InputIt last,
+ size_type bucket_count,
+ const Hash& hash,
+ const Allocator& alloc): robin_map(first, last, bucket_count, hash, KeyEqual(), alloc)
+ {
+ }
+
+ robin_map(std::initializer_list<value_type> init,
+ size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
+ const Hash& hash = Hash(),
+ const KeyEqual& equal = KeyEqual(),
+ const Allocator& alloc = Allocator()):
+ robin_map(init.begin(), init.end(), bucket_count, hash, equal, alloc)
+ {
+ }
+
+ robin_map(std::initializer_list<value_type> init,
+ size_type bucket_count,
+ const Allocator& alloc):
+ robin_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
+ {
+ }
+
+ robin_map(std::initializer_list<value_type> init,
+ size_type bucket_count,
+ const Hash& hash,
+ const Allocator& alloc):
+ robin_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
+ {
+ }
+
+ robin_map& operator=(std::initializer_list<value_type> ilist) {
+ m_ht.clear();
+
+ m_ht.reserve(ilist.size());
+ m_ht.insert(ilist.begin(), ilist.end());
+
+ return *this;
+ }
+
+ allocator_type get_allocator() const { return m_ht.get_allocator(); }
+
+
+ /*
+ * Iterators
+ */
+ iterator begin() noexcept { return m_ht.begin(); }
+ const_iterator begin() const noexcept { return m_ht.begin(); }
+ const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
+
+ iterator end() noexcept { return m_ht.end(); }
+ const_iterator end() const noexcept { return m_ht.end(); }
+ const_iterator cend() const noexcept { return m_ht.cend(); }
+
+
+ /*
+ * Capacity
+ */
+ bool empty() const noexcept { return m_ht.empty(); }
+ size_type size() const noexcept { return m_ht.size(); }
+ size_type max_size() const noexcept { return m_ht.max_size(); }
+
+ /*
+ * Modifiers
+ */
+ void clear() noexcept { m_ht.clear(); }
+
+
+
+ std::pair<iterator, bool> insert(const value_type& value) {
+ return m_ht.insert(value);
+ }
+
+ template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
+ std::pair<iterator, bool> insert(P&& value) {
+ return m_ht.emplace(std::forward<P>(value));
+ }
+
+ std::pair<iterator, bool> insert(value_type&& value) {
+ return m_ht.insert(std::move(value));
+ }
+
+
+ iterator insert(const_iterator hint, const value_type& value) {
+ return m_ht.insert(hint, value);
+ }
+
+ template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
+ iterator insert(const_iterator hint, P&& value) {
+ return m_ht.emplace_hint(hint, std::forward<P>(value));
+ }
+
+ iterator insert(const_iterator hint, value_type&& value) {
+ return m_ht.insert(hint, std::move(value));
+ }
+
+
+ template<class InputIt>
+ void insert(InputIt first, InputIt last) {
+ m_ht.insert(first, last);
+ }
+
+ void insert(std::initializer_list<value_type> ilist) {
+ m_ht.insert(ilist.begin(), ilist.end());
+ }
+
+
+
+
+ template<class M>
+ std::pair<iterator, bool> insert_or_assign(const key_type& k, M&& obj) {
+ return m_ht.insert_or_assign(k, std::forward<M>(obj));
+ }
+
+ template<class M>
+ std::pair<iterator, bool> insert_or_assign(key_type&& k, M&& obj) {
+ return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
+ }
+
+ template<class M>
+ iterator insert_or_assign(const_iterator hint, const key_type& k, M&& obj) {
+ return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
+ }
+
+ template<class M>
+ iterator insert_or_assign(const_iterator hint, key_type&& k, M&& obj) {
+ return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
+ }
+
+
+
+ /**
+ * Due to the way elements are stored, emplace will need to move or copy the key-value once.
+ * The method is equivalent to insert(value_type(std::forward<Args>(args)...));
+ *
+ * Mainly here for compatibility with the std::unordered_map interface.
+ */
+ template<class... Args>
+ std::pair<iterator, bool> emplace(Args&&... args) {
+ return m_ht.emplace(std::forward<Args>(args)...);
+ }
+
+
+
+ /**
+ * Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
+ * The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
+ *
+ * Mainly here for compatibility with the std::unordered_map interface.
+ */
+ template<class... Args>
+ iterator emplace_hint(const_iterator hint, Args&&... args) {
+ return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
+ }
+
+
+
+
+ template<class... Args>
+ std::pair<iterator, bool> try_emplace(const key_type& k, Args&&... args) {
+ return m_ht.try_emplace(k, std::forward<Args>(args)...);
+ }
+
+ template<class... Args>
+ std::pair<iterator, bool> try_emplace(key_type&& k, Args&&... args) {
+ return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
+ }
+
+ template<class... Args>
+ iterator try_emplace(const_iterator hint, const key_type& k, Args&&... args) {
+ return m_ht.try_emplace(hint, k, std::forward<Args>(args)...);
+ }
+
+ template<class... Args>
+ iterator try_emplace(const_iterator hint, key_type&& k, Args&&... args) {
+ return m_ht.try_emplace(hint, std::move(k), std::forward<Args>(args)...);
+ }
+
+
+
+
+ iterator erase(iterator pos) { return m_ht.erase(pos); }
+ iterator erase(const_iterator pos) { return m_ht.erase(pos); }
+ iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
+ size_type erase(const key_type& key) { return m_ht.erase(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
+ */
+ size_type erase(const key_type& key, std::size_t precalculated_hash) {
+ return m_ht.erase(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ size_type erase(const K& key) { return m_ht.erase(key); }
+
+ /**
+ * @copydoc erase(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ size_type erase(const K& key, std::size_t precalculated_hash) {
+ return m_ht.erase(key, precalculated_hash);
+ }
+
+
+
+ void swap(robin_map& other) { other.m_ht.swap(m_ht); }
+
+
+
+ /*
+ * Lookup
+ */
+ T& at(const Key& key) { return m_ht.at(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ T& at(const Key& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
+
+
+ const T& at(const Key& key) const { return m_ht.at(key); }
+
+ /**
+ * @copydoc at(const Key& key, std::size_t precalculated_hash)
+ */
+ const T& at(const Key& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
+
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ T& at(const K& key) { return m_ht.at(key); }
+
+ /**
+ * @copydoc at(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ T& at(const K& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
+
+
+ /**
+ * @copydoc at(const K& key)
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ const T& at(const K& key) const { return m_ht.at(key); }
+
+ /**
+ * @copydoc at(const K& key, std::size_t precalculated_hash)
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ const T& at(const K& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
+
+
+
+
+ T& operator[](const Key& key) { return m_ht[key]; }
+ T& operator[](Key&& key) { return m_ht[std::move(key)]; }
+
+
+
+
+ size_type count(const Key& key) const { return m_ht.count(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ size_type count(const Key& key, std::size_t precalculated_hash) const {
+ return m_ht.count(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ size_type count(const K& key) const { return m_ht.count(key); }
+
+ /**
+ * @copydoc count(const K& key) const
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
+
+
+
+
+ iterator find(const Key& key) { return m_ht.find(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
+
+ const_iterator find(const Key& key) const { return m_ht.find(key); }
+
+ /**
+ * @copydoc find(const Key& key, std::size_t precalculated_hash)
+ */
+ const_iterator find(const Key& key, std::size_t precalculated_hash) const {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ iterator find(const K& key) { return m_ht.find(key); }
+
+ /**
+ * @copydoc find(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
+
+ /**
+ * @copydoc find(const K& key)
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ const_iterator find(const K& key) const { return m_ht.find(key); }
+
+ /**
+ * @copydoc find(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ const_iterator find(const K& key, std::size_t precalculated_hash) const {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+
+
+
+ std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
+
+ /**
+ * @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
+ */
+ std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
+
+
+ /**
+ * @copydoc equal_range(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /**
+ * @copydoc equal_range(const K& key)
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
+
+ /**
+ * @copydoc equal_range(const K& key, std::size_t precalculated_hash)
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+
+
+
+ /*
+ * Bucket interface
+ */
+ size_type bucket_count() const { return m_ht.bucket_count(); }
+ size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
+
+
+ /*
+ * Hash policy
+ */
+ float load_factor() const { return m_ht.load_factor(); }
+ float max_load_factor() const { return m_ht.max_load_factor(); }
+ void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
+
+ void rehash(size_type count) { m_ht.rehash(count); }
+ void reserve(size_type count) { m_ht.reserve(count); }
+
+
+ /*
+ * Observers
+ */
+ hasher hash_function() const { return m_ht.hash_function(); }
+ key_equal key_eq() const { return m_ht.key_eq(); }
+
+ /*
+ * Other
+ */
+
+ /**
+ * Convert a const_iterator to an iterator.
+ */
+ iterator mutable_iterator(const_iterator pos) {
+ return m_ht.mutable_iterator(pos);
+ }
+
+ friend bool operator==(const robin_map& lhs, const robin_map& rhs) {
+ if(lhs.size() != rhs.size()) {
+ return false;
+ }
+
+ for(const auto& element_lhs: lhs) {
+ const auto it_element_rhs = rhs.find(element_lhs.first);
+ if(it_element_rhs == rhs.cend() || element_lhs.second != it_element_rhs->second) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ friend bool operator!=(const robin_map& lhs, const robin_map& rhs) {
+ return !operator==(lhs, rhs);
+ }
+
+ friend void swap(robin_map& lhs, robin_map& rhs) {
+ lhs.swap(rhs);
+ }
+
+private:
+ ht m_ht;
+};
+
+
+/**
+ * Same as `tsl::robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>`.
+ */
+template<class Key,
+ class T,
+ class Hash = std::hash<Key>,
+ class KeyEqual = std::equal_to<Key>,
+ class Allocator = std::allocator<std::pair<Key, T>>,
+ bool StoreHash = false>
+using robin_pg_map = robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>;
+
+} // end namespace tsl
+
+#endif
diff --git a/be/src/extern/diskann/include/tsl/robin_set.h b/be/src/extern/diskann/include/tsl/robin_set.h
new file mode 100644
index 0000000..4e4667e
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/robin_set.h
@@ -0,0 +1,535 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2017 Tessil
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef TSL_ROBIN_SET_H
+#define TSL_ROBIN_SET_H
+
+
+#include <cstddef>
+#include <functional>
+#include <initializer_list>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include "robin_hash.h"
+
+
+namespace tsl {
+
+
+/**
+ * Implementation of a hash set using open-adressing and the robin hood hashing algorithm with backward shift deletion.
+ *
+ * For operations modifying the hash set (insert, erase, rehash, ...), the strong exception guarantee
+ * is only guaranteed when the expression `std::is_nothrow_swappable<Key>::value &&
+ * std::is_nothrow_move_constructible<Key>::value` is true, otherwise if an exception
+ * is thrown during the swap or the move, the hash set may end up in a undefined state. Per the standard
+ * a `Key` with a noexcept copy constructor and no move constructor also satisfies the
+ * `std::is_nothrow_move_constructible<Key>::value` criterion (and will thus guarantee the
+ * strong exception for the set).
+ *
+ * When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve
+ * the performance during lookups if the `KeyEqual` function takes time (or engenders a cache-miss for example)
+ * as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used
+ * as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash.
+ * When it is detected that storing the hash will not incur any memory penality due to alignement (i.e.
+ * `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, true>) ==
+ * sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`) and `tsl::rh::power_of_two_growth_policy` is
+ * used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will
+ * not be used on lookups unless `StoreHash` is true).
+ *
+ * `GrowthPolicy` defines how the set grows and consequently how a hash value is mapped to a bucket.
+ * By default the set uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets
+ * to a power of two and uses a mask to set the hash to a bucket instead of the slow modulo.
+ * Other growth policies are available and you may define your own growth policy,
+ * check `tsl::rh::power_of_two_growth_policy` for the interface.
+ *
+ * If the destructor of `Key` throws an exception, the behaviour of the class is undefined.
+ *
+ * Iterators invalidation:
+ * - clear, operator=, reserve, rehash: always invalidate the iterators.
+ * - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators.
+ * - erase: always invalidate the iterators.
+ */
+template<class Key,
+ class Hash = std::hash<Key>,
+ class KeyEqual = std::equal_to<Key>,
+ class Allocator = std::allocator<Key>,
+ bool StoreHash = false,
+ class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
+class robin_set {
+private:
+ template<typename U>
+ using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
+
+ class KeySelect {
+ public:
+ using key_type = Key;
+
+ const key_type& operator()(const Key& key) const noexcept {
+ return key;
+ }
+
+ key_type& operator()(Key& key) noexcept {
+ return key;
+ }
+ };
+
+ using ht = detail_robin_hash::robin_hash<Key, KeySelect, void,
+ Hash, KeyEqual, Allocator, StoreHash, GrowthPolicy>;
+
+public:
+ using key_type = typename ht::key_type;
+ using value_type = typename ht::value_type;
+ using size_type = typename ht::size_type;
+ using difference_type = typename ht::difference_type;
+ using hasher = typename ht::hasher;
+ using key_equal = typename ht::key_equal;
+ using allocator_type = typename ht::allocator_type;
+ using reference = typename ht::reference;
+ using const_reference = typename ht::const_reference;
+ using pointer = typename ht::pointer;
+ using const_pointer = typename ht::const_pointer;
+ using iterator = typename ht::iterator;
+ using const_iterator = typename ht::const_iterator;
+
+
+ /*
+ * Constructors
+ */
+ robin_set(): robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE) {
+ }
+
+ explicit robin_set(size_type bucket_count,
+ const Hash& hash = Hash(),
+ const KeyEqual& equal = KeyEqual(),
+ const Allocator& alloc = Allocator()):
+ m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
+ {
+ }
+
+ robin_set(size_type bucket_count,
+ const Allocator& alloc): robin_set(bucket_count, Hash(), KeyEqual(), alloc)
+ {
+ }
+
+ robin_set(size_type bucket_count,
+ const Hash& hash,
+ const Allocator& alloc): robin_set(bucket_count, hash, KeyEqual(), alloc)
+ {
+ }
+
+ explicit robin_set(const Allocator& alloc): robin_set(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
+ }
+
+ template<class InputIt>
+ robin_set(InputIt first, InputIt last,
+ size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
+ const Hash& hash = Hash(),
+ const KeyEqual& equal = KeyEqual(),
+ const Allocator& alloc = Allocator()): robin_set(bucket_count, hash, equal, alloc)
+ {
+ insert(first, last);
+ }
+
+ template<class InputIt>
+ robin_set(InputIt first, InputIt last,
+ size_type bucket_count,
+ const Allocator& alloc): robin_set(first, last, bucket_count, Hash(), KeyEqual(), alloc)
+ {
+ }
+
+ template<class InputIt>
+ robin_set(InputIt first, InputIt last,
+ size_type bucket_count,
+ const Hash& hash,
+ const Allocator& alloc): robin_set(first, last, bucket_count, hash, KeyEqual(), alloc)
+ {
+ }
+
+ robin_set(std::initializer_list<value_type> init,
+ size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
+ const Hash& hash = Hash(),
+ const KeyEqual& equal = KeyEqual(),
+ const Allocator& alloc = Allocator()):
+ robin_set(init.begin(), init.end(), bucket_count, hash, equal, alloc)
+ {
+ }
+
+ robin_set(std::initializer_list<value_type> init,
+ size_type bucket_count,
+ const Allocator& alloc):
+ robin_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
+ {
+ }
+
+ robin_set(std::initializer_list<value_type> init,
+ size_type bucket_count,
+ const Hash& hash,
+ const Allocator& alloc):
+ robin_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
+ {
+ }
+
+
+ robin_set& operator=(std::initializer_list<value_type> ilist) {
+ m_ht.clear();
+
+ m_ht.reserve(ilist.size());
+ m_ht.insert(ilist.begin(), ilist.end());
+
+ return *this;
+ }
+
+ allocator_type get_allocator() const { return m_ht.get_allocator(); }
+
+
+ /*
+ * Iterators
+ */
+ iterator begin() noexcept { return m_ht.begin(); }
+ const_iterator begin() const noexcept { return m_ht.begin(); }
+ const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
+
+ iterator end() noexcept { return m_ht.end(); }
+ const_iterator end() const noexcept { return m_ht.end(); }
+ const_iterator cend() const noexcept { return m_ht.cend(); }
+
+
+ /*
+ * Capacity
+ */
+ bool empty() const noexcept { return m_ht.empty(); }
+ size_type size() const noexcept { return m_ht.size(); }
+ size_type max_size() const noexcept { return m_ht.max_size(); }
+
+ /*
+ * Modifiers
+ */
+ void clear() noexcept { m_ht.clear(); }
+
+
+
+
+ std::pair<iterator, bool> insert(const value_type& value) {
+ return m_ht.insert(value);
+ }
+
+ std::pair<iterator, bool> insert(value_type&& value) {
+ return m_ht.insert(std::move(value));
+ }
+
+ iterator insert(const_iterator hint, const value_type& value) {
+ return m_ht.insert(hint, value);
+ }
+
+ iterator insert(const_iterator hint, value_type&& value) {
+ return m_ht.insert(hint, std::move(value));
+ }
+
+ template<class InputIt>
+ void insert(InputIt first, InputIt last) {
+ m_ht.insert(first, last);
+ }
+
+ void insert(std::initializer_list<value_type> ilist) {
+ m_ht.insert(ilist.begin(), ilist.end());
+ }
+
+
+
+
+ /**
+ * Due to the way elements are stored, emplace will need to move or copy the key-value once.
+ * The method is equivalent to insert(value_type(std::forward<Args>(args)...));
+ *
+ * Mainly here for compatibility with the std::unordered_map interface.
+ */
+ template<class... Args>
+ std::pair<iterator, bool> emplace(Args&&... args) {
+ return m_ht.emplace(std::forward<Args>(args)...);
+ }
+
+
+
+ /**
+ * Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
+ * The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
+ *
+ * Mainly here for compatibility with the std::unordered_map interface.
+ */
+ template<class... Args>
+ iterator emplace_hint(const_iterator hint, Args&&... args) {
+ return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
+ }
+
+
+
+ iterator erase(iterator pos) { return m_ht.erase(pos); }
+ iterator erase(const_iterator pos) { return m_ht.erase(pos); }
+ iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
+ size_type erase(const key_type& key) { return m_ht.erase(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
+ */
+ size_type erase(const key_type& key, std::size_t precalculated_hash) {
+ return m_ht.erase(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ size_type erase(const K& key) { return m_ht.erase(key); }
+
+ /**
+ * @copydoc erase(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ size_type erase(const K& key, std::size_t precalculated_hash) {
+ return m_ht.erase(key, precalculated_hash);
+ }
+
+
+
+ void swap(robin_set& other) { other.m_ht.swap(m_ht); }
+
+
+
+ /*
+ * Lookup
+ */
+ size_type count(const Key& key) const { return m_ht.count(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ size_type count(const Key& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ size_type count(const K& key) const { return m_ht.count(key); }
+
+ /**
+ * @copydoc count(const K& key) const
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
+
+
+
+
+ iterator find(const Key& key) { return m_ht.find(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
+
+ const_iterator find(const Key& key) const { return m_ht.find(key); }
+
+ /**
+ * @copydoc find(const Key& key, std::size_t precalculated_hash)
+ */
+ const_iterator find(const Key& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ iterator find(const K& key) { return m_ht.find(key); }
+
+ /**
+ * @copydoc find(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
+
+ /**
+ * @copydoc find(const K& key)
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ const_iterator find(const K& key) const { return m_ht.find(key); }
+
+ /**
+ * @copydoc find(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ const_iterator find(const K& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
+
+
+
+
+ std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
+
+ /**
+ * @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
+ */
+ std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
+ * If so, K must be hashable and comparable to Key.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
+
+ /**
+ * @copydoc equal_range(const K& key)
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
+ * as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /**
+ * @copydoc equal_range(const K& key)
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
+
+ /**
+ * @copydoc equal_range(const K& key, std::size_t precalculated_hash)
+ */
+ template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
+ std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+
+
+
+ /*
+ * Bucket interface
+ */
+ size_type bucket_count() const { return m_ht.bucket_count(); }
+ size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
+
+
+ /*
+ * Hash policy
+ */
+ float load_factor() const { return m_ht.load_factor(); }
+ float max_load_factor() const { return m_ht.max_load_factor(); }
+ void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
+
+ void rehash(size_type count) { m_ht.rehash(count); }
+ void reserve(size_type count) { m_ht.reserve(count); }
+
+
+ /*
+ * Observers
+ */
+ hasher hash_function() const { return m_ht.hash_function(); }
+ key_equal key_eq() const { return m_ht.key_eq(); }
+
+
+ /*
+ * Other
+ */
+
+ /**
+ * Convert a const_iterator to an iterator.
+ */
+ iterator mutable_iterator(const_iterator pos) {
+ return m_ht.mutable_iterator(pos);
+ }
+
+ friend bool operator==(const robin_set& lhs, const robin_set& rhs) {
+ if(lhs.size() != rhs.size()) {
+ return false;
+ }
+
+ for(const auto& element_lhs: lhs) {
+ const auto it_element_rhs = rhs.find(element_lhs);
+ if(it_element_rhs == rhs.cend()) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ friend bool operator!=(const robin_set& lhs, const robin_set& rhs) {
+ return !operator==(lhs, rhs);
+ }
+
+ friend void swap(robin_set& lhs, robin_set& rhs) {
+ lhs.swap(rhs);
+ }
+
+private:
+ ht m_ht;
+};
+
+
+/**
+ * Same as `tsl::robin_set<Key, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>`.
+ */
+template<class Key,
+ class Hash = std::hash<Key>,
+ class KeyEqual = std::equal_to<Key>,
+ class Allocator = std::allocator<Key>,
+ bool StoreHash = false>
+using robin_pg_set = robin_set<Key, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>;
+
+} // end namespace tsl
+
+#endif
+
diff --git a/be/src/extern/diskann/include/tsl/sparse_growth_policy.h b/be/src/extern/diskann/include/tsl/sparse_growth_policy.h
new file mode 100644
index 0000000..d73aaaf
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/sparse_growth_policy.h
@@ -0,0 +1,301 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef TSL_SPARSE_GROWTH_POLICY_H
+#define TSL_SPARSE_GROWTH_POLICY_H
+
+#include <algorithm>
+#include <array>
+#include <climits>
+#include <cmath>
+#include <cstddef>
+#include <iterator>
+#include <limits>
+#include <ratio>
+#include <stdexcept>
+
+namespace tsl {
+namespace sh {
+
+/**
+ * Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
+ * power of two. It allows the table to use a mask operation instead of a modulo
+ * operation to map a hash to a bucket.
+ *
+ * GrowthFactor must be a power of two >= 2.
+ */
+template <std::size_t GrowthFactor>
+class power_of_two_growth_policy {
+ public:
+ /**
+ * Called on the hash table creation and on rehash. The number of buckets for
+ * the table is passed in parameter. This number is a minimum, the policy may
+ * update this value with a higher value if needed (but not lower).
+ *
+ * If 0 is given, min_bucket_count_in_out must still be 0 after the policy
+ * creation and bucket_for_hash must always return 0 in this case.
+ */
+ explicit power_of_two_growth_policy(std::size_t &min_bucket_count_in_out) {
+ if (min_bucket_count_in_out > max_bucket_count()) {
+ throw std::length_error("The hash table exceeds its maximum size.");
+ }
+
+ if (min_bucket_count_in_out > 0) {
+ min_bucket_count_in_out =
+ round_up_to_power_of_two(min_bucket_count_in_out);
+ m_mask = min_bucket_count_in_out - 1;
+ } else {
+ m_mask = 0;
+ }
+ }
+
+ /**
+ * Return the bucket [0, bucket_count()) to which the hash belongs.
+ * If bucket_count() is 0, it must always return 0.
+ */
+ std::size_t bucket_for_hash(std::size_t hash) const noexcept {
+ return hash & m_mask;
+ }
+
+ /**
+ * Return the number of buckets that should be used on next growth.
+ */
+ std::size_t next_bucket_count() const {
+ if ((m_mask + 1) > max_bucket_count() / GrowthFactor) {
+ throw std::length_error("The hash table exceeds its maximum size.");
+ }
+
+ return (m_mask + 1) * GrowthFactor;
+ }
+
+ /**
+ * Return the maximum number of buckets supported by the policy.
+ */
+ std::size_t max_bucket_count() const {
+ // Largest power of two.
+ return (std::numeric_limits<std::size_t>::max() / 2) + 1;
+ }
+
+ /**
+ * Reset the growth policy as if it was created with a bucket count of 0.
+ * After a clear, the policy must always return 0 when bucket_for_hash is
+ * called.
+ */
+ void clear() noexcept { m_mask = 0; }
+
+ private:
+ static std::size_t round_up_to_power_of_two(std::size_t value) {
+ if (is_power_of_two(value)) {
+ return value;
+ }
+
+ if (value == 0) {
+ return 1;
+ }
+
+ --value;
+ for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
+ value |= value >> i;
+ }
+
+ return value + 1;
+ }
+
+ static constexpr bool is_power_of_two(std::size_t value) {
+ return value != 0 && (value & (value - 1)) == 0;
+ }
+
+ protected:
+ static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2,
+ "GrowthFactor must be a power of two >= 2.");
+
+ std::size_t m_mask;
+};
+
+/**
+ * Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
+ * to map a hash to a bucket. Slower but it can be useful if you want a slower
+ * growth.
+ */
+template <class GrowthFactor = std::ratio<3, 2>>
+class mod_growth_policy {
+ public:
+ explicit mod_growth_policy(std::size_t &min_bucket_count_in_out) {
+ if (min_bucket_count_in_out > max_bucket_count()) {
+ throw std::length_error("The hash table exceeds its maximum size.");
+ }
+
+ if (min_bucket_count_in_out > 0) {
+ m_mod = min_bucket_count_in_out;
+ } else {
+ m_mod = 1;
+ }
+ }
+
+ std::size_t bucket_for_hash(std::size_t hash) const noexcept {
+ return hash % m_mod;
+ }
+
+ std::size_t next_bucket_count() const {
+ if (m_mod == max_bucket_count()) {
+ throw std::length_error("The hash table exceeds its maximum size.");
+ }
+
+ const double next_bucket_count =
+ std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
+ if (!std::isnormal(next_bucket_count)) {
+ throw std::length_error("The hash table exceeds its maximum size.");
+ }
+
+ if (next_bucket_count > double(max_bucket_count())) {
+ return max_bucket_count();
+ } else {
+ return std::size_t(next_bucket_count);
+ }
+ }
+
+ std::size_t max_bucket_count() const { return MAX_BUCKET_COUNT; }
+
+ void clear() noexcept { m_mod = 1; }
+
+ private:
+ static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR =
+ 1.0 * GrowthFactor::num / GrowthFactor::den;
+ static const std::size_t MAX_BUCKET_COUNT =
+ std::size_t(double(std::numeric_limits<std::size_t>::max() /
+ REHASH_SIZE_MULTIPLICATION_FACTOR));
+
+ static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1,
+ "Growth factor should be >= 1.1.");
+
+ std::size_t m_mod;
+};
+
+/**
+ * Grow the hash table by using prime numbers as bucket count. Slower than
+ * tsl::sh::power_of_two_growth_policy in general but will probably distribute
+ * the values around better in the buckets with a poor hash function.
+ *
+ * To allow the compiler to optimize the modulo operation, a lookup table is
+ * used with constant primes numbers.
+ *
+ * With a switch the code would look like:
+ * \code
+ * switch(iprime) { // iprime is the current prime of the hash table
+ * case 0: hash % 5ul;
+ * break;
+ * case 1: hash % 17ul;
+ * break;
+ * case 2: hash % 29ul;
+ * break;
+ * ...
+ * }
+ * \endcode
+ *
+ * Due to the constant variable in the modulo the compiler is able to optimize
+ * the operation by a series of multiplications, substractions and shifts.
+ *
+ * The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
+ * * 5' in a 64 bits environment.
+ */
+class prime_growth_policy {
+ public:
+ explicit prime_growth_policy(std::size_t &min_bucket_count_in_out) {
+ auto it_prime = std::lower_bound(primes().begin(), primes().end(),
+ min_bucket_count_in_out);
+ if (it_prime == primes().end()) {
+ throw std::length_error("The hash table exceeds its maximum size.");
+ }
+
+ m_iprime =
+ static_cast<unsigned int>(std::distance(primes().begin(), it_prime));
+ if (min_bucket_count_in_out > 0) {
+ min_bucket_count_in_out = *it_prime;
+ } else {
+ min_bucket_count_in_out = 0;
+ }
+ }
+
+ std::size_t bucket_for_hash(std::size_t hash) const noexcept {
+ return mod_prime()[m_iprime](hash);
+ }
+
+ std::size_t next_bucket_count() const {
+ if (m_iprime + 1 >= primes().size()) {
+ throw std::length_error("The hash table exceeds its maximum size.");
+ }
+
+ return primes()[m_iprime + 1];
+ }
+
+ std::size_t max_bucket_count() const { return primes().back(); }
+
+ void clear() noexcept { m_iprime = 0; }
+
+ private:
+ static const std::array<std::size_t, 40> &primes() {
+ static const std::array<std::size_t, 40> PRIMES = {
+ {1ul, 5ul, 17ul, 29ul, 37ul,
+ 53ul, 67ul, 79ul, 97ul, 131ul,
+ 193ul, 257ul, 389ul, 521ul, 769ul,
+ 1031ul, 1543ul, 2053ul, 3079ul, 6151ul,
+ 12289ul, 24593ul, 49157ul, 98317ul, 196613ul,
+ 393241ul, 786433ul, 1572869ul, 3145739ul, 6291469ul,
+ 12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
+ 402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul}};
+
+ static_assert(
+ std::numeric_limits<decltype(m_iprime)>::max() >= PRIMES.size(),
+ "The type of m_iprime is not big enough.");
+
+ return PRIMES;
+ }
+
+ static const std::array<std::size_t (*)(std::size_t), 40> &mod_prime() {
+ // MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows
+ // for faster modulo as the compiler can optimize the modulo code better
+ // with a constant known at the compilation.
+ static const std::array<std::size_t (*)(std::size_t), 40> MOD_PRIME = {
+ {&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>,
+ &mod<7>, &mod<8>, &mod<9>, &mod<10>, &mod<11>, &mod<12>, &mod<13>,
+ &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
+ &mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>,
+ &mod<28>, &mod<29>, &mod<30>, &mod<31>, &mod<32>, &mod<33>, &mod<34>,
+ &mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>}};
+
+ return MOD_PRIME;
+ }
+
+ template <unsigned int IPrime>
+ static std::size_t mod(std::size_t hash) {
+ return hash % primes()[IPrime];
+ }
+
+ private:
+ unsigned int m_iprime;
+};
+
+} // namespace sh
+} // namespace tsl
+
+#endif
diff --git a/be/src/extern/diskann/include/tsl/sparse_hash.h b/be/src/extern/diskann/include/tsl/sparse_hash.h
new file mode 100644
index 0000000..e2115b4
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/sparse_hash.h
@@ -0,0 +1,2215 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef TSL_SPARSE_HASH_H
+#define TSL_SPARSE_HASH_H
+
+#include <algorithm>
+#include <cassert>
+#include <climits>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <limits>
+#include <memory>
+#include <stdexcept>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "sparse_growth_policy.h"
+
+#ifdef __INTEL_COMPILER
+#include <immintrin.h> // For _popcnt32 and _popcnt64
+#endif
+
+#ifdef _MSC_VER
+#include <intrin.h> // For __cpuid, __popcnt and __popcnt64
+#endif
+
+#ifdef TSL_DEBUG
+#define tsl_sh_assert(expr) assert(expr)
+#else
+#define tsl_sh_assert(expr) (static_cast<void>(0))
+#endif
+
+namespace tsl {
+
+namespace sh {
+enum class probing { linear, quadratic };
+
+enum class exception_safety { basic, strong };
+
+enum class sparsity { high, medium, low };
+} // namespace sh
+
+namespace detail_popcount {
+/**
+ * Define the popcount(ll) methods and pick-up the best depending on the
+ * compiler.
+ */
+
+// From Wikipedia: https://en.wikipedia.org/wiki/Hamming_weight
+inline int fallback_popcountll(unsigned long long int x) {
+ static_assert(
+ sizeof(unsigned long long int) == sizeof(std::uint64_t),
+ "sizeof(unsigned long long int) must be equal to sizeof(std::uint64_t). "
+ "Open a feature request if you need support for a platform where it "
+ "isn't the case.");
+
+ const std::uint64_t m1 = 0x5555555555555555ull;
+ const std::uint64_t m2 = 0x3333333333333333ull;
+ const std::uint64_t m4 = 0x0f0f0f0f0f0f0f0full;
+ const std::uint64_t h01 = 0x0101010101010101ull;
+
+ x -= (x >> 1ull) & m1;
+ x = (x & m2) + ((x >> 2ull) & m2);
+ x = (x + (x >> 4ull)) & m4;
+ return static_cast<int>((x * h01) >> (64ull - 8ull));
+}
+
+inline int fallback_popcount(unsigned int x) {
+ static_assert(sizeof(unsigned int) == sizeof(std::uint32_t) ||
+ sizeof(unsigned int) == sizeof(std::uint64_t),
+ "sizeof(unsigned int) must be equal to sizeof(std::uint32_t) "
+ "or sizeof(std::uint64_t). "
+ "Open a feature request if you need support for a platform "
+ "where it isn't the case.");
+
+ if (sizeof(unsigned int) == sizeof(std::uint32_t)) {
+ const std::uint32_t m1 = 0x55555555;
+ const std::uint32_t m2 = 0x33333333;
+ const std::uint32_t m4 = 0x0f0f0f0f;
+ const std::uint32_t h01 = 0x01010101;
+
+ x -= (x >> 1) & m1;
+ x = (x & m2) + ((x >> 2) & m2);
+ x = (x + (x >> 4)) & m4;
+ return static_cast<int>((x * h01) >> (32 - 8));
+ } else {
+ return fallback_popcountll(x);
+ }
+}
+
+#if defined(__clang__) || defined(__GNUC__)
+inline int popcountll(unsigned long long int value) {
+ return __builtin_popcountll(value);
+}
+
+inline int popcount(unsigned int value) { return __builtin_popcount(value); }
+
+#elif defined(_MSC_VER)
+/**
+ * We need to check for popcount support at runtime on Windows with __cpuid
+ * See https://msdn.microsoft.com/en-us/library/bb385231.aspx
+ */
+inline bool has_popcount_support() {
+ int cpu_infos[4];
+ __cpuid(cpu_infos, 1);
+ return (cpu_infos[2] & (1 << 23)) != 0;
+}
+
+inline int popcountll(unsigned long long int value) {
+#ifdef _WIN64
+ static_assert(
+ sizeof(unsigned long long int) == sizeof(std::int64_t),
+ "sizeof(unsigned long long int) must be equal to sizeof(std::int64_t). ");
+
+ static const bool has_popcount = has_popcount_support();
+ return has_popcount
+ ? static_cast<int>(__popcnt64(static_cast<std::int64_t>(value)))
+ : fallback_popcountll(value);
+#else
+ return fallback_popcountll(value);
+#endif
+}
+
+inline int popcount(unsigned int value) {
+ static_assert(sizeof(unsigned int) == sizeof(std::int32_t),
+ "sizeof(unsigned int) must be equal to sizeof(std::int32_t). ");
+
+ static const bool has_popcount = has_popcount_support();
+ return has_popcount
+ ? static_cast<int>(__popcnt(static_cast<std::int32_t>(value)))
+ : fallback_popcount(value);
+}
+
+#elif defined(__INTEL_COMPILER)
+inline int popcountll(unsigned long long int value) {
+ static_assert(sizeof(unsigned long long int) == sizeof(__int64), "");
+ return _popcnt64(static_cast<__int64>(value));
+}
+
+inline int popcount(unsigned int value) {
+ return _popcnt32(static_cast<int>(value));
+}
+
+#else
+inline int popcountll(unsigned long long int x) {
+ return fallback_popcountll(x);
+}
+
+inline int popcount(unsigned int x) { return fallback_popcount(x); }
+
+#endif
+} // namespace detail_popcount
+
+namespace detail_sparse_hash {
+
+template <typename T>
+struct make_void {
+ using type = void;
+};
+
+template <typename T, typename = void>
+struct has_is_transparent : std::false_type {};
+
+template <typename T>
+struct has_is_transparent<T,
+ typename make_void<typename T::is_transparent>::type>
+ : std::true_type {};
+
+template <typename U>
+struct is_power_of_two_policy : std::false_type {};
+
+template <std::size_t GrowthFactor>
+struct is_power_of_two_policy<tsl::sh::power_of_two_growth_policy<GrowthFactor>>
+ : std::true_type {};
+
+inline constexpr bool is_power_of_two(std::size_t value) {
+ return value != 0 && (value & (value - 1)) == 0;
+}
+
+inline std::size_t round_up_to_power_of_two(std::size_t value) {
+ if (is_power_of_two(value)) {
+ return value;
+ }
+
+ if (value == 0) {
+ return 1;
+ }
+
+ --value;
+ for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
+ value |= value >> i;
+ }
+
+ return value + 1;
+}
+
+template <typename T, typename U>
+static T numeric_cast(U value,
+ const char *error_message = "numeric_cast() failed.") {
+ T ret = static_cast<T>(value);
+ if (static_cast<U>(ret) != value) {
+ throw std::runtime_error(error_message);
+ }
+
+ const bool is_same_signedness =
+ (std::is_unsigned<T>::value && std::is_unsigned<U>::value) ||
+ (std::is_signed<T>::value && std::is_signed<U>::value);
+ if (!is_same_signedness && (ret < T{}) != (value < U{})) {
+ throw std::runtime_error(error_message);
+ }
+
+ return ret;
+}
+
+/**
+ * Fixed size type used to represent size_type values on serialization. Need to
+ * be big enough to represent a std::size_t on 32 and 64 bits platforms, and
+ * must be the same size on both platforms.
+ */
+using slz_size_type = std::uint64_t;
+static_assert(std::numeric_limits<slz_size_type>::max() >=
+ std::numeric_limits<std::size_t>::max(),
+ "slz_size_type must be >= std::size_t");
+
+template <class T, class Deserializer>
+static T deserialize_value(Deserializer &deserializer) {
+ // MSVC < 2017 is not conformant, circumvent the problem by removing the
+ // template keyword
+#if defined(_MSC_VER) && _MSC_VER < 1910
+ return deserializer.Deserializer::operator()<T>();
+#else
+ return deserializer.Deserializer::template operator()<T>();
+#endif
+}
+
+/**
+ * WARNING: the sparse_array class doesn't free the ressources allocated through
+ * the allocator passed in parameter in each method. You have to manually call
+ * `clear(Allocator&)` when you don't need a sparse_array object anymore.
+ *
+ * The reason is that the sparse_array doesn't store the allocator to avoid
+ * wasting space in each sparse_array when the allocator has a size > 0. It only
+ * allocates/deallocates objects with the allocator that is passed in parameter.
+ *
+ *
+ *
+ * Index denotes a value between [0, BITMAP_NB_BITS), it is an index similar to
+ * std::vector. Offset denotes the real position in `m_values` corresponding to
+ * an index.
+ *
+ * We are using raw pointers instead of std::vector to avoid loosing
+ * 2*sizeof(size_t) bytes to store the capacity and size of the vector in each
+ * sparse_array. We know we can only store up to BITMAP_NB_BITS elements in the
+ * array, we don't need such big types.
+ *
+ *
+ * T must be nothrow move constructible and/or copy constructible.
+ * Behaviour is undefined if the destructor of T throws an exception.
+ *
+ * See https://smerity.com/articles/2015/google_sparsehash.html for details on
+ * the idea behinds the implementation.
+ *
+ * TODO Check to use std::realloc and std::memmove when possible
+ */
+template <typename T, typename Allocator, tsl::sh::sparsity Sparsity>
+class sparse_array {
+ public:
+ using value_type = T;
+ using size_type = std::uint_least8_t;
+ using allocator_type = Allocator;
+ using iterator = value_type *;
+ using const_iterator = const value_type *;
+
+ private:
+ static const size_type CAPACITY_GROWTH_STEP =
+ (Sparsity == tsl::sh::sparsity::high) ? 2
+ : (Sparsity == tsl::sh::sparsity::medium)
+ ? 4
+ : 8; // (Sparsity == tsl::sh::sparsity::low)
+
+ /**
+ * Bitmap size configuration.
+ * Use 32 bits for the bitmap on 32-bits or less environnement as popcount on
+ * 64 bits numbers is slow on these environnement. Use 64 bits bitmap
+ * otherwise.
+ */
+#if SIZE_MAX <= UINT32_MAX
+ using bitmap_type = std::uint_least32_t;
+ static const std::size_t BITMAP_NB_BITS = 32;
+ static const std::size_t BUCKET_SHIFT = 5;
+#else
+ using bitmap_type = std::uint_least64_t;
+ static const std::size_t BITMAP_NB_BITS = 64;
+ static const std::size_t BUCKET_SHIFT = 6;
+#endif
+
+ static const std::size_t BUCKET_MASK = BITMAP_NB_BITS - 1;
+
+ static_assert(is_power_of_two(BITMAP_NB_BITS),
+ "BITMAP_NB_BITS must be a power of two.");
+ static_assert(std::numeric_limits<bitmap_type>::digits >= BITMAP_NB_BITS,
+ "bitmap_type must be able to hold at least BITMAP_NB_BITS.");
+ static_assert((std::size_t(1) << BUCKET_SHIFT) == BITMAP_NB_BITS,
+ "(1 << BUCKET_SHIFT) must be equal to BITMAP_NB_BITS.");
+ static_assert(std::numeric_limits<size_type>::max() >= BITMAP_NB_BITS,
+ "size_type must be big enough to hold BITMAP_NB_BITS.");
+ static_assert(std::is_unsigned<bitmap_type>::value,
+ "bitmap_type must be unsigned.");
+ static_assert((std::numeric_limits<bitmap_type>::max() & BUCKET_MASK) ==
+ BITMAP_NB_BITS - 1,
+ "");
+
+ public:
+ /**
+ * Map an ibucket [0, bucket_count) in the hash table to a sparse_ibucket
+ * (a sparse_array holds multiple buckets, so there is less sparse_array than
+ * bucket_count).
+ *
+ * The bucket ibucket is in
+ * m_sparse_buckets[sparse_ibucket(ibucket)][index_in_sparse_bucket(ibucket)]
+ * instead of something like m_buckets[ibucket] in a classical hash table.
+ */
+ static std::size_t sparse_ibucket(std::size_t ibucket) {
+ return ibucket >> BUCKET_SHIFT;
+ }
+
+ /**
+ * Map an ibucket [0, bucket_count) in the hash table to an index in the
+ * sparse_array which corresponds to the bucket.
+ *
+ * The bucket ibucket is in
+ * m_sparse_buckets[sparse_ibucket(ibucket)][index_in_sparse_bucket(ibucket)]
+ * instead of something like m_buckets[ibucket] in a classical hash table.
+ */
+ static typename sparse_array::size_type index_in_sparse_bucket(
+ std::size_t ibucket) {
+ return static_cast<typename sparse_array::size_type>(
+ ibucket & sparse_array::BUCKET_MASK);
+ }
+
+ static std::size_t nb_sparse_buckets(std::size_t bucket_count) noexcept {
+ if (bucket_count == 0) {
+ return 0;
+ }
+
+ return std::max<std::size_t>(
+ 1, sparse_ibucket(tsl::detail_sparse_hash::round_up_to_power_of_two(
+ bucket_count)));
+ }
+
+ public:
+ sparse_array() noexcept
+ : m_values(nullptr),
+ m_bitmap_vals(0),
+ m_bitmap_deleted_vals(0),
+ m_nb_elements(0),
+ m_capacity(0),
+ m_last_array(false) {}
+
+ explicit sparse_array(bool last_bucket) noexcept
+ : m_values(nullptr),
+ m_bitmap_vals(0),
+ m_bitmap_deleted_vals(0),
+ m_nb_elements(0),
+ m_capacity(0),
+ m_last_array(last_bucket) {}
+
+ sparse_array(size_type capacity, Allocator &alloc)
+ : m_values(nullptr),
+ m_bitmap_vals(0),
+ m_bitmap_deleted_vals(0),
+ m_nb_elements(0),
+ m_capacity(capacity),
+ m_last_array(false) {
+ if (m_capacity > 0) {
+ m_values = alloc.allocate(m_capacity);
+ tsl_sh_assert(m_values !=
+ nullptr); // allocate should throw if there is a failure
+ }
+ }
+
+ sparse_array(const sparse_array &other, Allocator &alloc)
+ : m_values(nullptr),
+ m_bitmap_vals(other.m_bitmap_vals),
+ m_bitmap_deleted_vals(other.m_bitmap_deleted_vals),
+ m_nb_elements(0),
+ m_capacity(other.m_capacity),
+ m_last_array(other.m_last_array) {
+ tsl_sh_assert(other.m_capacity >= other.m_nb_elements);
+ if (m_capacity == 0) {
+ return;
+ }
+
+ m_values = alloc.allocate(m_capacity);
+ tsl_sh_assert(m_values !=
+ nullptr); // allocate should throw if there is a failure
+ try {
+ for (size_type i = 0; i < other.m_nb_elements; i++) {
+ construct_value(alloc, m_values + i, other.m_values[i]);
+ m_nb_elements++;
+ }
+ } catch (...) {
+ clear(alloc);
+ throw;
+ }
+ }
+
+ sparse_array(sparse_array &&other) noexcept
+ : m_values(other.m_values),
+ m_bitmap_vals(other.m_bitmap_vals),
+ m_bitmap_deleted_vals(other.m_bitmap_deleted_vals),
+ m_nb_elements(other.m_nb_elements),
+ m_capacity(other.m_capacity),
+ m_last_array(other.m_last_array) {
+ other.m_values = nullptr;
+ other.m_bitmap_vals = 0;
+ other.m_bitmap_deleted_vals = 0;
+ other.m_nb_elements = 0;
+ other.m_capacity = 0;
+ }
+
+ sparse_array(sparse_array &&other, Allocator &alloc)
+ : m_values(nullptr),
+ m_bitmap_vals(other.m_bitmap_vals),
+ m_bitmap_deleted_vals(other.m_bitmap_deleted_vals),
+ m_nb_elements(0),
+ m_capacity(other.m_capacity),
+ m_last_array(other.m_last_array) {
+ tsl_sh_assert(other.m_capacity >= other.m_nb_elements);
+ if (m_capacity == 0) {
+ return;
+ }
+
+ m_values = alloc.allocate(m_capacity);
+ tsl_sh_assert(m_values !=
+ nullptr); // allocate should throw if there is a failure
+ try {
+ for (size_type i = 0; i < other.m_nb_elements; i++) {
+ construct_value(alloc, m_values + i, std::move(other.m_values[i]));
+ m_nb_elements++;
+ }
+ } catch (...) {
+ clear(alloc);
+ throw;
+ }
+ }
+
+ sparse_array &operator=(const sparse_array &) = delete;
+ sparse_array &operator=(sparse_array &&) = delete;
+
+ ~sparse_array() noexcept {
+ // The code that manages the sparse_array must have called clear before
+ // destruction. See documentation of sparse_array for more details.
+ tsl_sh_assert(m_capacity == 0 && m_nb_elements == 0 && m_values == nullptr);
+ }
+
+ iterator begin() noexcept { return m_values; }
+ iterator end() noexcept { return m_values + m_nb_elements; }
+ const_iterator begin() const noexcept { return cbegin(); }
+ const_iterator end() const noexcept { return cend(); }
+ const_iterator cbegin() const noexcept { return m_values; }
+ const_iterator cend() const noexcept { return m_values + m_nb_elements; }
+
+ bool empty() const noexcept { return m_nb_elements == 0; }
+
+ size_type size() const noexcept { return m_nb_elements; }
+
+ void clear(allocator_type &alloc) noexcept {
+ destroy_and_deallocate_values(alloc, m_values, m_nb_elements, m_capacity);
+
+ m_values = nullptr;
+ m_bitmap_vals = 0;
+ m_bitmap_deleted_vals = 0;
+ m_nb_elements = 0;
+ m_capacity = 0;
+ }
+
+ bool last() const noexcept { return m_last_array; }
+
+ void set_as_last() noexcept { m_last_array = true; }
+
+ bool has_value(size_type index) const noexcept {
+ tsl_sh_assert(index < BITMAP_NB_BITS);
+ return (m_bitmap_vals & (bitmap_type(1) << index)) != 0;
+ }
+
+ bool has_deleted_value(size_type index) const noexcept {
+ tsl_sh_assert(index < BITMAP_NB_BITS);
+ return (m_bitmap_deleted_vals & (bitmap_type(1) << index)) != 0;
+ }
+
+ iterator value(size_type index) noexcept {
+ tsl_sh_assert(has_value(index));
+ return m_values + index_to_offset(index);
+ }
+
+ const_iterator value(size_type index) const noexcept {
+ tsl_sh_assert(has_value(index));
+ return m_values + index_to_offset(index);
+ }
+
+ /**
+ * Return iterator to set value.
+ */
+ template <typename... Args>
+ iterator set(allocator_type &alloc, size_type index, Args &&...value_args) {
+ tsl_sh_assert(!has_value(index));
+
+ const size_type offset = index_to_offset(index);
+ insert_at_offset(alloc, offset, std::forward<Args>(value_args)...);
+
+ m_bitmap_vals = (m_bitmap_vals | (bitmap_type(1) << index));
+ m_bitmap_deleted_vals =
+ (m_bitmap_deleted_vals & ~(bitmap_type(1) << index));
+
+ m_nb_elements++;
+
+ tsl_sh_assert(has_value(index));
+ tsl_sh_assert(!has_deleted_value(index));
+
+ return m_values + offset;
+ }
+
+ iterator erase(allocator_type &alloc, iterator position) {
+ const size_type offset =
+ static_cast<size_type>(std::distance(begin(), position));
+ return erase(alloc, position, offset_to_index(offset));
+ }
+
+ // Return the next value or end if no next value
+ iterator erase(allocator_type &alloc, iterator position, size_type index) {
+ tsl_sh_assert(has_value(index));
+ tsl_sh_assert(!has_deleted_value(index));
+
+ const size_type offset =
+ static_cast<size_type>(std::distance(begin(), position));
+ erase_at_offset(alloc, offset);
+
+ m_bitmap_vals = (m_bitmap_vals & ~(bitmap_type(1) << index));
+ m_bitmap_deleted_vals = (m_bitmap_deleted_vals | (bitmap_type(1) << index));
+
+ m_nb_elements--;
+
+ tsl_sh_assert(!has_value(index));
+ tsl_sh_assert(has_deleted_value(index));
+
+ return m_values + offset;
+ }
+
+ void swap(sparse_array &other) {
+ using std::swap;
+
+ swap(m_values, other.m_values);
+ swap(m_bitmap_vals, other.m_bitmap_vals);
+ swap(m_bitmap_deleted_vals, other.m_bitmap_deleted_vals);
+ swap(m_nb_elements, other.m_nb_elements);
+ swap(m_capacity, other.m_capacity);
+ swap(m_last_array, other.m_last_array);
+ }
+
+ static iterator mutable_iterator(const_iterator pos) {
+ return const_cast<iterator>(pos);
+ }
+
+ template <class Serializer>
+ void serialize(Serializer &serializer) const {
+ const slz_size_type sparse_bucket_size = m_nb_elements;
+ serializer(sparse_bucket_size);
+
+ const slz_size_type bitmap_vals = m_bitmap_vals;
+ serializer(bitmap_vals);
+
+ const slz_size_type bitmap_deleted_vals = m_bitmap_deleted_vals;
+ serializer(bitmap_deleted_vals);
+
+ for (const value_type &value : *this) {
+ serializer(value);
+ }
+ }
+
+ template <class Deserializer>
+ static sparse_array deserialize_hash_compatible(Deserializer &deserializer,
+ Allocator &alloc) {
+ const slz_size_type sparse_bucket_size =
+ deserialize_value<slz_size_type>(deserializer);
+ const slz_size_type bitmap_vals =
+ deserialize_value<slz_size_type>(deserializer);
+ const slz_size_type bitmap_deleted_vals =
+ deserialize_value<slz_size_type>(deserializer);
+
+ if (sparse_bucket_size > BITMAP_NB_BITS) {
+ throw std::runtime_error(
+ "Deserialized sparse_bucket_size is too big for the platform. "
+ "Maximum should be BITMAP_NB_BITS.");
+ }
+
+ sparse_array sarray;
+ if (sparse_bucket_size == 0) {
+ return sarray;
+ }
+
+ sarray.m_bitmap_vals = numeric_cast<bitmap_type>(
+ bitmap_vals, "Deserialized bitmap_vals is too big.");
+ sarray.m_bitmap_deleted_vals = numeric_cast<bitmap_type>(
+ bitmap_deleted_vals, "Deserialized bitmap_deleted_vals is too big.");
+
+ sarray.m_capacity = numeric_cast<size_type>(
+ sparse_bucket_size, "Deserialized sparse_bucket_size is too big.");
+ sarray.m_values = alloc.allocate(sarray.m_capacity);
+
+ try {
+ for (size_type ivalue = 0; ivalue < sarray.m_capacity; ivalue++) {
+ construct_value(alloc, sarray.m_values + ivalue,
+ deserialize_value<value_type>(deserializer));
+ sarray.m_nb_elements++;
+ }
+ } catch (...) {
+ sarray.clear(alloc);
+ throw;
+ }
+
+ return sarray;
+ }
+
+ /**
+ * Deserialize the values of the bucket and insert them all in sparse_hash
+ * through sparse_hash.insert(...).
+ */
+ template <class Deserializer, class SparseHash>
+ static void deserialize_values_into_sparse_hash(Deserializer &deserializer,
+ SparseHash &sparse_hash) {
+ const slz_size_type sparse_bucket_size =
+ deserialize_value<slz_size_type>(deserializer);
+
+ const slz_size_type bitmap_vals =
+ deserialize_value<slz_size_type>(deserializer);
+ static_cast<void>(bitmap_vals); // Ignore, not needed
+
+ const slz_size_type bitmap_deleted_vals =
+ deserialize_value<slz_size_type>(deserializer);
+ static_cast<void>(bitmap_deleted_vals); // Ignore, not needed
+
+ for (slz_size_type ivalue = 0; ivalue < sparse_bucket_size; ivalue++) {
+ sparse_hash.insert(deserialize_value<value_type>(deserializer));
+ }
+ }
+
+ private:
+ template <typename... Args>
+ static void construct_value(allocator_type &alloc, value_type *value,
+ Args &&...value_args) {
+ std::allocator_traits<allocator_type>::construct(
+ alloc, value, std::forward<Args>(value_args)...);
+ }
+
+ static void destroy_value(allocator_type &alloc, value_type *value) noexcept {
+ std::allocator_traits<allocator_type>::destroy(alloc, value);
+ }
+
+ static void destroy_and_deallocate_values(
+ allocator_type &alloc, value_type *values, size_type nb_values,
+ size_type capacity_values) noexcept {
+ for (size_type i = 0; i < nb_values; i++) {
+ destroy_value(alloc, values + i);
+ }
+
+ alloc.deallocate(values, capacity_values);
+ }
+
+ static size_type popcount(bitmap_type val) noexcept {
+ if (sizeof(bitmap_type) <= sizeof(unsigned int)) {
+ return static_cast<size_type>(
+ tsl::detail_popcount::popcount(static_cast<unsigned int>(val)));
+ } else {
+ return static_cast<size_type>(tsl::detail_popcount::popcountll(val));
+ }
+ }
+
+ size_type index_to_offset(size_type index) const noexcept {
+ tsl_sh_assert(index < BITMAP_NB_BITS);
+ return popcount(m_bitmap_vals &
+ ((bitmap_type(1) << index) - bitmap_type(1)));
+ }
+
+ // TODO optimize
+ size_type offset_to_index(size_type offset) const noexcept {
+ tsl_sh_assert(offset < m_nb_elements);
+
+ bitmap_type bitmap_vals = m_bitmap_vals;
+ size_type index = 0;
+ size_type nb_ones = 0;
+
+ while (bitmap_vals != 0) {
+ if ((bitmap_vals & 0x1) == 1) {
+ if (nb_ones == offset) {
+ break;
+ }
+
+ nb_ones++;
+ }
+
+ index++;
+ bitmap_vals = bitmap_vals >> 1;
+ }
+
+ return index;
+ }
+
+ size_type next_capacity() const noexcept {
+ return static_cast<size_type>(m_capacity + CAPACITY_GROWTH_STEP);
+ }
+
+ /**
+ * Insertion
+ *
+ * Two situations:
+ * - Either we are in a situation where
+ * std::is_nothrow_move_constructible<value_type>::value is true. In this
+ * case, on insertion we just reallocate m_values when we reach its capacity
+ * (i.e. m_nb_elements == m_capacity), otherwise we just put the new value at
+ * its appropriate place. We can easily keep the strong exception guarantee as
+ * moving the values around is safe.
+ * - Otherwise we are in a situation where
+ * std::is_nothrow_move_constructible<value_type>::value is false. In this
+ * case on EACH insertion we allocate a new area of m_nb_elements + 1 where we
+ * copy the values of m_values into it and put the new value there. On
+ * success, we set m_values to this new area. Even if slower, it's the only
+ * way to preserve to strong exception guarantee.
+ */
+ template <typename... Args, typename U = value_type,
+ typename std::enable_if<
+ std::is_nothrow_move_constructible<U>::value>::type * = nullptr>
+ void insert_at_offset(allocator_type &alloc, size_type offset,
+ Args &&...value_args) {
+ if (m_nb_elements < m_capacity) {
+ insert_at_offset_no_realloc(alloc, offset,
+ std::forward<Args>(value_args)...);
+ } else {
+ insert_at_offset_realloc(alloc, offset, next_capacity(),
+ std::forward<Args>(value_args)...);
+ }
+ }
+
+ template <typename... Args, typename U = value_type,
+ typename std::enable_if<!std::is_nothrow_move_constructible<
+ U>::value>::type * = nullptr>
+ void insert_at_offset(allocator_type &alloc, size_type offset,
+ Args &&...value_args) {
+ insert_at_offset_realloc(alloc, offset, m_nb_elements + 1,
+ std::forward<Args>(value_args)...);
+ }
+
+ template <typename... Args, typename U = value_type,
+ typename std::enable_if<
+ std::is_nothrow_move_constructible<U>::value>::type * = nullptr>
+ void insert_at_offset_no_realloc(allocator_type &alloc, size_type offset,
+ Args &&...value_args) {
+ tsl_sh_assert(offset <= m_nb_elements);
+ tsl_sh_assert(m_nb_elements < m_capacity);
+
+ for (size_type i = m_nb_elements; i > offset; i--) {
+ construct_value(alloc, m_values + i, std::move(m_values[i - 1]));
+ destroy_value(alloc, m_values + i - 1);
+ }
+
+ try {
+ construct_value(alloc, m_values + offset,
+ std::forward<Args>(value_args)...);
+ } catch (...) {
+ for (size_type i = offset; i < m_nb_elements; i++) {
+ construct_value(alloc, m_values + i, std::move(m_values[i + 1]));
+ destroy_value(alloc, m_values + i + 1);
+ }
+ throw;
+ }
+ }
+
+ template <typename... Args, typename U = value_type,
+ typename std::enable_if<
+ std::is_nothrow_move_constructible<U>::value>::type * = nullptr>
+ void insert_at_offset_realloc(allocator_type &alloc, size_type offset,
+ size_type new_capacity, Args &&...value_args) {
+ tsl_sh_assert(new_capacity > m_nb_elements);
+
+ value_type *new_values = alloc.allocate(new_capacity);
+ // Allocate should throw if there is a failure
+ tsl_sh_assert(new_values != nullptr);
+
+ try {
+ construct_value(alloc, new_values + offset,
+ std::forward<Args>(value_args)...);
+ } catch (...) {
+ alloc.deallocate(new_values, new_capacity);
+ throw;
+ }
+
+ // Should not throw from here
+ for (size_type i = 0; i < offset; i++) {
+ construct_value(alloc, new_values + i, std::move(m_values[i]));
+ }
+
+ for (size_type i = offset; i < m_nb_elements; i++) {
+ construct_value(alloc, new_values + i + 1, std::move(m_values[i]));
+ }
+
+ destroy_and_deallocate_values(alloc, m_values, m_nb_elements, m_capacity);
+
+ m_values = new_values;
+ m_capacity = new_capacity;
+ }
+
+ template <typename... Args, typename U = value_type,
+ typename std::enable_if<!std::is_nothrow_move_constructible<
+ U>::value>::type * = nullptr>
+ void insert_at_offset_realloc(allocator_type &alloc, size_type offset,
+ size_type new_capacity, Args &&...value_args) {
+ tsl_sh_assert(new_capacity > m_nb_elements);
+
+ value_type *new_values = alloc.allocate(new_capacity);
+ // Allocate should throw if there is a failure
+ tsl_sh_assert(new_values != nullptr);
+
+ size_type nb_new_values = 0;
+ try {
+ for (size_type i = 0; i < offset; i++) {
+ construct_value(alloc, new_values + i, m_values[i]);
+ nb_new_values++;
+ }
+
+ construct_value(alloc, new_values + offset,
+ std::forward<Args>(value_args)...);
+ nb_new_values++;
+
+ for (size_type i = offset; i < m_nb_elements; i++) {
+ construct_value(alloc, new_values + i + 1, m_values[i]);
+ nb_new_values++;
+ }
+ } catch (...) {
+ destroy_and_deallocate_values(alloc, new_values, nb_new_values,
+ new_capacity);
+ throw;
+ }
+
+ tsl_sh_assert(nb_new_values == m_nb_elements + 1);
+
+ destroy_and_deallocate_values(alloc, m_values, m_nb_elements, m_capacity);
+
+ m_values = new_values;
+ m_capacity = new_capacity;
+ }
+
+ /**
+ * Erasure
+ *
+ * Two situations:
+ * - Either we are in a situation where
+ * std::is_nothrow_move_constructible<value_type>::value is true. Simply
+ * destroy the value and left-shift move the value on the right of offset.
+ * - Otherwise we are in a situation where
+ * std::is_nothrow_move_constructible<value_type>::value is false. Copy all
+ * the values except the one at offset into a new heap area. On success, we
+ * set m_values to this new area. Even if slower, it's the only way to
+ * preserve to strong exception guarantee.
+ */
+ template <typename... Args, typename U = value_type,
+ typename std::enable_if<
+ std::is_nothrow_move_constructible<U>::value>::type * = nullptr>
+ void erase_at_offset(allocator_type &alloc, size_type offset) noexcept {
+ tsl_sh_assert(offset < m_nb_elements);
+
+ destroy_value(alloc, m_values + offset);
+
+ for (size_type i = offset + 1; i < m_nb_elements; i++) {
+ construct_value(alloc, m_values + i - 1, std::move(m_values[i]));
+ destroy_value(alloc, m_values + i);
+ }
+ }
+
+ template <typename... Args, typename U = value_type,
+ typename std::enable_if<!std::is_nothrow_move_constructible<
+ U>::value>::type * = nullptr>
+ void erase_at_offset(allocator_type &alloc, size_type offset) {
+ tsl_sh_assert(offset < m_nb_elements);
+
+ // Erasing the last element, don't need to reallocate. We keep the capacity.
+ if (offset + 1 == m_nb_elements) {
+ destroy_value(alloc, m_values + offset);
+ return;
+ }
+
+ tsl_sh_assert(m_nb_elements > 1);
+ const size_type new_capacity = m_nb_elements - 1;
+
+ value_type *new_values = alloc.allocate(new_capacity);
+ // Allocate should throw if there is a failure
+ tsl_sh_assert(new_values != nullptr);
+
+ size_type nb_new_values = 0;
+ try {
+ for (size_type i = 0; i < m_nb_elements; i++) {
+ if (i != offset) {
+ construct_value(alloc, new_values + nb_new_values, m_values[i]);
+ nb_new_values++;
+ }
+ }
+ } catch (...) {
+ destroy_and_deallocate_values(alloc, new_values, nb_new_values,
+ new_capacity);
+ throw;
+ }
+
+ tsl_sh_assert(nb_new_values == m_nb_elements - 1);
+
+ destroy_and_deallocate_values(alloc, m_values, m_nb_elements, m_capacity);
+
+ m_values = new_values;
+ m_capacity = new_capacity;
+ }
+
+ private:
+ value_type *m_values;
+
+ bitmap_type m_bitmap_vals;
+ bitmap_type m_bitmap_deleted_vals;
+
+ size_type m_nb_elements;
+ size_type m_capacity;
+ bool m_last_array;
+};
+
+/**
+ * Internal common class used by `sparse_map` and `sparse_set`.
+ *
+ * `ValueType` is what will be stored by `sparse_hash` (usually `std::pair<Key,
+ * T>` for map and `Key` for set).
+ *
+ * `KeySelect` should be a `FunctionObject` which takes a `ValueType` in
+ * parameter and returns a reference to the key.
+ *
+ * `ValueSelect` should be a `FunctionObject` which takes a `ValueType` in
+ * parameter and returns a reference to the value. `ValueSelect` should be void
+ * if there is no value (in a set for example).
+ *
+ * The strong exception guarantee only holds if `ExceptionSafety` is set to
+ * `tsl::sh::exception_safety::strong`.
+ *
+ * `ValueType` must be nothrow move constructible and/or copy constructible.
+ * Behaviour is undefined if the destructor of `ValueType` throws.
+ *
+ *
+ * The class holds its buckets in a 2-dimensional fashion. Instead of having a
+ * linear `std::vector<bucket>` for [0, bucket_count) where each bucket stores
+ * one value, we have a `std::vector<sparse_array>` (m_sparse_buckets_data)
+ * where each `sparse_array` stores multiple values (up to
+ * `sparse_array::BITMAP_NB_BITS`). To convert a one dimensional `ibucket`
+ * position to a position in `std::vector<sparse_array>` and a position in
+ * `sparse_array`, use respectively the methods
+ * `sparse_array::sparse_ibucket(ibucket)` and
+ * `sparse_array::index_in_sparse_bucket(ibucket)`.
+ */
+template <class ValueType, class KeySelect, class ValueSelect, class Hash,
+ class KeyEqual, class Allocator, class GrowthPolicy,
+ tsl::sh::exception_safety ExceptionSafety, tsl::sh::sparsity Sparsity,
+ tsl::sh::probing Probing>
+class sparse_hash : private Allocator,
+ private Hash,
+ private KeyEqual,
+ private GrowthPolicy {
+ private:
+ template <typename U>
+ using has_mapped_type =
+ typename std::integral_constant<bool, !std::is_same<U, void>::value>;
+
+ static_assert(
+ noexcept(std::declval<GrowthPolicy>().bucket_for_hash(std::size_t(0))),
+ "GrowthPolicy::bucket_for_hash must be noexcept.");
+ static_assert(noexcept(std::declval<GrowthPolicy>().clear()),
+ "GrowthPolicy::clear must be noexcept.");
+
+ public:
+ template <bool IsConst>
+ class sparse_iterator;
+
+ using key_type = typename KeySelect::key_type;
+ using value_type = ValueType;
+ using size_type = std::size_t;
+ using difference_type = std::ptrdiff_t;
+ using hasher = Hash;
+ using key_equal = KeyEqual;
+ using allocator_type = Allocator;
+ using reference = value_type &;
+ using const_reference = const value_type &;
+ using pointer = value_type *;
+ using const_pointer = const value_type *;
+ using iterator = sparse_iterator<false>;
+ using const_iterator = sparse_iterator<true>;
+
+ private:
+ using sparse_array =
+ tsl::detail_sparse_hash::sparse_array<ValueType, Allocator, Sparsity>;
+
+ using sparse_buckets_allocator = typename std::allocator_traits<
+ allocator_type>::template rebind_alloc<sparse_array>;
+ using sparse_buckets_container =
+ std::vector<sparse_array, sparse_buckets_allocator>;
+
+ public:
+ /**
+ * The `operator*()` and `operator->()` methods return a const reference and
+ * const pointer respectively to the stored value type (`Key` for a set,
+ * `std::pair<Key, T>` for a map).
+ *
+ * In case of a map, to get a mutable reference to the value `T` associated to
+ * a key (the `.second` in the stored pair), you have to call `value()`.
+ */
+ template <bool IsConst>
+ class sparse_iterator {
+ friend class sparse_hash;
+
+ private:
+ using sparse_bucket_iterator = typename std::conditional<
+ IsConst, typename sparse_buckets_container::const_iterator,
+ typename sparse_buckets_container::iterator>::type;
+
+ using sparse_array_iterator =
+ typename std::conditional<IsConst,
+ typename sparse_array::const_iterator,
+ typename sparse_array::iterator>::type;
+
+ /**
+ * sparse_array_it should be nullptr if sparse_bucket_it ==
+ * m_sparse_buckets_data.end(). (TODO better way?)
+ */
+ sparse_iterator(sparse_bucket_iterator sparse_bucket_it,
+ sparse_array_iterator sparse_array_it)
+ : m_sparse_buckets_it(sparse_bucket_it),
+ m_sparse_array_it(sparse_array_it) {}
+
+ public:
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = const typename sparse_hash::value_type;
+ using difference_type = std::ptrdiff_t;
+ using reference = value_type &;
+ using pointer = value_type *;
+
+ sparse_iterator() noexcept {}
+
+ // Copy constructor from iterator to const_iterator.
+ template <bool TIsConst = IsConst,
+ typename std::enable_if<TIsConst>::type * = nullptr>
+ sparse_iterator(const sparse_iterator<!TIsConst> &other) noexcept
+ : m_sparse_buckets_it(other.m_sparse_buckets_it),
+ m_sparse_array_it(other.m_sparse_array_it) {}
+
+ sparse_iterator(const sparse_iterator &other) = default;
+ sparse_iterator(sparse_iterator &&other) = default;
+ sparse_iterator &operator=(const sparse_iterator &other) = default;
+ sparse_iterator &operator=(sparse_iterator &&other) = default;
+
+ const typename sparse_hash::key_type &key() const {
+ return KeySelect()(*m_sparse_array_it);
+ }
+
+ template <class U = ValueSelect,
+ typename std::enable_if<has_mapped_type<U>::value &&
+ IsConst>::type * = nullptr>
+ const typename U::value_type &value() const {
+ return U()(*m_sparse_array_it);
+ }
+
+ template <class U = ValueSelect,
+ typename std::enable_if<has_mapped_type<U>::value &&
+ !IsConst>::type * = nullptr>
+ typename U::value_type &value() {
+ return U()(*m_sparse_array_it);
+ }
+
+ reference operator*() const { return *m_sparse_array_it; }
+
+ pointer operator->() const { return std::addressof(*m_sparse_array_it); }
+
+ sparse_iterator &operator++() {
+ tsl_sh_assert(m_sparse_array_it != nullptr);
+ ++m_sparse_array_it;
+
+ if (m_sparse_array_it == m_sparse_buckets_it->end()) {
+ do {
+ if (m_sparse_buckets_it->last()) {
+ ++m_sparse_buckets_it;
+ m_sparse_array_it = nullptr;
+ return *this;
+ }
+
+ ++m_sparse_buckets_it;
+ } while (m_sparse_buckets_it->empty());
+
+ m_sparse_array_it = m_sparse_buckets_it->begin();
+ }
+
+ return *this;
+ }
+
+ sparse_iterator operator++(int) {
+ sparse_iterator tmp(*this);
+ ++*this;
+
+ return tmp;
+ }
+
+ friend bool operator==(const sparse_iterator &lhs,
+ const sparse_iterator &rhs) {
+ return lhs.m_sparse_buckets_it == rhs.m_sparse_buckets_it &&
+ lhs.m_sparse_array_it == rhs.m_sparse_array_it;
+ }
+
+ friend bool operator!=(const sparse_iterator &lhs,
+ const sparse_iterator &rhs) {
+ return !(lhs == rhs);
+ }
+
+ private:
+ sparse_bucket_iterator m_sparse_buckets_it;
+ sparse_array_iterator m_sparse_array_it;
+ };
+
+ public:
+ sparse_hash(size_type bucket_count, const Hash &hash, const KeyEqual &equal,
+ const Allocator &alloc, float max_load_factor)
+ : Allocator(alloc),
+ Hash(hash),
+ KeyEqual(equal),
+ GrowthPolicy(bucket_count),
+ m_sparse_buckets_data(alloc),
+ m_sparse_buckets(static_empty_sparse_bucket_ptr()),
+ m_bucket_count(bucket_count),
+ m_nb_elements(0),
+ m_nb_deleted_buckets(0) {
+ if (m_bucket_count > max_bucket_count()) {
+ throw std::length_error("The map exceeds its maximum size.");
+ }
+
+ if (m_bucket_count > 0) {
+ /*
+ * We can't use the `vector(size_type count, const Allocator& alloc)`
+ * constructor as it's only available in C++14 and we need to support
+ * C++11. We thus must resize after using the `vector(const Allocator&
+ * alloc)` constructor.
+ *
+ * We can't use `vector(size_type count, const T& value, const Allocator&
+ * alloc)` as it requires the value T to be copyable.
+ */
+ m_sparse_buckets_data.resize(
+ sparse_array::nb_sparse_buckets(bucket_count));
+ m_sparse_buckets = m_sparse_buckets_data.data();
+
+ tsl_sh_assert(!m_sparse_buckets_data.empty());
+ m_sparse_buckets_data.back().set_as_last();
+ }
+
+ this->max_load_factor(max_load_factor);
+
+ // Check in the constructor instead of outside of a function to avoid
+ // compilation issues when value_type is not complete.
+ static_assert(std::is_nothrow_move_constructible<value_type>::value ||
+ std::is_copy_constructible<value_type>::value,
+ "Key, and T if present, must be nothrow move constructible "
+ "and/or copy constructible.");
+ }
+
+ ~sparse_hash() { clear(); }
+
+ sparse_hash(const sparse_hash &other)
+ : Allocator(std::allocator_traits<
+ Allocator>::select_on_container_copy_construction(other)),
+ Hash(other),
+ KeyEqual(other),
+ GrowthPolicy(other),
+ m_sparse_buckets_data(
+ std::allocator_traits<
+ Allocator>::select_on_container_copy_construction(other)),
+ m_bucket_count(other.m_bucket_count),
+ m_nb_elements(other.m_nb_elements),
+ m_nb_deleted_buckets(other.m_nb_deleted_buckets),
+ m_load_threshold_rehash(other.m_load_threshold_rehash),
+ m_load_threshold_clear_deleted(other.m_load_threshold_clear_deleted),
+ m_max_load_factor(other.m_max_load_factor) {
+ copy_buckets_from(other),
+ m_sparse_buckets = m_sparse_buckets_data.empty()
+ ? static_empty_sparse_bucket_ptr()
+ : m_sparse_buckets_data.data();
+ }
+
+ sparse_hash(sparse_hash &&other) noexcept(
+ std::is_nothrow_move_constructible<Allocator>::value
+ &&std::is_nothrow_move_constructible<Hash>::value
+ &&std::is_nothrow_move_constructible<KeyEqual>::value
+ &&std::is_nothrow_move_constructible<GrowthPolicy>::value
+ &&std::is_nothrow_move_constructible<
+ sparse_buckets_container>::value)
+ : Allocator(std::move(other)),
+ Hash(std::move(other)),
+ KeyEqual(std::move(other)),
+ GrowthPolicy(std::move(other)),
+ m_sparse_buckets_data(std::move(other.m_sparse_buckets_data)),
+ m_sparse_buckets(m_sparse_buckets_data.empty()
+ ? static_empty_sparse_bucket_ptr()
+ : m_sparse_buckets_data.data()),
+ m_bucket_count(other.m_bucket_count),
+ m_nb_elements(other.m_nb_elements),
+ m_nb_deleted_buckets(other.m_nb_deleted_buckets),
+ m_load_threshold_rehash(other.m_load_threshold_rehash),
+ m_load_threshold_clear_deleted(other.m_load_threshold_clear_deleted),
+ m_max_load_factor(other.m_max_load_factor) {
+ other.GrowthPolicy::clear();
+ other.m_sparse_buckets_data.clear();
+ other.m_sparse_buckets = static_empty_sparse_bucket_ptr();
+ other.m_bucket_count = 0;
+ other.m_nb_elements = 0;
+ other.m_nb_deleted_buckets = 0;
+ other.m_load_threshold_rehash = 0;
+ other.m_load_threshold_clear_deleted = 0;
+ }
+
+ sparse_hash &operator=(const sparse_hash &other) {
+ if (this != &other) {
+ clear();
+
+ if (std::allocator_traits<
+ Allocator>::propagate_on_container_copy_assignment::value) {
+ Allocator::operator=(other);
+ }
+
+ Hash::operator=(other);
+ KeyEqual::operator=(other);
+ GrowthPolicy::operator=(other);
+
+ if (std::allocator_traits<
+ Allocator>::propagate_on_container_copy_assignment::value) {
+ m_sparse_buckets_data =
+ sparse_buckets_container(static_cast<const Allocator &>(other));
+ } else {
+ if (m_sparse_buckets_data.size() !=
+ other.m_sparse_buckets_data.size()) {
+ m_sparse_buckets_data =
+ sparse_buckets_container(static_cast<const Allocator &>(*this));
+ } else {
+ m_sparse_buckets_data.clear();
+ }
+ }
+
+ copy_buckets_from(other);
+ m_sparse_buckets = m_sparse_buckets_data.empty()
+ ? static_empty_sparse_bucket_ptr()
+ : m_sparse_buckets_data.data();
+
+ m_bucket_count = other.m_bucket_count;
+ m_nb_elements = other.m_nb_elements;
+ m_nb_deleted_buckets = other.m_nb_deleted_buckets;
+ m_load_threshold_rehash = other.m_load_threshold_rehash;
+ m_load_threshold_clear_deleted = other.m_load_threshold_clear_deleted;
+ m_max_load_factor = other.m_max_load_factor;
+ }
+
+ return *this;
+ }
+
+ sparse_hash &operator=(sparse_hash &&other) {
+ clear();
+
+ if (std::allocator_traits<
+ Allocator>::propagate_on_container_move_assignment::value) {
+ static_cast<Allocator &>(*this) =
+ std::move(static_cast<Allocator &>(other));
+ m_sparse_buckets_data = std::move(other.m_sparse_buckets_data);
+ } else if (static_cast<Allocator &>(*this) !=
+ static_cast<Allocator &>(other)) {
+ move_buckets_from(std::move(other));
+ } else {
+ static_cast<Allocator &>(*this) =
+ std::move(static_cast<Allocator &>(other));
+ m_sparse_buckets_data = std::move(other.m_sparse_buckets_data);
+ }
+
+ m_sparse_buckets = m_sparse_buckets_data.empty()
+ ? static_empty_sparse_bucket_ptr()
+ : m_sparse_buckets_data.data();
+
+ static_cast<Hash &>(*this) = std::move(static_cast<Hash &>(other));
+ static_cast<KeyEqual &>(*this) = std::move(static_cast<KeyEqual &>(other));
+ static_cast<GrowthPolicy &>(*this) =
+ std::move(static_cast<GrowthPolicy &>(other));
+ m_bucket_count = other.m_bucket_count;
+ m_nb_elements = other.m_nb_elements;
+ m_nb_deleted_buckets = other.m_nb_deleted_buckets;
+ m_load_threshold_rehash = other.m_load_threshold_rehash;
+ m_load_threshold_clear_deleted = other.m_load_threshold_clear_deleted;
+ m_max_load_factor = other.m_max_load_factor;
+
+ other.GrowthPolicy::clear();
+ other.m_sparse_buckets_data.clear();
+ other.m_sparse_buckets = static_empty_sparse_bucket_ptr();
+ other.m_bucket_count = 0;
+ other.m_nb_elements = 0;
+ other.m_nb_deleted_buckets = 0;
+ other.m_load_threshold_rehash = 0;
+ other.m_load_threshold_clear_deleted = 0;
+
+ return *this;
+ }
+
+ allocator_type get_allocator() const {
+ return static_cast<const Allocator &>(*this);
+ }
+
+ /*
+ * Iterators
+ */
+ iterator begin() noexcept {
+ auto begin = m_sparse_buckets_data.begin();
+ while (begin != m_sparse_buckets_data.end() && begin->empty()) {
+ ++begin;
+ }
+
+ return iterator(begin, (begin != m_sparse_buckets_data.end())
+ ? begin->begin()
+ : nullptr);
+ }
+
+ const_iterator begin() const noexcept { return cbegin(); }
+
+ const_iterator cbegin() const noexcept {
+ auto begin = m_sparse_buckets_data.cbegin();
+ while (begin != m_sparse_buckets_data.cend() && begin->empty()) {
+ ++begin;
+ }
+
+ return const_iterator(begin, (begin != m_sparse_buckets_data.cend())
+ ? begin->cbegin()
+ : nullptr);
+ }
+
+ iterator end() noexcept {
+ return iterator(m_sparse_buckets_data.end(), nullptr);
+ }
+
+ const_iterator end() const noexcept { return cend(); }
+
+ const_iterator cend() const noexcept {
+ return const_iterator(m_sparse_buckets_data.cend(), nullptr);
+ }
+
+ /*
+ * Capacity
+ */
+ bool empty() const noexcept { return m_nb_elements == 0; }
+
+ size_type size() const noexcept { return m_nb_elements; }
+
+ size_type max_size() const noexcept {
+ return std::min(std::allocator_traits<Allocator>::max_size(),
+ m_sparse_buckets_data.max_size());
+ }
+
+ /*
+ * Modifiers
+ */
+ void clear() noexcept {
+ for (auto &bucket : m_sparse_buckets_data) {
+ bucket.clear(*this);
+ }
+
+ m_nb_elements = 0;
+ m_nb_deleted_buckets = 0;
+ }
+
+ template <typename P>
+ std::pair<iterator, bool> insert(P &&value) {
+ return insert_impl(KeySelect()(value), std::forward<P>(value));
+ }
+
+ template <typename P>
+ iterator insert_hint(const_iterator hint, P &&value) {
+ if (hint != cend() &&
+ compare_keys(KeySelect()(*hint), KeySelect()(value))) {
+ return mutable_iterator(hint);
+ }
+
+ return insert(std::forward<P>(value)).first;
+ }
+
+ template <class InputIt>
+ void insert(InputIt first, InputIt last) {
+ if (std::is_base_of<
+ std::forward_iterator_tag,
+ typename std::iterator_traits<InputIt>::iterator_category>::value) {
+ const auto nb_elements_insert = std::distance(first, last);
+ const size_type nb_free_buckets = m_load_threshold_rehash - size();
+ tsl_sh_assert(m_load_threshold_rehash >= size());
+
+ if (nb_elements_insert > 0 &&
+ nb_free_buckets < size_type(nb_elements_insert)) {
+ reserve(size() + size_type(nb_elements_insert));
+ }
+ }
+
+ for (; first != last; ++first) {
+ insert(*first);
+ }
+ }
+
+ template <class K, class M>
+ std::pair<iterator, bool> insert_or_assign(K &&key, M &&obj) {
+ auto it = try_emplace(std::forward<K>(key), std::forward<M>(obj));
+ if (!it.second) {
+ it.first.value() = std::forward<M>(obj);
+ }
+
+ return it;
+ }
+
+ template <class K, class M>
+ iterator insert_or_assign(const_iterator hint, K &&key, M &&obj) {
+ if (hint != cend() && compare_keys(KeySelect()(*hint), key)) {
+ auto it = mutable_iterator(hint);
+ it.value() = std::forward<M>(obj);
+
+ return it;
+ }
+
+ return insert_or_assign(std::forward<K>(key), std::forward<M>(obj)).first;
+ }
+
+ template <class... Args>
+ std::pair<iterator, bool> emplace(Args &&...args) {
+ return insert(value_type(std::forward<Args>(args)...));
+ }
+
+ template <class... Args>
+ iterator emplace_hint(const_iterator hint, Args &&...args) {
+ return insert_hint(hint, value_type(std::forward<Args>(args)...));
+ }
+
+ template <class K, class... Args>
+ std::pair<iterator, bool> try_emplace(K &&key, Args &&...args) {
+ return insert_impl(key, std::piecewise_construct,
+ std::forward_as_tuple(std::forward<K>(key)),
+ std::forward_as_tuple(std::forward<Args>(args)...));
+ }
+
+ template <class K, class... Args>
+ iterator try_emplace_hint(const_iterator hint, K &&key, Args &&...args) {
+ if (hint != cend() && compare_keys(KeySelect()(*hint), key)) {
+ return mutable_iterator(hint);
+ }
+
+ return try_emplace(std::forward<K>(key), std::forward<Args>(args)...).first;
+ }
+
+ /**
+ * Here to avoid `template<class K> size_type erase(const K& key)` being used
+ * when we use an iterator instead of a const_iterator.
+ */
+ iterator erase(iterator pos) {
+ tsl_sh_assert(pos != end() && m_nb_elements > 0);
+ auto it_sparse_array_next =
+ pos.m_sparse_buckets_it->erase(*this, pos.m_sparse_array_it);
+ m_nb_elements--;
+ m_nb_deleted_buckets++;
+
+ if (it_sparse_array_next == pos.m_sparse_buckets_it->end()) {
+ auto it_sparse_buckets_next = pos.m_sparse_buckets_it;
+ do {
+ ++it_sparse_buckets_next;
+ } while (it_sparse_buckets_next != m_sparse_buckets_data.end() &&
+ it_sparse_buckets_next->empty());
+
+ if (it_sparse_buckets_next == m_sparse_buckets_data.end()) {
+ return end();
+ } else {
+ return iterator(it_sparse_buckets_next,
+ it_sparse_buckets_next->begin());
+ }
+ } else {
+ return iterator(pos.m_sparse_buckets_it, it_sparse_array_next);
+ }
+ }
+
+ iterator erase(const_iterator pos) { return erase(mutable_iterator(pos)); }
+
+ iterator erase(const_iterator first, const_iterator last) {
+ if (first == last) {
+ return mutable_iterator(first);
+ }
+
+ // TODO Optimize, could avoid the call to std::distance.
+ const size_type nb_elements_to_erase =
+ static_cast<size_type>(std::distance(first, last));
+ auto to_delete = mutable_iterator(first);
+ for (size_type i = 0; i < nb_elements_to_erase; i++) {
+ to_delete = erase(to_delete);
+ }
+
+ return to_delete;
+ }
+
+ template <class K>
+ size_type erase(const K &key) {
+ return erase(key, hash_key(key));
+ }
+
+ template <class K>
+ size_type erase(const K &key, std::size_t hash) {
+ return erase_impl(key, hash);
+ }
+
+ void swap(sparse_hash &other) {
+ using std::swap;
+
+ if (std::allocator_traits<Allocator>::propagate_on_container_swap::value) {
+ swap(static_cast<Allocator &>(*this), static_cast<Allocator &>(other));
+ } else {
+ tsl_sh_assert(static_cast<Allocator &>(*this) ==
+ static_cast<Allocator &>(other));
+ }
+
+ swap(static_cast<Hash &>(*this), static_cast<Hash &>(other));
+ swap(static_cast<KeyEqual &>(*this), static_cast<KeyEqual &>(other));
+ swap(static_cast<GrowthPolicy &>(*this),
+ static_cast<GrowthPolicy &>(other));
+ swap(m_sparse_buckets_data, other.m_sparse_buckets_data);
+ swap(m_sparse_buckets, other.m_sparse_buckets);
+ swap(m_bucket_count, other.m_bucket_count);
+ swap(m_nb_elements, other.m_nb_elements);
+ swap(m_nb_deleted_buckets, other.m_nb_deleted_buckets);
+ swap(m_load_threshold_rehash, other.m_load_threshold_rehash);
+ swap(m_load_threshold_clear_deleted, other.m_load_threshold_clear_deleted);
+ swap(m_max_load_factor, other.m_max_load_factor);
+ }
+
+ /*
+ * Lookup
+ */
+ template <
+ class K, class U = ValueSelect,
+ typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
+ typename U::value_type &at(const K &key) {
+ return at(key, hash_key(key));
+ }
+
+ template <
+ class K, class U = ValueSelect,
+ typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
+ typename U::value_type &at(const K &key, std::size_t hash) {
+ return const_cast<typename U::value_type &>(
+ static_cast<const sparse_hash *>(this)->at(key, hash));
+ }
+
+ template <
+ class K, class U = ValueSelect,
+ typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
+ const typename U::value_type &at(const K &key) const {
+ return at(key, hash_key(key));
+ }
+
+ template <
+ class K, class U = ValueSelect,
+ typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
+ const typename U::value_type &at(const K &key, std::size_t hash) const {
+ auto it = find(key, hash);
+ if (it != cend()) {
+ return it.value();
+ } else {
+ throw std::out_of_range("Couldn't find key.");
+ }
+ }
+
+ template <
+ class K, class U = ValueSelect,
+ typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
+ typename U::value_type &operator[](K &&key) {
+ return try_emplace(std::forward<K>(key)).first.value();
+ }
+
+ template <class K>
+ bool contains(const K &key) const {
+ return contains(key, hash_key(key));
+ }
+
+ template <class K>
+ bool contains(const K &key, std::size_t hash) const {
+ return count(key, hash) != 0;
+ }
+
+ template <class K>
+ size_type count(const K &key) const {
+ return count(key, hash_key(key));
+ }
+
+ template <class K>
+ size_type count(const K &key, std::size_t hash) const {
+ if (find(key, hash) != cend()) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ template <class K>
+ iterator find(const K &key) {
+ return find_impl(key, hash_key(key));
+ }
+
+ template <class K>
+ iterator find(const K &key, std::size_t hash) {
+ return find_impl(key, hash);
+ }
+
+ template <class K>
+ const_iterator find(const K &key) const {
+ return find_impl(key, hash_key(key));
+ }
+
+ template <class K>
+ const_iterator find(const K &key, std::size_t hash) const {
+ return find_impl(key, hash);
+ }
+
+ template <class K>
+ std::pair<iterator, iterator> equal_range(const K &key) {
+ return equal_range(key, hash_key(key));
+ }
+
+ template <class K>
+ std::pair<iterator, iterator> equal_range(const K &key, std::size_t hash) {
+ iterator it = find(key, hash);
+ return std::make_pair(it, (it == end()) ? it : std::next(it));
+ }
+
+ template <class K>
+ std::pair<const_iterator, const_iterator> equal_range(const K &key) const {
+ return equal_range(key, hash_key(key));
+ }
+
+ template <class K>
+ std::pair<const_iterator, const_iterator> equal_range(
+ const K &key, std::size_t hash) const {
+ const_iterator it = find(key, hash);
+ return std::make_pair(it, (it == cend()) ? it : std::next(it));
+ }
+
+ /*
+ * Bucket interface
+ */
+ size_type bucket_count() const { return m_bucket_count; }
+
+ size_type max_bucket_count() const {
+ return m_sparse_buckets_data.max_size();
+ }
+
+ /*
+ * Hash policy
+ */
+ float load_factor() const {
+ if (bucket_count() == 0) {
+ return 0;
+ }
+
+ return float(m_nb_elements) / float(bucket_count());
+ }
+
+ float max_load_factor() const { return m_max_load_factor; }
+
+ void max_load_factor(float ml) {
+ m_max_load_factor = std::max(0.1f, std::min(ml, 0.8f));
+ m_load_threshold_rehash =
+ size_type(float(bucket_count()) * m_max_load_factor);
+
+ const float max_load_factor_with_deleted_buckets =
+ m_max_load_factor + 0.5f * (1.0f - m_max_load_factor);
+ tsl_sh_assert(max_load_factor_with_deleted_buckets > 0.0f &&
+ max_load_factor_with_deleted_buckets <= 1.0f);
+ m_load_threshold_clear_deleted =
+ size_type(float(bucket_count()) * max_load_factor_with_deleted_buckets);
+ }
+
+ void rehash(size_type count) {
+ count = std::max(count,
+ size_type(std::ceil(float(size()) / max_load_factor())));
+ rehash_impl(count);
+ }
+
+ void reserve(size_type count) {
+ rehash(size_type(std::ceil(float(count) / max_load_factor())));
+ }
+
+ /*
+ * Observers
+ */
+ hasher hash_function() const { return static_cast<const Hash &>(*this); }
+
+ key_equal key_eq() const { return static_cast<const KeyEqual &>(*this); }
+
+ /*
+ * Other
+ */
+ iterator mutable_iterator(const_iterator pos) {
+ auto it_sparse_buckets =
+ m_sparse_buckets_data.begin() +
+ std::distance(m_sparse_buckets_data.cbegin(), pos.m_sparse_buckets_it);
+
+ return iterator(it_sparse_buckets,
+ sparse_array::mutable_iterator(pos.m_sparse_array_it));
+ }
+
+ template <class Serializer>
+ void serialize(Serializer &serializer) const {
+ serialize_impl(serializer);
+ }
+
+ template <class Deserializer>
+ void deserialize(Deserializer &deserializer, bool hash_compatible) {
+ deserialize_impl(deserializer, hash_compatible);
+ }
+
+ private:
+ template <class K>
+ std::size_t hash_key(const K &key) const {
+ return Hash::operator()(key);
+ }
+
+ template <class K1, class K2>
+ bool compare_keys(const K1 &key1, const K2 &key2) const {
+ return KeyEqual::operator()(key1, key2);
+ }
+
+ size_type bucket_for_hash(std::size_t hash) const {
+ const std::size_t bucket = GrowthPolicy::bucket_for_hash(hash);
+ tsl_sh_assert(sparse_array::sparse_ibucket(bucket) <
+ m_sparse_buckets_data.size() ||
+ (bucket == 0 && m_sparse_buckets_data.empty()));
+
+ return bucket;
+ }
+
+ template <class U = GrowthPolicy,
+ typename std::enable_if<is_power_of_two_policy<U>::value>::type * =
+ nullptr>
+ size_type next_bucket(size_type ibucket, size_type iprobe) const {
+ (void)iprobe;
+ if (Probing == tsl::sh::probing::linear) {
+ return (ibucket + 1) & this->m_mask;
+ } else {
+ tsl_sh_assert(Probing == tsl::sh::probing::quadratic);
+ return (ibucket + iprobe) & this->m_mask;
+ }
+ }
+
+ template <class U = GrowthPolicy,
+ typename std::enable_if<!is_power_of_two_policy<U>::value>::type * =
+ nullptr>
+ size_type next_bucket(size_type ibucket, size_type iprobe) const {
+ (void)iprobe;
+ if (Probing == tsl::sh::probing::linear) {
+ ibucket++;
+ return (ibucket != bucket_count()) ? ibucket : 0;
+ } else {
+ tsl_sh_assert(Probing == tsl::sh::probing::quadratic);
+ ibucket += iprobe;
+ return (ibucket < bucket_count()) ? ibucket : ibucket % bucket_count();
+ }
+ }
+
+ // TODO encapsulate m_sparse_buckets_data to avoid the managing the allocator
+ void copy_buckets_from(const sparse_hash &other) {
+ m_sparse_buckets_data.reserve(other.m_sparse_buckets_data.size());
+
+ try {
+ for (const auto &bucket : other.m_sparse_buckets_data) {
+ m_sparse_buckets_data.emplace_back(bucket,
+ static_cast<Allocator &>(*this));
+ }
+ } catch (...) {
+ clear();
+ throw;
+ }
+
+ tsl_sh_assert(m_sparse_buckets_data.empty() ||
+ m_sparse_buckets_data.back().last());
+ }
+
+ void move_buckets_from(sparse_hash &&other) {
+ m_sparse_buckets_data.reserve(other.m_sparse_buckets_data.size());
+
+ try {
+ for (auto &&bucket : other.m_sparse_buckets_data) {
+ m_sparse_buckets_data.emplace_back(std::move(bucket),
+ static_cast<Allocator &>(*this));
+ }
+ } catch (...) {
+ clear();
+ throw;
+ }
+
+ tsl_sh_assert(m_sparse_buckets_data.empty() ||
+ m_sparse_buckets_data.back().last());
+ }
+
+ template <class K, class... Args>
+ std::pair<iterator, bool> insert_impl(const K &key,
+ Args &&...value_type_args) {
+ if (size() >= m_load_threshold_rehash) {
+ rehash_impl(GrowthPolicy::next_bucket_count());
+ } else if (size() + m_nb_deleted_buckets >=
+ m_load_threshold_clear_deleted) {
+ clear_deleted_buckets();
+ }
+ tsl_sh_assert(!m_sparse_buckets_data.empty());
+
+ /**
+ * We must insert the value in the first empty or deleted bucket we find. If
+ * we first find a deleted bucket, we still have to continue the search
+ * until we find an empty bucket or until we have searched all the buckets
+ * to be sure that the value is not in the hash table. We thus remember the
+ * position, if any, of the first deleted bucket we have encountered so we
+ * can insert it there if needed.
+ */
+ bool found_first_deleted_bucket = false;
+ std::size_t sparse_ibucket_first_deleted = 0;
+ typename sparse_array::size_type index_in_sparse_bucket_first_deleted = 0;
+
+ const std::size_t hash = hash_key(key);
+ std::size_t ibucket = bucket_for_hash(hash);
+
+ std::size_t probe = 0;
+ while (true) {
+ std::size_t sparse_ibucket = sparse_array::sparse_ibucket(ibucket);
+ auto index_in_sparse_bucket =
+ sparse_array::index_in_sparse_bucket(ibucket);
+
+ if (m_sparse_buckets[sparse_ibucket].has_value(index_in_sparse_bucket)) {
+ auto value_it =
+ m_sparse_buckets[sparse_ibucket].value(index_in_sparse_bucket);
+ if (compare_keys(key, KeySelect()(*value_it))) {
+ return std::make_pair(
+ iterator(m_sparse_buckets_data.begin() + sparse_ibucket,
+ value_it),
+ false);
+ }
+ } else if (m_sparse_buckets[sparse_ibucket].has_deleted_value(
+ index_in_sparse_bucket) &&
+ probe < m_bucket_count) {
+ if (!found_first_deleted_bucket) {
+ found_first_deleted_bucket = true;
+ sparse_ibucket_first_deleted = sparse_ibucket;
+ index_in_sparse_bucket_first_deleted = index_in_sparse_bucket;
+ }
+ } else if (found_first_deleted_bucket) {
+ auto it = insert_in_bucket(sparse_ibucket_first_deleted,
+ index_in_sparse_bucket_first_deleted,
+ std::forward<Args>(value_type_args)...);
+ m_nb_deleted_buckets--;
+
+ return it;
+ } else {
+ return insert_in_bucket(sparse_ibucket, index_in_sparse_bucket,
+ std::forward<Args>(value_type_args)...);
+ }
+
+ probe++;
+ ibucket = next_bucket(ibucket, probe);
+ }
+ }
+
+ template <class... Args>
+ std::pair<iterator, bool> insert_in_bucket(
+ std::size_t sparse_ibucket,
+ typename sparse_array::size_type index_in_sparse_bucket,
+ Args &&...value_type_args) {
+ auto value_it = m_sparse_buckets[sparse_ibucket].set(
+ *this, index_in_sparse_bucket, std::forward<Args>(value_type_args)...);
+ m_nb_elements++;
+
+ return std::make_pair(
+ iterator(m_sparse_buckets_data.begin() + sparse_ibucket, value_it),
+ true);
+ }
+
+ template <class K>
+ size_type erase_impl(const K &key, std::size_t hash) {
+ std::size_t ibucket = bucket_for_hash(hash);
+
+ std::size_t probe = 0;
+ while (true) {
+ const std::size_t sparse_ibucket = sparse_array::sparse_ibucket(ibucket);
+ const auto index_in_sparse_bucket =
+ sparse_array::index_in_sparse_bucket(ibucket);
+
+ if (m_sparse_buckets[sparse_ibucket].has_value(index_in_sparse_bucket)) {
+ auto value_it =
+ m_sparse_buckets[sparse_ibucket].value(index_in_sparse_bucket);
+ if (compare_keys(key, KeySelect()(*value_it))) {
+ m_sparse_buckets[sparse_ibucket].erase(*this, value_it,
+ index_in_sparse_bucket);
+ m_nb_elements--;
+ m_nb_deleted_buckets++;
+
+ return 1;
+ }
+ } else if (!m_sparse_buckets[sparse_ibucket].has_deleted_value(
+ index_in_sparse_bucket) ||
+ probe >= m_bucket_count) {
+ return 0;
+ }
+
+ probe++;
+ ibucket = next_bucket(ibucket, probe);
+ }
+ }
+
+ template <class K>
+ iterator find_impl(const K &key, std::size_t hash) {
+ return mutable_iterator(
+ static_cast<const sparse_hash *>(this)->find(key, hash));
+ }
+
+ template <class K>
+ const_iterator find_impl(const K &key, std::size_t hash) const {
+ std::size_t ibucket = bucket_for_hash(hash);
+
+ std::size_t probe = 0;
+ while (true) {
+ const std::size_t sparse_ibucket = sparse_array::sparse_ibucket(ibucket);
+ const auto index_in_sparse_bucket =
+ sparse_array::index_in_sparse_bucket(ibucket);
+
+ if (m_sparse_buckets[sparse_ibucket].has_value(index_in_sparse_bucket)) {
+ auto value_it =
+ m_sparse_buckets[sparse_ibucket].value(index_in_sparse_bucket);
+ if (compare_keys(key, KeySelect()(*value_it))) {
+ return const_iterator(m_sparse_buckets_data.cbegin() + sparse_ibucket,
+ value_it);
+ }
+ } else if (!m_sparse_buckets[sparse_ibucket].has_deleted_value(
+ index_in_sparse_bucket) ||
+ probe >= m_bucket_count) {
+ return cend();
+ }
+
+ probe++;
+ ibucket = next_bucket(ibucket, probe);
+ }
+ }
+
+ void clear_deleted_buckets() {
+ // TODO could be optimized, we could do it in-place instead of allocating a
+ // new bucket array.
+ rehash_impl(m_bucket_count);
+ tsl_sh_assert(m_nb_deleted_buckets == 0);
+ }
+
+ template <tsl::sh::exception_safety U = ExceptionSafety,
+ typename std::enable_if<U == tsl::sh::exception_safety::basic>::type
+ * = nullptr>
+ void rehash_impl(size_type count) {
+ sparse_hash new_table(count, static_cast<Hash &>(*this),
+ static_cast<KeyEqual &>(*this),
+ static_cast<Allocator &>(*this), m_max_load_factor);
+
+ for (auto &bucket : m_sparse_buckets_data) {
+ for (auto &val : bucket) {
+ new_table.insert_on_rehash(std::move(val));
+ }
+
+ // TODO try to reuse some of the memory
+ bucket.clear(*this);
+ }
+
+ new_table.swap(*this);
+ }
+
+ /**
+ * TODO: For now we copy each element into the new map. We could move
+ * them if they are nothrow_move_constructible without triggering
+ * any exception if we reserve enough space in the sparse arrays beforehand.
+ */
+ template <tsl::sh::exception_safety U = ExceptionSafety,
+ typename std::enable_if<
+ U == tsl::sh::exception_safety::strong>::type * = nullptr>
+ void rehash_impl(size_type count) {
+ sparse_hash new_table(count, static_cast<Hash &>(*this),
+ static_cast<KeyEqual &>(*this),
+ static_cast<Allocator &>(*this), m_max_load_factor);
+
+ for (const auto &bucket : m_sparse_buckets_data) {
+ for (const auto &val : bucket) {
+ new_table.insert_on_rehash(val);
+ }
+ }
+
+ new_table.swap(*this);
+ }
+
+ template <typename K>
+ void insert_on_rehash(K &&key_value) {
+ const key_type &key = KeySelect()(key_value);
+
+ const std::size_t hash = hash_key(key);
+ std::size_t ibucket = bucket_for_hash(hash);
+
+ std::size_t probe = 0;
+ while (true) {
+ std::size_t sparse_ibucket = sparse_array::sparse_ibucket(ibucket);
+ auto index_in_sparse_bucket =
+ sparse_array::index_in_sparse_bucket(ibucket);
+
+ if (!m_sparse_buckets[sparse_ibucket].has_value(index_in_sparse_bucket)) {
+ m_sparse_buckets[sparse_ibucket].set(*this, index_in_sparse_bucket,
+ std::forward<K>(key_value));
+ m_nb_elements++;
+
+ return;
+ } else {
+ tsl_sh_assert(!compare_keys(
+ key, KeySelect()(*m_sparse_buckets[sparse_ibucket].value(
+ index_in_sparse_bucket))));
+ }
+
+ probe++;
+ ibucket = next_bucket(ibucket, probe);
+ }
+ }
+
+ template <class Serializer>
+ void serialize_impl(Serializer &serializer) const {
+ const slz_size_type version = SERIALIZATION_PROTOCOL_VERSION;
+ serializer(version);
+
+ const slz_size_type bucket_count = m_bucket_count;
+ serializer(bucket_count);
+
+ const slz_size_type nb_sparse_buckets = m_sparse_buckets_data.size();
+ serializer(nb_sparse_buckets);
+
+ const slz_size_type nb_elements = m_nb_elements;
+ serializer(nb_elements);
+
+ const slz_size_type nb_deleted_buckets = m_nb_deleted_buckets;
+ serializer(nb_deleted_buckets);
+
+ const float max_load_factor = m_max_load_factor;
+ serializer(max_load_factor);
+
+ for (const auto &bucket : m_sparse_buckets_data) {
+ bucket.serialize(serializer);
+ }
+ }
+
+ template <class Deserializer>
+ void deserialize_impl(Deserializer &deserializer, bool hash_compatible) {
+ tsl_sh_assert(
+ m_bucket_count == 0 &&
+ m_sparse_buckets_data.empty()); // Current hash table must be empty
+
+ const slz_size_type version =
+ deserialize_value<slz_size_type>(deserializer);
+ // For now we only have one version of the serialization protocol.
+ // If it doesn't match there is a problem with the file.
+ if (version != SERIALIZATION_PROTOCOL_VERSION) {
+ throw std::runtime_error(
+ "Can't deserialize the sparse_map/set. The "
+ "protocol version header is invalid.");
+ }
+
+ const slz_size_type bucket_count_ds =
+ deserialize_value<slz_size_type>(deserializer);
+ const slz_size_type nb_sparse_buckets =
+ deserialize_value<slz_size_type>(deserializer);
+ const slz_size_type nb_elements =
+ deserialize_value<slz_size_type>(deserializer);
+ const slz_size_type nb_deleted_buckets =
+ deserialize_value<slz_size_type>(deserializer);
+ const float max_load_factor = deserialize_value<float>(deserializer);
+
+ if (!hash_compatible) {
+ this->max_load_factor(max_load_factor);
+ reserve(numeric_cast<size_type>(nb_elements,
+ "Deserialized nb_elements is too big."));
+ for (slz_size_type ibucket = 0; ibucket < nb_sparse_buckets; ibucket++) {
+ sparse_array::deserialize_values_into_sparse_hash(deserializer, *this);
+ }
+ } else {
+ m_bucket_count = numeric_cast<size_type>(
+ bucket_count_ds, "Deserialized bucket_count is too big.");
+
+ GrowthPolicy::operator=(GrowthPolicy(m_bucket_count));
+ // GrowthPolicy should not modify the bucket count we got from
+ // deserialization
+ if (m_bucket_count != bucket_count_ds) {
+ throw std::runtime_error(
+ "The GrowthPolicy is not the same even though "
+ "hash_compatible is true.");
+ }
+
+ if (nb_sparse_buckets !=
+ sparse_array::nb_sparse_buckets(m_bucket_count)) {
+ throw std::runtime_error("Deserialized nb_sparse_buckets is invalid.");
+ }
+
+ m_nb_elements = numeric_cast<size_type>(
+ nb_elements, "Deserialized nb_elements is too big.");
+ m_nb_deleted_buckets = numeric_cast<size_type>(
+ nb_deleted_buckets, "Deserialized nb_deleted_buckets is too big.");
+
+ m_sparse_buckets_data.reserve(numeric_cast<size_type>(
+ nb_sparse_buckets, "Deserialized nb_sparse_buckets is too big."));
+ for (slz_size_type ibucket = 0; ibucket < nb_sparse_buckets; ibucket++) {
+ m_sparse_buckets_data.emplace_back(
+ sparse_array::deserialize_hash_compatible(
+ deserializer, static_cast<Allocator &>(*this)));
+ }
+
+ if (!m_sparse_buckets_data.empty()) {
+ m_sparse_buckets_data.back().set_as_last();
+ m_sparse_buckets = m_sparse_buckets_data.data();
+ }
+
+ this->max_load_factor(max_load_factor);
+ if (load_factor() > this->max_load_factor()) {
+ throw std::runtime_error(
+ "Invalid max_load_factor. Check that the serializer and "
+ "deserializer support "
+ "floats correctly as they can be converted implicitely to ints.");
+ }
+ }
+ }
+
+ public:
+ static const size_type DEFAULT_INIT_BUCKET_COUNT = 0;
+ static constexpr float DEFAULT_MAX_LOAD_FACTOR = 0.5f;
+
+ /**
+ * Protocol version currenlty used for serialization.
+ */
+ static const slz_size_type SERIALIZATION_PROTOCOL_VERSION = 1;
+
+ /**
+ * Return an always valid pointer to an static empty bucket_entry with
+ * last_bucket() == true.
+ */
+ sparse_array *static_empty_sparse_bucket_ptr() {
+ static sparse_array empty_sparse_bucket(true);
+ return &empty_sparse_bucket;
+ }
+
+ private:
+ sparse_buckets_container m_sparse_buckets_data;
+
+ /**
+ * Points to m_sparse_buckets_data.data() if !m_sparse_buckets_data.empty()
+ * otherwise points to static_empty_sparse_bucket_ptr. This variable is useful
+ * to avoid the cost of checking if m_sparse_buckets_data is empty when trying
+ * to find an element.
+ *
+ * TODO Remove m_sparse_buckets_data and only use a pointer instead of a
+ * pointer+vector to save some space in the sparse_hash object.
+ */
+ sparse_array *m_sparse_buckets;
+
+ size_type m_bucket_count;
+ size_type m_nb_elements;
+ size_type m_nb_deleted_buckets;
+
+ /**
+ * Maximum that m_nb_elements can reach before a rehash occurs automatically
+ * to grow the hash table.
+ */
+ size_type m_load_threshold_rehash;
+
+ /**
+ * Maximum that m_nb_elements + m_nb_deleted_buckets can reach before cleaning
+ * up the buckets marked as deleted.
+ */
+ size_type m_load_threshold_clear_deleted;
+ float m_max_load_factor;
+};
+
+} // namespace detail_sparse_hash
+} // namespace tsl
+
+#endif
diff --git a/be/src/extern/diskann/include/tsl/sparse_map.h b/be/src/extern/diskann/include/tsl/sparse_map.h
new file mode 100644
index 0000000..601742d
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/sparse_map.h
@@ -0,0 +1,800 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef TSL_SPARSE_MAP_H
+#define TSL_SPARSE_MAP_H
+
+#include <cstddef>
+#include <functional>
+#include <initializer_list>
+#include <memory>
+#include <type_traits>
+#include <utility>
+
+#include "sparse_hash.h"
+
+namespace tsl {
+
+/**
+ * Implementation of a sparse hash map using open-addressing with quadratic
+ * probing. The goal on the hash map is to be the most memory efficient
+ * possible, even at low load factor, while keeping reasonable performances.
+ *
+ * `GrowthPolicy` defines how the map grows and consequently how a hash value is
+ * mapped to a bucket. By default the map uses
+ * `tsl::sh::power_of_two_growth_policy`. This policy keeps the number of
+ * buckets to a power of two and uses a mask to map the hash to a bucket instead
+ * of the slow modulo. Other growth policies are available and you may define
+ * your own growth policy, check `tsl::sh::power_of_two_growth_policy` for the
+ * interface.
+ *
+ * `ExceptionSafety` defines the exception guarantee provided by the class. By
+ * default only the basic exception safety is guaranteed which mean that all
+ * resources used by the hash map will be freed (no memory leaks) but the hash
+ * map may end-up in an undefined state if an exception is thrown (undefined
+ * here means that some elements may be missing). This can ONLY happen on rehash
+ * (either on insert or if `rehash` is called explicitly) and will occur if the
+ * Allocator can't allocate memory (`std::bad_alloc`) or if the copy constructor
+ * (when a nothrow move constructor is not available) throws an exception. This
+ * can be avoided by calling `reserve` beforehand. This basic guarantee is
+ * similar to the one of `google::sparse_hash_map` and `spp::sparse_hash_map`.
+ * It is possible to ask for the strong exception guarantee with
+ * `tsl::sh::exception_safety::strong`, the drawback is that the map will be
+ * slower on rehashes and will also need more memory on rehashes.
+ *
+ * `Sparsity` defines how much the hash set will compromise between insertion
+ * speed and memory usage. A high sparsity means less memory usage but longer
+ * insertion times, and vice-versa for low sparsity. The default
+ * `tsl::sh::sparsity::medium` sparsity offers a good compromise. It doesn't
+ * change the lookup speed.
+ *
+ * `Key` and `T` must be nothrow move constructible and/or copy constructible.
+ *
+ * If the destructor of `Key` or `T` throws an exception, the behaviour of the
+ * class is undefined.
+ *
+ * Iterators invalidation:
+ * - clear, operator=, reserve, rehash: always invalidate the iterators.
+ * - insert, emplace, emplace_hint, operator[]: if there is an effective
+ * insert, invalidate the iterators.
+ * - erase: always invalidate the iterators.
+ */
+template <class Key, class T, class Hash = std::hash<Key>,
+ class KeyEqual = std::equal_to<Key>,
+ class Allocator = std::allocator<std::pair<Key, T>>,
+ class GrowthPolicy = tsl::sh::power_of_two_growth_policy<2>,
+ tsl::sh::exception_safety ExceptionSafety =
+ tsl::sh::exception_safety::basic,
+ tsl::sh::sparsity Sparsity = tsl::sh::sparsity::medium>
+class sparse_map {
+ private:
+ template <typename U>
+ using has_is_transparent = tsl::detail_sparse_hash::has_is_transparent<U>;
+
+ class KeySelect {
+ public:
+ using key_type = Key;
+
+ const key_type &operator()(
+ const std::pair<Key, T> &key_value) const noexcept {
+ return key_value.first;
+ }
+
+ key_type &operator()(std::pair<Key, T> &key_value) noexcept {
+ return key_value.first;
+ }
+ };
+
+ class ValueSelect {
+ public:
+ using value_type = T;
+
+ const value_type &operator()(
+ const std::pair<Key, T> &key_value) const noexcept {
+ return key_value.second;
+ }
+
+ value_type &operator()(std::pair<Key, T> &key_value) noexcept {
+ return key_value.second;
+ }
+ };
+
+ using ht = detail_sparse_hash::sparse_hash<
+ std::pair<Key, T>, KeySelect, ValueSelect, Hash, KeyEqual, Allocator,
+ GrowthPolicy, ExceptionSafety, Sparsity, tsl::sh::probing::quadratic>;
+
+ public:
+ using key_type = typename ht::key_type;
+ using mapped_type = T;
+ using value_type = typename ht::value_type;
+ using size_type = typename ht::size_type;
+ using difference_type = typename ht::difference_type;
+ using hasher = typename ht::hasher;
+ using key_equal = typename ht::key_equal;
+ using allocator_type = typename ht::allocator_type;
+ using reference = typename ht::reference;
+ using const_reference = typename ht::const_reference;
+ using pointer = typename ht::pointer;
+ using const_pointer = typename ht::const_pointer;
+ using iterator = typename ht::iterator;
+ using const_iterator = typename ht::const_iterator;
+
+ public:
+ /*
+ * Constructors
+ */
+ sparse_map() : sparse_map(ht::DEFAULT_INIT_BUCKET_COUNT) {}
+
+ explicit sparse_map(size_type bucket_count, const Hash &hash = Hash(),
+ const KeyEqual &equal = KeyEqual(),
+ const Allocator &alloc = Allocator())
+ : m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR) {}
+
+ sparse_map(size_type bucket_count, const Allocator &alloc)
+ : sparse_map(bucket_count, Hash(), KeyEqual(), alloc) {}
+
+ sparse_map(size_type bucket_count, const Hash &hash, const Allocator &alloc)
+ : sparse_map(bucket_count, hash, KeyEqual(), alloc) {}
+
+ explicit sparse_map(const Allocator &alloc)
+ : sparse_map(ht::DEFAULT_INIT_BUCKET_COUNT, alloc) {}
+
+ template <class InputIt>
+ sparse_map(InputIt first, InputIt last,
+ size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
+ const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
+ const Allocator &alloc = Allocator())
+ : sparse_map(bucket_count, hash, equal, alloc) {
+ insert(first, last);
+ }
+
+ template <class InputIt>
+ sparse_map(InputIt first, InputIt last, size_type bucket_count,
+ const Allocator &alloc)
+ : sparse_map(first, last, bucket_count, Hash(), KeyEqual(), alloc) {}
+
+ template <class InputIt>
+ sparse_map(InputIt first, InputIt last, size_type bucket_count,
+ const Hash &hash, const Allocator &alloc)
+ : sparse_map(first, last, bucket_count, hash, KeyEqual(), alloc) {}
+
+ sparse_map(std::initializer_list<value_type> init,
+ size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
+ const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
+ const Allocator &alloc = Allocator())
+ : sparse_map(init.begin(), init.end(), bucket_count, hash, equal, alloc) {
+ }
+
+ sparse_map(std::initializer_list<value_type> init, size_type bucket_count,
+ const Allocator &alloc)
+ : sparse_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(),
+ alloc) {}
+
+ sparse_map(std::initializer_list<value_type> init, size_type bucket_count,
+ const Hash &hash, const Allocator &alloc)
+ : sparse_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(),
+ alloc) {}
+
+ sparse_map &operator=(std::initializer_list<value_type> ilist) {
+ m_ht.clear();
+
+ m_ht.reserve(ilist.size());
+ m_ht.insert(ilist.begin(), ilist.end());
+
+ return *this;
+ }
+
+ allocator_type get_allocator() const { return m_ht.get_allocator(); }
+
+ /*
+ * Iterators
+ */
+ iterator begin() noexcept { return m_ht.begin(); }
+ const_iterator begin() const noexcept { return m_ht.begin(); }
+ const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
+
+ iterator end() noexcept { return m_ht.end(); }
+ const_iterator end() const noexcept { return m_ht.end(); }
+ const_iterator cend() const noexcept { return m_ht.cend(); }
+
+ /*
+ * Capacity
+ */
+ bool empty() const noexcept { return m_ht.empty(); }
+ size_type size() const noexcept { return m_ht.size(); }
+ size_type max_size() const noexcept { return m_ht.max_size(); }
+
+ /*
+ * Modifiers
+ */
+ void clear() noexcept { m_ht.clear(); }
+
+ std::pair<iterator, bool> insert(const value_type &value) {
+ return m_ht.insert(value);
+ }
+
+ template <class P, typename std::enable_if<std::is_constructible<
+ value_type, P &&>::value>::type * = nullptr>
+ std::pair<iterator, bool> insert(P &&value) {
+ return m_ht.emplace(std::forward<P>(value));
+ }
+
+ std::pair<iterator, bool> insert(value_type &&value) {
+ return m_ht.insert(std::move(value));
+ }
+
+ iterator insert(const_iterator hint, const value_type &value) {
+ return m_ht.insert_hint(hint, value);
+ }
+
+ template <class P, typename std::enable_if<std::is_constructible<
+ value_type, P &&>::value>::type * = nullptr>
+ iterator insert(const_iterator hint, P &&value) {
+ return m_ht.emplace_hint(hint, std::forward<P>(value));
+ }
+
+ iterator insert(const_iterator hint, value_type &&value) {
+ return m_ht.insert_hint(hint, std::move(value));
+ }
+
+ template <class InputIt>
+ void insert(InputIt first, InputIt last) {
+ m_ht.insert(first, last);
+ }
+
+ void insert(std::initializer_list<value_type> ilist) {
+ m_ht.insert(ilist.begin(), ilist.end());
+ }
+
+ template <class M>
+ std::pair<iterator, bool> insert_or_assign(const key_type &k, M &&obj) {
+ return m_ht.insert_or_assign(k, std::forward<M>(obj));
+ }
+
+ template <class M>
+ std::pair<iterator, bool> insert_or_assign(key_type &&k, M &&obj) {
+ return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
+ }
+
+ template <class M>
+ iterator insert_or_assign(const_iterator hint, const key_type &k, M &&obj) {
+ return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
+ }
+
+ template <class M>
+ iterator insert_or_assign(const_iterator hint, key_type &&k, M &&obj) {
+ return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
+ }
+
+ /**
+ * Due to the way elements are stored, emplace will need to move or copy the
+ * key-value once. The method is equivalent to
+ * `insert(value_type(std::forward<Args>(args)...));`.
+ *
+ * Mainly here for compatibility with the `std::unordered_map` interface.
+ */
+ template <class... Args>
+ std::pair<iterator, bool> emplace(Args &&...args) {
+ return m_ht.emplace(std::forward<Args>(args)...);
+ }
+
+ /**
+ * Due to the way elements are stored, emplace_hint will need to move or copy
+ * the key-value once. The method is equivalent to `insert(hint,
+ * value_type(std::forward<Args>(args)...));`.
+ *
+ * Mainly here for compatibility with the `std::unordered_map` interface.
+ */
+ template <class... Args>
+ iterator emplace_hint(const_iterator hint, Args &&...args) {
+ return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ std::pair<iterator, bool> try_emplace(const key_type &k, Args &&...args) {
+ return m_ht.try_emplace(k, std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ std::pair<iterator, bool> try_emplace(key_type &&k, Args &&...args) {
+ return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ iterator try_emplace(const_iterator hint, const key_type &k, Args &&...args) {
+ return m_ht.try_emplace_hint(hint, k, std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ iterator try_emplace(const_iterator hint, key_type &&k, Args &&...args) {
+ return m_ht.try_emplace_hint(hint, std::move(k),
+ std::forward<Args>(args)...);
+ }
+
+ iterator erase(iterator pos) { return m_ht.erase(pos); }
+ iterator erase(const_iterator pos) { return m_ht.erase(pos); }
+ iterator erase(const_iterator first, const_iterator last) {
+ return m_ht.erase(first, last);
+ }
+ size_type erase(const key_type &key) { return m_ht.erase(key); }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ size_type erase(const key_type &key, std::size_t precalculated_hash) {
+ return m_ht.erase(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ size_type erase(const K &key) {
+ return m_ht.erase(key);
+ }
+
+ /**
+ * @copydoc erase(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ size_type erase(const K &key, std::size_t precalculated_hash) {
+ return m_ht.erase(key, precalculated_hash);
+ }
+
+ void swap(sparse_map &other) { other.m_ht.swap(m_ht); }
+
+ /*
+ * Lookup
+ */
+ T &at(const Key &key) { return m_ht.at(key); }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ T &at(const Key &key, std::size_t precalculated_hash) {
+ return m_ht.at(key, precalculated_hash);
+ }
+
+ const T &at(const Key &key) const { return m_ht.at(key); }
+
+ /**
+ * @copydoc at(const Key& key, std::size_t precalculated_hash)
+ */
+ const T &at(const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.at(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ T &at(const K &key) {
+ return m_ht.at(key);
+ }
+
+ /**
+ * @copydoc at(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ T &at(const K &key, std::size_t precalculated_hash) {
+ return m_ht.at(key, precalculated_hash);
+ }
+
+ /**
+ * @copydoc at(const K& key)
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ const T &at(const K &key) const {
+ return m_ht.at(key);
+ }
+
+ /**
+ * @copydoc at(const K& key, std::size_t precalculated_hash)
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ const T &at(const K &key, std::size_t precalculated_hash) const {
+ return m_ht.at(key, precalculated_hash);
+ }
+
+ T &operator[](const Key &key) { return m_ht[key]; }
+ T &operator[](Key &&key) { return m_ht[std::move(key)]; }
+
+ size_type count(const Key &key) const { return m_ht.count(key); }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ size_type count(const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.count(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ size_type count(const K &key) const {
+ return m_ht.count(key);
+ }
+
+ /**
+ * @copydoc count(const K& key) const
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ size_type count(const K &key, std::size_t precalculated_hash) const {
+ return m_ht.count(key, precalculated_hash);
+ }
+
+ iterator find(const Key &key) { return m_ht.find(key); }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ iterator find(const Key &key, std::size_t precalculated_hash) {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ const_iterator find(const Key &key) const { return m_ht.find(key); }
+
+ /**
+ * @copydoc find(const Key& key, std::size_t precalculated_hash)
+ */
+ const_iterator find(const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ iterator find(const K &key) {
+ return m_ht.find(key);
+ }
+
+ /**
+ * @copydoc find(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ iterator find(const K &key, std::size_t precalculated_hash) {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ /**
+ * @copydoc find(const K& key)
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ const_iterator find(const K &key) const {
+ return m_ht.find(key);
+ }
+
+ /**
+ * @copydoc find(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ const_iterator find(const K &key, std::size_t precalculated_hash) const {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ bool contains(const Key &key) const { return m_ht.contains(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The
+ * hash value should be the same as hash_function()(key). Useful to speed-up
+ * the lookup if you already have the hash.
+ */
+ bool contains(const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.contains(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * KeyEqual::is_transparent exists. If so, K must be hashable and comparable
+ * to Key.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ bool contains(const K &key) const {
+ return m_ht.contains(key);
+ }
+
+ /**
+ * @copydoc contains(const K& key) const
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The
+ * hash value should be the same as hash_function()(key). Useful to speed-up
+ * the lookup if you already have the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ bool contains(const K &key, std::size_t precalculated_hash) const {
+ return m_ht.contains(key, precalculated_hash);
+ }
+
+ std::pair<iterator, iterator> equal_range(const Key &key) {
+ return m_ht.equal_range(key);
+ }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ std::pair<iterator, iterator> equal_range(const Key &key,
+ std::size_t precalculated_hash) {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ std::pair<const_iterator, const_iterator> equal_range(const Key &key) const {
+ return m_ht.equal_range(key);
+ }
+
+ /**
+ * @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
+ */
+ std::pair<const_iterator, const_iterator> equal_range(
+ const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ std::pair<iterator, iterator> equal_range(const K &key) {
+ return m_ht.equal_range(key);
+ }
+
+ /**
+ * @copydoc equal_range(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ std::pair<iterator, iterator> equal_range(const K &key,
+ std::size_t precalculated_hash) {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /**
+ * @copydoc equal_range(const K& key)
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ std::pair<const_iterator, const_iterator> equal_range(const K &key) const {
+ return m_ht.equal_range(key);
+ }
+
+ /**
+ * @copydoc equal_range(const K& key, std::size_t precalculated_hash)
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ std::pair<const_iterator, const_iterator> equal_range(
+ const K &key, std::size_t precalculated_hash) const {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /*
+ * Bucket interface
+ */
+ size_type bucket_count() const { return m_ht.bucket_count(); }
+ size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
+
+ /*
+ * Hash policy
+ */
+ float load_factor() const { return m_ht.load_factor(); }
+ float max_load_factor() const { return m_ht.max_load_factor(); }
+ void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
+
+ void rehash(size_type count) { m_ht.rehash(count); }
+ void reserve(size_type count) { m_ht.reserve(count); }
+
+ /*
+ * Observers
+ */
+ hasher hash_function() const { return m_ht.hash_function(); }
+ key_equal key_eq() const { return m_ht.key_eq(); }
+
+ /*
+ * Other
+ */
+
+ /**
+ * Convert a `const_iterator` to an `iterator`.
+ */
+ iterator mutable_iterator(const_iterator pos) {
+ return m_ht.mutable_iterator(pos);
+ }
+
+ /**
+ * Serialize the map through the `serializer` parameter.
+ *
+ * The `serializer` parameter must be a function object that supports the
+ * following call:
+ * - `template<typename U> void operator()(const U& value);` where the types
+ * `std::uint64_t`, `float` and `std::pair<Key, T>` must be supported for U.
+ *
+ * The implementation leaves binary compatibility (endianness, IEEE 754 for
+ * floats, ...) of the types it serializes in the hands of the `Serializer`
+ * function object if compatibility is required.
+ */
+ template <class Serializer>
+ void serialize(Serializer &serializer) const {
+ m_ht.serialize(serializer);
+ }
+
+ /**
+ * Deserialize a previously serialized map through the `deserializer`
+ * parameter.
+ *
+ * The `deserializer` parameter must be a function object that supports the
+ * following calls:
+ * - `template<typename U> U operator()();` where the types `std::uint64_t`,
+ * `float` and `std::pair<Key, T>` must be supported for U.
+ *
+ * If the deserialized hash map type is hash compatible with the serialized
+ * map, the deserialization process can be sped up by setting
+ * `hash_compatible` to true. To be hash compatible, the Hash, KeyEqual and
+ * GrowthPolicy must behave the same way than the ones used on the serialized
+ * map. The `std::size_t` must also be of the same size as the one on the
+ * platform used to serialize the map. If these criteria are not met, the
+ * behaviour is undefined with `hash_compatible` sets to true.
+ *
+ * The behaviour is undefined if the type `Key` and `T` of the `sparse_map`
+ * are not the same as the types used during serialization.
+ *
+ * The implementation leaves binary compatibility (endianness, IEEE 754 for
+ * floats, size of int, ...) of the types it deserializes in the hands of the
+ * `Deserializer` function object if compatibility is required.
+ */
+ template <class Deserializer>
+ static sparse_map deserialize(Deserializer &deserializer,
+ bool hash_compatible = false) {
+ sparse_map map(0);
+ map.m_ht.deserialize(deserializer, hash_compatible);
+
+ return map;
+ }
+
+ friend bool operator==(const sparse_map &lhs, const sparse_map &rhs) {
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+
+ for (const auto &element_lhs : lhs) {
+ const auto it_element_rhs = rhs.find(element_lhs.first);
+ if (it_element_rhs == rhs.cend() ||
+ element_lhs.second != it_element_rhs->second) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ friend bool operator!=(const sparse_map &lhs, const sparse_map &rhs) {
+ return !operator==(lhs, rhs);
+ }
+
+ friend void swap(sparse_map &lhs, sparse_map &rhs) { lhs.swap(rhs); }
+
+ private:
+ ht m_ht;
+};
+
+/**
+ * Same as `tsl::sparse_map<Key, T, Hash, KeyEqual, Allocator,
+ * tsl::sh::prime_growth_policy>`.
+ */
+template <class Key, class T, class Hash = std::hash<Key>,
+ class KeyEqual = std::equal_to<Key>,
+ class Allocator = std::allocator<std::pair<Key, T>>>
+using sparse_pg_map =
+ sparse_map<Key, T, Hash, KeyEqual, Allocator, tsl::sh::prime_growth_policy>;
+
+} // end namespace tsl
+
+#endif
diff --git a/be/src/extern/diskann/include/tsl/sparse_set.h b/be/src/extern/diskann/include/tsl/sparse_set.h
new file mode 100644
index 0000000..3ce6a58
--- /dev/null
+++ b/be/src/extern/diskann/include/tsl/sparse_set.h
@@ -0,0 +1,655 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef TSL_SPARSE_SET_H
+#define TSL_SPARSE_SET_H
+
+#include <cstddef>
+#include <functional>
+#include <initializer_list>
+#include <memory>
+#include <type_traits>
+#include <utility>
+
+#include "sparse_hash.h"
+
+namespace tsl {
+
+/**
+ * Implementation of a sparse hash set using open-addressing with quadratic
+ * probing. The goal on the hash set is to be the most memory efficient
+ * possible, even at low load factor, while keeping reasonable performances.
+ *
+ * `GrowthPolicy` defines how the set grows and consequently how a hash value is
+ * mapped to a bucket. By default the set uses
+ * `tsl::sh::power_of_two_growth_policy`. This policy keeps the number of
+ * buckets to a power of two and uses a mask to map the hash to a bucket instead
+ * of the slow modulo. Other growth policies are available and you may define
+ * your own growth policy, check `tsl::sh::power_of_two_growth_policy` for the
+ * interface.
+ *
+ * `ExceptionSafety` defines the exception guarantee provided by the class. By
+ * default only the basic exception safety is guaranteed which mean that all
+ * resources used by the hash set will be freed (no memory leaks) but the hash
+ * set may end-up in an undefined state if an exception is thrown (undefined
+ * here means that some elements may be missing). This can ONLY happen on rehash
+ * (either on insert or if `rehash` is called explicitly) and will occur if the
+ * Allocator can't allocate memory (`std::bad_alloc`) or if the copy constructor
+ * (when a nothrow move constructor is not available) throws an exception. This
+ * can be avoided by calling `reserve` beforehand. This basic guarantee is
+ * similar to the one of `google::sparse_hash_map` and `spp::sparse_hash_map`.
+ * It is possible to ask for the strong exception guarantee with
+ * `tsl::sh::exception_safety::strong`, the drawback is that the set will be
+ * slower on rehashes and will also need more memory on rehashes.
+ *
+ * `Sparsity` defines how much the hash set will compromise between insertion
+ * speed and memory usage. A high sparsity means less memory usage but longer
+ * insertion times, and vice-versa for low sparsity. The default
+ * `tsl::sh::sparsity::medium` sparsity offers a good compromise. It doesn't
+ * change the lookup speed.
+ *
+ * `Key` must be nothrow move constructible and/or copy constructible.
+ *
+ * If the destructor of `Key` throws an exception, the behaviour of the class is
+ * undefined.
+ *
+ * Iterators invalidation:
+ * - clear, operator=, reserve, rehash: always invalidate the iterators.
+ * - insert, emplace, emplace_hint: if there is an effective insert, invalidate
+ * the iterators.
+ * - erase: always invalidate the iterators.
+ */
+template <class Key, class Hash = std::hash<Key>,
+ class KeyEqual = std::equal_to<Key>,
+ class Allocator = std::allocator<Key>,
+ class GrowthPolicy = tsl::sh::power_of_two_growth_policy<2>,
+ tsl::sh::exception_safety ExceptionSafety =
+ tsl::sh::exception_safety::basic,
+ tsl::sh::sparsity Sparsity = tsl::sh::sparsity::medium>
+class sparse_set {
+ private:
+ template <typename U>
+ using has_is_transparent = tsl::detail_sparse_hash::has_is_transparent<U>;
+
+ class KeySelect {
+ public:
+ using key_type = Key;
+
+ const key_type &operator()(const Key &key) const noexcept { return key; }
+
+ key_type &operator()(Key &key) noexcept { return key; }
+ };
+
+ using ht =
+ detail_sparse_hash::sparse_hash<Key, KeySelect, void, Hash, KeyEqual,
+ Allocator, GrowthPolicy, ExceptionSafety,
+ Sparsity, tsl::sh::probing::quadratic>;
+
+ public:
+ using key_type = typename ht::key_type;
+ using value_type = typename ht::value_type;
+ using size_type = typename ht::size_type;
+ using difference_type = typename ht::difference_type;
+ using hasher = typename ht::hasher;
+ using key_equal = typename ht::key_equal;
+ using allocator_type = typename ht::allocator_type;
+ using reference = typename ht::reference;
+ using const_reference = typename ht::const_reference;
+ using pointer = typename ht::pointer;
+ using const_pointer = typename ht::const_pointer;
+ using iterator = typename ht::iterator;
+ using const_iterator = typename ht::const_iterator;
+
+ /*
+ * Constructors
+ */
+ sparse_set() : sparse_set(ht::DEFAULT_INIT_BUCKET_COUNT) {}
+
+ explicit sparse_set(size_type bucket_count, const Hash &hash = Hash(),
+ const KeyEqual &equal = KeyEqual(),
+ const Allocator &alloc = Allocator())
+ : m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR) {}
+
+ sparse_set(size_type bucket_count, const Allocator &alloc)
+ : sparse_set(bucket_count, Hash(), KeyEqual(), alloc) {}
+
+ sparse_set(size_type bucket_count, const Hash &hash, const Allocator &alloc)
+ : sparse_set(bucket_count, hash, KeyEqual(), alloc) {}
+
+ explicit sparse_set(const Allocator &alloc)
+ : sparse_set(ht::DEFAULT_INIT_BUCKET_COUNT, alloc) {}
+
+ template <class InputIt>
+ sparse_set(InputIt first, InputIt last,
+ size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
+ const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
+ const Allocator &alloc = Allocator())
+ : sparse_set(bucket_count, hash, equal, alloc) {
+ insert(first, last);
+ }
+
+ template <class InputIt>
+ sparse_set(InputIt first, InputIt last, size_type bucket_count,
+ const Allocator &alloc)
+ : sparse_set(first, last, bucket_count, Hash(), KeyEqual(), alloc) {}
+
+ template <class InputIt>
+ sparse_set(InputIt first, InputIt last, size_type bucket_count,
+ const Hash &hash, const Allocator &alloc)
+ : sparse_set(first, last, bucket_count, hash, KeyEqual(), alloc) {}
+
+ sparse_set(std::initializer_list<value_type> init,
+ size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
+ const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
+ const Allocator &alloc = Allocator())
+ : sparse_set(init.begin(), init.end(), bucket_count, hash, equal, alloc) {
+ }
+
+ sparse_set(std::initializer_list<value_type> init, size_type bucket_count,
+ const Allocator &alloc)
+ : sparse_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(),
+ alloc) {}
+
+ sparse_set(std::initializer_list<value_type> init, size_type bucket_count,
+ const Hash &hash, const Allocator &alloc)
+ : sparse_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(),
+ alloc) {}
+
+ sparse_set &operator=(std::initializer_list<value_type> ilist) {
+ m_ht.clear();
+
+ m_ht.reserve(ilist.size());
+ m_ht.insert(ilist.begin(), ilist.end());
+
+ return *this;
+ }
+
+ allocator_type get_allocator() const { return m_ht.get_allocator(); }
+
+ /*
+ * Iterators
+ */
+ iterator begin() noexcept { return m_ht.begin(); }
+ const_iterator begin() const noexcept { return m_ht.begin(); }
+ const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
+
+ iterator end() noexcept { return m_ht.end(); }
+ const_iterator end() const noexcept { return m_ht.end(); }
+ const_iterator cend() const noexcept { return m_ht.cend(); }
+
+ /*
+ * Capacity
+ */
+ bool empty() const noexcept { return m_ht.empty(); }
+ size_type size() const noexcept { return m_ht.size(); }
+ size_type max_size() const noexcept { return m_ht.max_size(); }
+
+ /*
+ * Modifiers
+ */
+ void clear() noexcept { m_ht.clear(); }
+
+ std::pair<iterator, bool> insert(const value_type &value) {
+ return m_ht.insert(value);
+ }
+
+ std::pair<iterator, bool> insert(value_type &&value) {
+ return m_ht.insert(std::move(value));
+ }
+
+ iterator insert(const_iterator hint, const value_type &value) {
+ return m_ht.insert_hint(hint, value);
+ }
+
+ iterator insert(const_iterator hint, value_type &&value) {
+ return m_ht.insert_hint(hint, std::move(value));
+ }
+
+ template <class InputIt>
+ void insert(InputIt first, InputIt last) {
+ m_ht.insert(first, last);
+ }
+
+ void insert(std::initializer_list<value_type> ilist) {
+ m_ht.insert(ilist.begin(), ilist.end());
+ }
+
+ /**
+ * Due to the way elements are stored, emplace will need to move or copy the
+ * key-value once. The method is equivalent to
+ * `insert(value_type(std::forward<Args>(args)...));`.
+ *
+ * Mainly here for compatibility with the `std::unordered_map` interface.
+ */
+ template <class... Args>
+ std::pair<iterator, bool> emplace(Args &&...args) {
+ return m_ht.emplace(std::forward<Args>(args)...);
+ }
+
+ /**
+ * Due to the way elements are stored, emplace_hint will need to move or copy
+ * the key-value once. The method is equivalent to `insert(hint,
+ * value_type(std::forward<Args>(args)...));`.
+ *
+ * Mainly here for compatibility with the `std::unordered_map` interface.
+ */
+ template <class... Args>
+ iterator emplace_hint(const_iterator hint, Args &&...args) {
+ return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
+ }
+
+ iterator erase(iterator pos) { return m_ht.erase(pos); }
+ iterator erase(const_iterator pos) { return m_ht.erase(pos); }
+ iterator erase(const_iterator first, const_iterator last) {
+ return m_ht.erase(first, last);
+ }
+ size_type erase(const key_type &key) { return m_ht.erase(key); }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ size_type erase(const key_type &key, std::size_t precalculated_hash) {
+ return m_ht.erase(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ size_type erase(const K &key) {
+ return m_ht.erase(key);
+ }
+
+ /**
+ * @copydoc erase(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ size_type erase(const K &key, std::size_t precalculated_hash) {
+ return m_ht.erase(key, precalculated_hash);
+ }
+
+ void swap(sparse_set &other) { other.m_ht.swap(m_ht); }
+
+ /*
+ * Lookup
+ */
+ size_type count(const Key &key) const { return m_ht.count(key); }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ size_type count(const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.count(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ size_type count(const K &key) const {
+ return m_ht.count(key);
+ }
+
+ /**
+ * @copydoc count(const K& key) const
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ size_type count(const K &key, std::size_t precalculated_hash) const {
+ return m_ht.count(key, precalculated_hash);
+ }
+
+ iterator find(const Key &key) { return m_ht.find(key); }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ iterator find(const Key &key, std::size_t precalculated_hash) {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ const_iterator find(const Key &key) const { return m_ht.find(key); }
+
+ /**
+ * @copydoc find(const Key& key, std::size_t precalculated_hash)
+ */
+ const_iterator find(const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ iterator find(const K &key) {
+ return m_ht.find(key);
+ }
+
+ /**
+ * @copydoc find(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ iterator find(const K &key, std::size_t precalculated_hash) {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ /**
+ * @copydoc find(const K& key)
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ const_iterator find(const K &key) const {
+ return m_ht.find(key);
+ }
+
+ /**
+ * @copydoc find(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ const_iterator find(const K &key, std::size_t precalculated_hash) const {
+ return m_ht.find(key, precalculated_hash);
+ }
+
+ bool contains(const Key &key) const { return m_ht.contains(key); }
+
+ /**
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The
+ * hash value should be the same as hash_function()(key). Useful to speed-up
+ * the lookup if you already have the hash.
+ */
+ bool contains(const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.contains(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * KeyEqual::is_transparent exists. If so, K must be hashable and comparable
+ * to Key.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ bool contains(const K &key) const {
+ return m_ht.contains(key);
+ }
+
+ /**
+ * @copydoc contains(const K& key) const
+ *
+ * Use the hash value 'precalculated_hash' instead of hashing the key. The
+ * hash value should be the same as hash_function()(key). Useful to speed-up
+ * the lookup if you already have the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ bool contains(const K &key, std::size_t precalculated_hash) const {
+ return m_ht.contains(key, precalculated_hash);
+ }
+
+ std::pair<iterator, iterator> equal_range(const Key &key) {
+ return m_ht.equal_range(key);
+ }
+
+ /**
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ std::pair<iterator, iterator> equal_range(const Key &key,
+ std::size_t precalculated_hash) {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ std::pair<const_iterator, const_iterator> equal_range(const Key &key) const {
+ return m_ht.equal_range(key);
+ }
+
+ /**
+ * @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
+ */
+ std::pair<const_iterator, const_iterator> equal_range(
+ const Key &key, std::size_t precalculated_hash) const {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /**
+ * This overload only participates in the overload resolution if the typedef
+ * `KeyEqual::is_transparent` exists. If so, `K` must be hashable and
+ * comparable to `Key`.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ std::pair<iterator, iterator> equal_range(const K &key) {
+ return m_ht.equal_range(key);
+ }
+
+ /**
+ * @copydoc equal_range(const K& key)
+ *
+ * Use the hash value `precalculated_hash` instead of hashing the key. The
+ * hash value should be the same as `hash_function()(key)`, otherwise the
+ * behaviour is undefined. Useful to speed-up the lookup if you already have
+ * the hash.
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ std::pair<iterator, iterator> equal_range(const K &key,
+ std::size_t precalculated_hash) {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /**
+ * @copydoc equal_range(const K& key)
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ std::pair<const_iterator, const_iterator> equal_range(const K &key) const {
+ return m_ht.equal_range(key);
+ }
+
+ /**
+ * @copydoc equal_range(const K& key, std::size_t precalculated_hash)
+ */
+ template <
+ class K, class KE = KeyEqual,
+ typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
+ std::pair<const_iterator, const_iterator> equal_range(
+ const K &key, std::size_t precalculated_hash) const {
+ return m_ht.equal_range(key, precalculated_hash);
+ }
+
+ /*
+ * Bucket interface
+ */
+ size_type bucket_count() const { return m_ht.bucket_count(); }
+ size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
+
+ /*
+ * Hash policy
+ */
+ float load_factor() const { return m_ht.load_factor(); }
+ float max_load_factor() const { return m_ht.max_load_factor(); }
+ void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
+
+ void rehash(size_type count) { m_ht.rehash(count); }
+ void reserve(size_type count) { m_ht.reserve(count); }
+
+ /*
+ * Observers
+ */
+ hasher hash_function() const { return m_ht.hash_function(); }
+ key_equal key_eq() const { return m_ht.key_eq(); }
+
+ /*
+ * Other
+ */
+
+ /**
+ * Convert a `const_iterator` to an `iterator`.
+ */
+ iterator mutable_iterator(const_iterator pos) {
+ return m_ht.mutable_iterator(pos);
+ }
+
+ /**
+ * Serialize the set through the `serializer` parameter.
+ *
+ * The `serializer` parameter must be a function object that supports the
+ * following call:
+ * - `void operator()(const U& value);` where the types `std::uint64_t`,
+ * `float` and `Key` must be supported for U.
+ *
+ * The implementation leaves binary compatibility (endianness, IEEE 754 for
+ * floats, ...) of the types it serializes in the hands of the `Serializer`
+ * function object if compatibility is required.
+ */
+ template <class Serializer>
+ void serialize(Serializer &serializer) const {
+ m_ht.serialize(serializer);
+ }
+
+ /**
+ * Deserialize a previously serialized set through the `deserializer`
+ * parameter.
+ *
+ * The `deserializer` parameter must be a function object that supports the
+ * following calls:
+ * - `template<typename U> U operator()();` where the types `std::uint64_t`,
+ * `float` and `Key` must be supported for U.
+ *
+ * If the deserialized hash set type is hash compatible with the serialized
+ * set, the deserialization process can be sped up by setting
+ * `hash_compatible` to true. To be hash compatible, the Hash, KeyEqual and
+ * GrowthPolicy must behave the same way than the ones used on the serialized
+ * set. The `std::size_t` must also be of the same size as the one on the
+ * platform used to serialize the set. If these criteria are not met, the
+ * behaviour is undefined with `hash_compatible` sets to true.
+ *
+ * The behaviour is undefined if the type `Key` of the `sparse_set` is not the
+ * same as the type used during serialization.
+ *
+ * The implementation leaves binary compatibility (endianness, IEEE 754 for
+ * floats, size of int, ...) of the types it deserializes in the hands of the
+ * `Deserializer` function object if compatibility is required.
+ */
+ template <class Deserializer>
+ static sparse_set deserialize(Deserializer &deserializer,
+ bool hash_compatible = false) {
+ sparse_set set(0);
+ set.m_ht.deserialize(deserializer, hash_compatible);
+
+ return set;
+ }
+
+ friend bool operator==(const sparse_set &lhs, const sparse_set &rhs) {
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+
+ for (const auto &element_lhs : lhs) {
+ const auto it_element_rhs = rhs.find(element_lhs);
+ if (it_element_rhs == rhs.cend()) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ friend bool operator!=(const sparse_set &lhs, const sparse_set &rhs) {
+ return !operator==(lhs, rhs);
+ }
+
+ friend void swap(sparse_set &lhs, sparse_set &rhs) { lhs.swap(rhs); }
+
+ private:
+ ht m_ht;
+};
+
+/**
+ * Same as `tsl::sparse_set<Key, Hash, KeyEqual, Allocator,
+ * tsl::sh::prime_growth_policy>`.
+ */
+template <class Key, class Hash = std::hash<Key>,
+ class KeyEqual = std::equal_to<Key>,
+ class Allocator = std::allocator<Key>>
+using sparse_pg_set =
+ sparse_set<Key, Hash, KeyEqual, Allocator, tsl::sh::prime_growth_policy>;
+
+} // end namespace tsl
+
+#endif
diff --git a/be/src/extern/diskann/include/types.h b/be/src/extern/diskann/include/types.h
new file mode 100644
index 0000000..953d59a
--- /dev/null
+++ b/be/src/extern/diskann/include/types.h
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#include <cstdint>
+#include <cstddef>
+#include <any>
+#include "any_wrappers.h"
+
+namespace diskann
+{
+typedef uint32_t location_t;
+
+using DataType = std::any;
+using TagType = std::any;
+using LabelType = std::any;
+using TagVector = AnyWrapper::AnyVector;
+using DataVector = AnyWrapper::AnyVector;
+using Labelvector = AnyWrapper::AnyVector;
+using TagRobinSet = AnyWrapper::AnyRobinSet;
+} // namespace diskann
diff --git a/be/src/extern/diskann/include/utils.h b/be/src/extern/diskann/include/utils.h
new file mode 100644
index 0000000..0d02c2f
--- /dev/null
+++ b/be/src/extern/diskann/include/utils.h
@@ -0,0 +1,1370 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+
+
+#include <errno.h>
+
+#include "common_includes.h"
+
+#ifdef __APPLE__
+#else
+#include <malloc.h>
+#endif
+
+#ifdef _WINDOWS
+#include <Windows.h>
+typedef HANDLE FileHandle;
+#else
+#include <unistd.h>
+typedef int FileHandle;
+#endif
+
+#include "distance.h"
+#include "logger.h"
+#include "cached_io.h"
+#include "ann_exception.h"
+#include "windows_customizations.h"
+#include "tsl/robin_set.h"
+#include "types.h"
+#include "tag_uint128.h"
+#include <any>
+
+#ifdef EXEC_ENV_OLS
+#include "content_buf.h"
+#include "memory_mapped_files.h"
+#endif
+#include "combined_file.h"
+#include "ThreadPool.h"
+
+#include <immintrin.h>
+
+#include "vector/stream_wrapper.h"
+
+// taken from
+// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h
+// round up X to the nearest multiple of Y
+#define ROUND_UP(X, Y) ((((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) * (Y))
+
+#define DIV_ROUND_UP(X, Y) (((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0))
+
+// round down X to the nearest multiple of Y
+#define ROUND_DOWN(X, Y) (((uint64_t)(X) / (Y)) * (Y))
+
+// alignment tests
+#define IS_ALIGNED(X, Y) ((uint64_t)(X) % (uint64_t)(Y) == 0)
+#define IS_512_ALIGNED(X) IS_ALIGNED(X, 512)
+#define IS_4096_ALIGNED(X) IS_ALIGNED(X, 4096)
+#define METADATA_SIZE \
+ 4096 // all metadata of individual sub-component files is written in first
+ // 4KB for unified files
+
+#define BUFFER_SIZE_FOR_CACHED_IO (size_t)1024 * (size_t)1048576
+
+#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
+#define PBWIDTH 60
+
+inline bool file_exists_impl(const std::string &name, bool dirCheck = false)
+{
+ int val;
+#ifndef _WINDOWS
+ struct stat buffer;
+ val = stat(name.c_str(), &buffer);
+#else
+ // It is the 21st century but Windows API still thinks in 32-bit terms.
+ // Turns out calling stat() on a file > 4GB results in errno = 132
+ // (OVERFLOW). How silly is this!? So calling _stat64()
+ struct _stat64 buffer;
+ val = _stat64(name.c_str(), &buffer);
+#endif
+
+ if (val != 0)
+ {
+ switch (errno)
+ {
+ case EINVAL:
+ diskann::cout << "Invalid argument passed to stat()" << std::endl;
+ break;
+ case ENOENT:
+ // file is not existing, not an issue, so we won't cout anything.
+ break;
+ default:
+ diskann::cout << "Unexpected error in stat():" << errno << std::endl;
+ break;
+ }
+ return false;
+ }
+ else
+ {
+ // the file entry exists. If reqd, check if this is a directory.
+ return dirCheck ? buffer.st_mode & S_IFDIR : true;
+ }
+}
+
+inline bool file_exists(const std::string &name, bool dirCheck = false)
+{
+#ifdef EXEC_ENV_OLS
+ bool exists = file_exists_impl(name, dirCheck);
+ if (exists)
+ {
+ return true;
+ }
+ if (!dirCheck)
+ {
+ // try with .enc extension
+ std::string enc_name = name + ENCRYPTED_EXTENSION;
+ return file_exists_impl(enc_name, dirCheck);
+ }
+ else
+ {
+ return exists;
+ }
+#else
+ return file_exists_impl(name, dirCheck);
+#endif
+}
+
+inline void open_file_to_write(std::ofstream &writer, const std::string &filename)
+{
+ writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
+ if (!file_exists(filename))
+ writer.open(filename, std::ios::binary | std::ios::out);
+ else
+ writer.open(filename, std::ios::binary | std::ios::in | std::ios::out);
+
+ if (writer.fail())
+ {
+ char buff[1024];
+#ifdef _WINDOWS
+ auto ret = std::to_string(strerror_s(buff, 1024, errno));
+#else
+ auto ret = std::string(strerror_r(errno, buff, 1024));
+#endif
+ auto message = std::string("Failed to open file") + filename + " for write because " + buff + ", ret=" + ret;
+ diskann::cerr << message << std::endl;
+ throw diskann::ANNException(message, -1);
+ }
+}
+
+inline size_t get_file_size(const std::string &fname)
+{
+ std::ifstream reader(fname, std::ios::binary | std::ios::ate);
+ if (!reader.fail() && reader.is_open())
+ {
+ size_t end_pos = reader.tellg();
+ reader.close();
+ return end_pos;
+ }
+ else
+ {
+ diskann::cerr << "Could not open file: " << fname << std::endl;
+ return 0;
+ }
+}
+
+inline int delete_file(const std::string &fileName)
+{
+ if (file_exists(fileName))
+ {
+ auto rc = ::remove(fileName.c_str());
+ if (rc != 0)
+ {
+ diskann::cerr << "Could not delete file: " << fileName
+ << " even though it exists. This might indicate a permissions "
+ "issue. "
+ "If you see this message, please contact the diskann team."
+ << std::endl;
+ }
+ return rc;
+ }
+ else
+ {
+ return 0;
+ }
+}
+
+// generates formatted_label and _labels_map file.
+inline void convert_labels_string_to_int(const std::string &inFileName, const std::string &outFileName,
+ const std::string &mapFileName, const std::string &unv_label)
+{
+ std::unordered_map<std::string, uint32_t> string_int_map;
+ std::ofstream label_writer(outFileName);
+ std::ifstream label_reader(inFileName);
+ if (unv_label != "")
+ string_int_map[unv_label] = 0; // if universal label is provided map it to 0 always
+ std::string line, token;
+ while (std::getline(label_reader, line))
+ {
+ std::istringstream new_iss(line);
+ std::vector<uint32_t> lbls;
+ while (getline(new_iss, token, ','))
+ {
+ token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
+ token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
+ if (string_int_map.find(token) == string_int_map.end())
+ {
+ uint32_t nextId = (uint32_t)string_int_map.size() + 1;
+ string_int_map[token] = nextId; // nextId can never be 0
+ }
+ lbls.push_back(string_int_map[token]);
+ }
+ if (lbls.size() <= 0)
+ {
+ std::cout << "No label found";
+ exit(-1);
+ }
+ for (size_t j = 0; j < lbls.size(); j++)
+ {
+ if (j != lbls.size() - 1)
+ label_writer << lbls[j] << ",";
+ else
+ label_writer << lbls[j] << std::endl;
+ }
+ }
+ label_writer.close();
+
+ std::ofstream map_writer(mapFileName);
+ for (auto mp : string_int_map)
+ {
+ map_writer << mp.first << "\t" << mp.second << std::endl;
+ }
+ map_writer.close();
+}
+
+#ifdef EXEC_ENV_OLS
+class AlignedFileReader;
+#endif
+
+namespace diskann
+{
+static const size_t MAX_SIZE_OF_STREAMBUF = 2LL * 1024 * 1024 * 1024;
+
+inline void print_error_and_terminate(std::stringstream &error_stream)
+{
+ diskann::cerr << error_stream.str() << std::endl;
+ throw diskann::ANNException(error_stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+}
+
+inline void report_memory_allocation_failure()
+{
+ std::stringstream stream;
+ stream << "Memory Allocation Failed.";
+ print_error_and_terminate(stream);
+}
+
+inline void report_misalignment_of_requested_size(size_t align)
+{
+ std::stringstream stream;
+ stream << "Requested memory size is not a multiple of " << align << ". Can not be allocated.";
+ print_error_and_terminate(stream);
+}
+
+inline void alloc_aligned(void **ptr, size_t size, size_t align)
+{
+ *ptr = nullptr;
+ if (IS_ALIGNED(size, align) == 0)
+ report_misalignment_of_requested_size(align);
+#ifndef _WINDOWS
+ *ptr = ::aligned_alloc(align, size);
+#else
+ *ptr = ::_aligned_malloc(size, align); // note the swapped arguments!
+#endif
+ if (*ptr == nullptr)
+ report_memory_allocation_failure();
+}
+
+inline void realloc_aligned(void **ptr, size_t size, size_t align)
+{
+ if (IS_ALIGNED(size, align) == 0)
+ report_misalignment_of_requested_size(align);
+#ifdef _WINDOWS
+ *ptr = ::_aligned_realloc(*ptr, size, align);
+#else
+ diskann::cerr << "No aligned realloc on GCC. Must malloc and mem_align, "
+ "left it out for now."
+ << std::endl;
+#endif
+ if (*ptr == nullptr)
+ report_memory_allocation_failure();
+}
+
+inline void check_stop(std::string arnd)
+{
+ int brnd;
+ diskann::cout << arnd << std::endl;
+ std::cin >> brnd;
+}
+
+inline void aligned_free(void *ptr)
+{
+ // Gopal. Must have a check here if the pointer was actually allocated by
+ // _alloc_aligned
+ if (ptr == nullptr)
+ {
+ return;
+ }
+#ifndef _WINDOWS
+ free(ptr);
+#else
+ ::_aligned_free(ptr);
+#endif
+}
+
+inline void GenRandom(std::mt19937 &rng, unsigned *addr, unsigned size, unsigned N)
+{
+ for (unsigned i = 0; i < size; ++i)
+ {
+ addr[i] = rng() % (N - size);
+ }
+
+ std::sort(addr, addr + size);
+ for (unsigned i = 1; i < size; ++i)
+ {
+ if (addr[i] <= addr[i - 1])
+ {
+ addr[i] = addr[i - 1] + 1;
+ }
+ }
+ unsigned off = rng() % N;
+ for (unsigned i = 0; i < size; ++i)
+ {
+ addr[i] = (addr[i] + off) % N;
+ }
+}
+
+// get_bin_metadata functions START
+inline void get_bin_metadata_impl(std::basic_istream<char> &reader, size_t &nrows, size_t &ncols, size_t offset = 0)
+{
+ int nrows_32, ncols_32;
+ reader.seekg(offset, reader.beg);
+ reader.read((char *)&nrows_32, sizeof(int));
+ reader.read((char *)&ncols_32, sizeof(int));
+ nrows = nrows_32;
+ ncols = ncols_32;
+}
+
+#ifdef EXEC_ENV_OLS
+inline void get_bin_metadata(MemoryMappedFiles &files, const std::string &bin_file, size_t &nrows, size_t &ncols,
+ size_t offset = 0)
+{
+ diskann::cout << "Getting metadata for file: " << bin_file << std::endl;
+ auto fc = files.getContent(bin_file);
+ // auto cb = ContentBuf((char*) fc._content, fc._size);
+ // std::basic_istream<char> reader(&cb);
+ // get_bin_metadata_impl(reader, nrows, ncols, offset);
+
+ int nrows_32, ncols_32;
+ int32_t *metadata_ptr = (int32_t *)((char *)fc._content + offset);
+ nrows_32 = *metadata_ptr;
+ ncols_32 = *(metadata_ptr + 1);
+ nrows = nrows_32;
+ ncols = ncols_32;
+}
+#endif
+
+inline void get_bin_metadata(const std::string &bin_file, size_t &nrows, size_t &ncols, size_t offset = 0)
+{
+ std::ifstream reader(bin_file.c_str(), std::ios::binary);
+ get_bin_metadata_impl(reader, nrows, ncols, offset);
+}
+
+
+
+inline void get_bin_metadata(IReaderWrapperSPtr reader, size_t &nrows, size_t &ncols, size_t offset = 0)
+{
+ int nrows_32, ncols_32;
+ reader->read((char *)&nrows_32, sizeof(int), offset);
+ reader->read((char *)&ncols_32, sizeof(int), offset + 4);
+ nrows = nrows_32;
+ ncols = ncols_32;
+}
+
+inline void get_bin_metadata(std::stringstream &reader, size_t &nrows, size_t &ncols, size_t offset = 0)
+{
+ int nrows_32, ncols_32;
+ reader.seekg(offset, reader.beg);
+ reader.read((char *)&nrows_32, sizeof(int));
+ reader.read((char *)&ncols_32, sizeof(int));
+ nrows = nrows_32;
+ ncols = ncols_32;
+}
+
+
+// get_bin_metadata functions END
+
+#ifndef EXEC_ENV_OLS
+inline size_t get_graph_num_frozen_points(const std::string &graph_file)
+{
+ size_t expected_file_size;
+ uint32_t max_observed_degree, start;
+ size_t file_frozen_pts;
+
+ std::ifstream in;
+ in.exceptions(std::ios::badbit | std::ios::failbit);
+
+ in.open(graph_file, std::ios::binary);
+ in.read((char *)&expected_file_size, sizeof(size_t));
+ in.read((char *)&max_observed_degree, sizeof(uint32_t));
+ in.read((char *)&start, sizeof(uint32_t));
+ in.read((char *)&file_frozen_pts, sizeof(size_t));
+
+ return file_frozen_pts;
+}
+#endif
+
+template <typename T> inline std::string getValues(T *data, size_t num)
+{
+ std::stringstream stream;
+ stream << "[";
+ for (size_t i = 0; i < num; i++)
+ {
+ stream << std::to_string(data[i]) << ",";
+ }
+ stream << "]" << std::endl;
+
+ return stream.str();
+}
+
+// load_bin functions START
+template <typename T>
+inline void load_bin_impl(std::basic_istream<char> &reader, T *&data, size_t &npts, size_t &dim, size_t file_offset = 0)
+{
+ int npts_i32, dim_i32;
+
+ reader.seekg(file_offset, reader.beg);
+ reader.read((char *)&npts_i32, sizeof(int));
+ reader.read((char *)&dim_i32, sizeof(int));
+ npts = (unsigned)npts_i32;
+ dim = (unsigned)dim_i32;
+
+ //std::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "..." << std::endl;
+
+ data = new T[npts * dim];
+ reader.read((char *)data, npts * dim * sizeof(T));
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T>
+inline void load_bin(MemoryMappedFiles &files, const std::string &bin_file, T *&data, size_t &npts, size_t &dim,
+ size_t offset = 0)
+{
+ diskann::cout << "Reading bin file " << bin_file.c_str() << " at offset: " << offset << "..." << std::endl;
+ auto fc = files.getContent(bin_file);
+
+ uint32_t t_npts, t_dim;
+ uint32_t *contentAsIntPtr = (uint32_t *)((char *)fc._content + offset);
+ t_npts = *(contentAsIntPtr);
+ t_dim = *(contentAsIntPtr + 1);
+
+ npts = t_npts;
+ dim = t_dim;
+
+ data = (T *)((char *)fc._content + offset + 2 * sizeof(uint32_t)); // No need to copy!
+}
+
+DISKANN_DLLEXPORT void get_bin_metadata(AlignedFileReader &reader, size_t &npts, size_t &ndim, size_t offset = 0);
+template <typename T>
+DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, T *&data, size_t &npts, size_t &ndim, size_t offset = 0);
+template <typename T>
+DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, std::unique_ptr<T[]> &data, size_t &npts, size_t &ndim,
+ size_t offset = 0);
+
+template <typename T>
+DISKANN_DLLEXPORT void copy_aligned_data_from_file(AlignedFileReader &reader, T *&data, size_t &npts, size_t &dim,
+ const size_t &rounded_dim, size_t offset = 0);
+
+// Unlike load_bin, assumes that data is already allocated 'size' entries
+template <typename T>
+DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, T *data, size_t size, size_t offset = 0);
+
+template <typename T> DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, T &value, size_t offset = 0);
+#endif
+
+template <typename T>
+inline void load_bin(const std::string &bin_file, T *&data, size_t &npts, size_t &dim, size_t offset = 0)
+{
+ diskann::cout << "Reading bin file " << bin_file.c_str() << " ..." << std::endl;
+ std::ifstream reader;
+ reader.exceptions(std::ifstream::failbit | std::ifstream::badbit);
+
+ try
+ {
+ diskann::cout << "Opening bin file " << bin_file.c_str() << "... " << std::endl;
+ reader.open(bin_file, std::ios::binary | std::ios::ate);
+ reader.seekg(0);
+ load_bin_impl<T>(reader, data, npts, dim, offset);
+ }
+ catch (std::system_error &e)
+ {
+ throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ diskann::cout << "done." << std::endl;
+}
+
+
+
+template <typename T>
+inline void load_bin(IReaderWrapperSPtr data_stream, T *&data, size_t &npts, size_t &dim, size_t offset = 0)
+{
+ try
+ {
+ int npts_i32, dim_i32;
+ data_stream->seek(offset);
+ data_stream->read((char *)&npts_i32, sizeof(int), offset);
+ data_stream->read((char *)&dim_i32, sizeof(int), offset+4);
+ npts = (unsigned)npts_i32;
+ dim = (unsigned)dim_i32;
+ //std::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "..." << std::endl;
+ data = new T[npts * dim];
+ data_stream->read((char *)data, npts * dim * sizeof(T), offset + 8);
+ }
+ catch (std::system_error &e)
+ {
+ throw FileException("", e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ diskann::cout << "done." << std::endl;
+}
+
+
+template <typename T>
+inline void load_bin(std::stringstream &data_stream, T *&data, size_t &npts, size_t &dim, size_t offset = 0)
+{
+ try
+ {
+ int npts_i32, dim_i32;
+ data_stream.seekg(offset);
+ data_stream.read((char *)&npts_i32, sizeof(int));
+ data_stream.read((char *)&dim_i32, sizeof(int));
+ npts = (unsigned)npts_i32;
+ dim = (unsigned)dim_i32;
+ //std::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "..." << std::endl;
+ data = new T[npts * dim];
+ data_stream.read((char *)data, npts * dim * sizeof(T));
+ }
+ catch (std::system_error &e)
+ {
+ throw FileException("", e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ diskann::cout << "done." << std::endl;
+}
+
+
+
+inline void wait_for_keystroke()
+{
+ int a;
+ std::cout << "Press any number to continue.." << std::endl;
+ std::cin >> a;
+}
+// load_bin functions END
+
+inline void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
+{
+ size_t read_blk_size = 64 * 1024 * 1024;
+ cached_ifstream reader(bin_file, read_blk_size);
+ diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
+ size_t actual_file_size = reader.get_file_size();
+
+ int npts_i32, dim_i32;
+ reader.read((char *)&npts_i32, sizeof(int));
+ reader.read((char *)&dim_i32, sizeof(int));
+ npts = (unsigned)npts_i32;
+ dim = (unsigned)dim_i32;
+
+ diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
+
+ int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
+ // only ids, -1 is error
+ size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
+
+ if (actual_file_size == expected_file_size_with_dists)
+ truthset_type = 1;
+
+ size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
+
+ if (actual_file_size == expected_file_size_just_ids)
+ truthset_type = 2;
+
+ if (truthset_type == -1)
+ {
+ std::stringstream stream;
+ stream << "Error. File size mismatch. File should have bin format, with "
+ "npts followed by ngt followed by npts*ngt ids and optionally "
+ "followed by npts*ngt distance values; actual size: "
+ << actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
+ << expected_file_size_just_ids;
+ diskann::cout << stream.str();
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ ids = new uint32_t[npts * dim];
+ reader.read((char *)ids, npts * dim * sizeof(uint32_t));
+
+ if (truthset_type == 1)
+ {
+ dists = new float[npts * dim];
+ reader.read((char *)dists, npts * dim * sizeof(float));
+ }
+}
+
+inline void prune_truthset_for_range(const std::string &bin_file, float range,
+ std::vector<std::vector<uint32_t>> &groundtruth, size_t &npts)
+{
+ size_t read_blk_size = 64 * 1024 * 1024;
+ cached_ifstream reader(bin_file, read_blk_size);
+ diskann::cout << "Reading truthset file " << bin_file.c_str() << "... " << std::endl;
+ size_t actual_file_size = reader.get_file_size();
+
+ int npts_i32, dim_i32;
+ reader.read((char *)&npts_i32, sizeof(int));
+ reader.read((char *)&dim_i32, sizeof(int));
+ npts = (unsigned)npts_i32;
+ uint64_t dim = (unsigned)dim_i32;
+ uint32_t *ids;
+ float *dists;
+
+ diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
+
+ int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
+ // only ids, -1 is error
+ size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
+
+ if (actual_file_size == expected_file_size_with_dists)
+ truthset_type = 1;
+
+ if (truthset_type == -1)
+ {
+ std::stringstream stream;
+ stream << "Error. File size mismatch. File should have bin format, with "
+ "npts followed by ngt followed by npts*ngt ids and optionally "
+ "followed by npts*ngt distance values; actual size: "
+ << actual_file_size << ", expected: " << expected_file_size_with_dists;
+ diskann::cout << stream.str();
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ ids = new uint32_t[npts * dim];
+ reader.read((char *)ids, npts * dim * sizeof(uint32_t));
+
+ if (truthset_type == 1)
+ {
+ dists = new float[npts * dim];
+ reader.read((char *)dists, npts * dim * sizeof(float));
+ }
+ float min_dist = std::numeric_limits<float>::max();
+ float max_dist = 0;
+ groundtruth.resize(npts);
+ for (uint32_t i = 0; i < npts; i++)
+ {
+ groundtruth[i].clear();
+ for (uint32_t j = 0; j < dim; j++)
+ {
+ if (dists[i * dim + j] <= range)
+ {
+ groundtruth[i].emplace_back(ids[i * dim + j]);
+ }
+ min_dist = min_dist > dists[i * dim + j] ? dists[i * dim + j] : min_dist;
+ max_dist = max_dist < dists[i * dim + j] ? dists[i * dim + j] : max_dist;
+ }
+ // std::cout<<groundtruth[i].size() << " " ;
+ }
+ std::cout << "Min dist: " << min_dist << ", Max dist: " << max_dist << std::endl;
+ delete[] ids;
+ delete[] dists;
+}
+
+inline void load_range_truthset(const std::string &bin_file, std::vector<std::vector<uint32_t>> &groundtruth,
+ uint64_t >_num)
+{
+ size_t read_blk_size = 64 * 1024 * 1024;
+ cached_ifstream reader(bin_file, read_blk_size);
+ diskann::cout << "Reading truthset file " << bin_file.c_str() << "... " << std::flush;
+ size_t actual_file_size = reader.get_file_size();
+
+ int nptsuint32_t, totaluint32_t;
+ reader.read((char *)&nptsuint32_t, sizeof(int));
+ reader.read((char *)&totaluint32_t, sizeof(int));
+
+ gt_num = (uint64_t)nptsuint32_t;
+ uint64_t total_res = (uint64_t)totaluint32_t;
+
+ diskann::cout << "Metadata: #pts = " << gt_num << ", #total_results = " << total_res << "..." << std::endl;
+
+ size_t expected_file_size = 2 * sizeof(uint32_t) + gt_num * sizeof(uint32_t) + total_res * sizeof(uint32_t);
+
+ if (actual_file_size != expected_file_size)
+ {
+ std::stringstream stream;
+ stream << "Error. File size mismatch in range truthset. actual size: " << actual_file_size
+ << ", expected: " << expected_file_size;
+ diskann::cout << stream.str();
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ groundtruth.clear();
+ groundtruth.resize(gt_num);
+ std::vector<uint32_t> gt_count(gt_num);
+
+ reader.read((char *)gt_count.data(), sizeof(uint32_t) * gt_num);
+
+ std::vector<uint32_t> gt_stats(gt_count);
+ std::sort(gt_stats.begin(), gt_stats.end());
+
+ std::cout << "GT count percentiles:" << std::endl;
+ for (uint32_t p = 0; p < 100; p += 5)
+ std::cout << "percentile " << p << ": " << gt_stats[static_cast<size_t>(std::floor((p / 100.0) * gt_num))]
+ << std::endl;
+ std::cout << "percentile 100"
+ << ": " << gt_stats[gt_num - 1] << std::endl;
+
+ for (uint32_t i = 0; i < gt_num; i++)
+ {
+ groundtruth[i].clear();
+ groundtruth[i].resize(gt_count[i]);
+ if (gt_count[i] != 0)
+ reader.read((char *)groundtruth[i].data(), sizeof(uint32_t) * gt_count[i]);
+ }
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T>
+inline void load_bin(MemoryMappedFiles &files, const std::string &bin_file, std::unique_ptr<T[]> &data, size_t &npts,
+ size_t &dim, size_t offset = 0)
+{
+ T *ptr;
+ load_bin<T>(files, bin_file, ptr, npts, dim, offset);
+ data.reset(ptr);
+}
+#endif
+
+inline void copy_file(std::string in_file, std::string out_file)
+{
+ std::ifstream source(in_file, std::ios::binary);
+ std::ofstream dest(out_file, std::ios::binary);
+
+ std::istreambuf_iterator<char> begin_source(source);
+ std::istreambuf_iterator<char> end_source;
+ std::ostreambuf_iterator<char> begin_dest(dest);
+ std::copy(begin_source, end_source, begin_dest);
+
+ source.close();
+ dest.close();
+}
+
+DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs,
+ unsigned *our_results, unsigned dim_or, unsigned recall_at);
+
+DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs,
+ unsigned *our_results, unsigned dim_or, unsigned recall_at,
+ const tsl::robin_set<unsigned> &active_tags);
+
+DISKANN_DLLEXPORT double calculate_range_search_recall(unsigned num_queries,
+ std::vector<std::vector<uint32_t>> &groundtruth,
+ std::vector<std::vector<uint32_t>> &our_results);
+
+template <typename T>
+inline void load_bin(const std::string &bin_file, std::unique_ptr<T[]> &data, size_t &npts, size_t &dim,
+ size_t offset = 0)
+{
+ T *ptr;
+ load_bin<T>(bin_file, ptr, npts, dim, offset);
+ data.reset(ptr);
+}
+
+template <typename T>
+inline void load_bin(IReaderWrapperSPtr data_source, std::unique_ptr<T[]> &data, size_t &npts, size_t &dim,
+ size_t offset = 0)
+{
+ T *ptr;
+ load_bin<T>(data_source, ptr, npts, dim, offset);
+ data.reset(ptr);
+}
+
+template <typename T>
+inline void load_bin(std::stringstream &data_source, std::unique_ptr<T[]> &data, size_t &npts, size_t &dim,
+ size_t offset = 0)
+{
+ T *ptr;
+ load_bin<T>(data_source, ptr, npts, dim, offset);
+ data.reset(ptr);
+}
+
+inline void open_file_to_write(std::ofstream &writer, const std::string &filename)
+{
+ writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
+ if (!file_exists(filename))
+ writer.open(filename, std::ios::binary | std::ios::out);
+ else
+ writer.open(filename, std::ios::binary | std::ios::in | std::ios::out);
+
+ if (writer.fail())
+ {
+ char buff[1024];
+#ifdef _WINDOWS
+ auto ret = std::to_string(strerror_s(buff, 1024, errno));
+#else
+ auto ret = std::string(strerror_r(errno, buff, 1024));
+#endif
+ std::string error_message =
+ std::string("Failed to open file") + filename + " for write because " + buff + ", ret=" + ret;
+ diskann::cerr << error_message << std::endl;
+ throw diskann::ANNException(error_message, -1);
+ }
+}
+
+template <typename T>
+inline size_t save_bin(const std::string &filename, T *data, size_t npts, size_t ndims, size_t offset = 0)
+{
+ std::ofstream writer;
+ open_file_to_write(writer, filename);
+
+ diskann::cout << "Writing bin: " << filename.c_str() << std::endl;
+ writer.seekp(offset, writer.beg);
+ int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
+ size_t bytes_written = npts * ndims * sizeof(T) + 2 * sizeof(uint32_t);
+ writer.write((char *)&npts_i32, sizeof(int));
+ writer.write((char *)&ndims_i32, sizeof(int));
+ diskann::cout << "bin: #pts = " << npts << ", #dims = " << ndims << ", size = " << bytes_written << "B"
+ << std::endl;
+
+ writer.write((char *)data, npts * ndims * sizeof(T));
+ writer.close();
+ diskann::cout << "Finished writing bin." << std::endl;
+ return bytes_written;
+}
+
+template <typename T>
+inline size_t save_bin(std::stringstream &writer, T *data, size_t npts, size_t ndims, size_t offset = 0)
+{
+ writer.seekp(offset, writer.beg);
+ int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
+ size_t bytes_written = npts * ndims * sizeof(T) + 2 * sizeof(uint32_t);
+ writer.write((char *)&npts_i32, sizeof(int));
+ writer.write((char *)&ndims_i32, sizeof(int));
+ diskann::cout << "bin: #pts = " << npts << ", #dims = " << ndims << ", size = " << bytes_written << "B"
+ << std::endl;
+
+ writer.write((char *)data, npts * ndims * sizeof(T));
+ if (writer.fail()) {
+ std::cerr << "Error: writer is in a failed state!" << std::endl;
+ }
+ diskann::cout << "Finished writing bin." << std::endl;
+ return bytes_written;
+}
+
+inline void print_progress(double percentage)
+{
+ int val = (int)(percentage * 100);
+ int lpad = (int)(percentage * PBWIDTH);
+ int rpad = PBWIDTH - lpad;
+ printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
+ fflush(stdout);
+}
+
+// load_aligned_bin functions START
+
+template <typename T>
+inline void load_aligned_bin_impl(std::basic_istream<char> &reader, size_t actual_file_size, T *&data, size_t &npts,
+ size_t &dim, size_t &rounded_dim)
+{
+ int npts_i32, dim_i32;
+ reader.read((char *)&npts_i32, sizeof(int));
+ reader.read((char *)&dim_i32, sizeof(int));
+ npts = (unsigned)npts_i32;
+ dim = (unsigned)dim_i32;
+
+ size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
+ if (actual_file_size != expected_actual_file_size)
+ {
+ std::stringstream stream;
+ stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
+ << expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
+ << std::endl;
+ diskann::cout << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ rounded_dim = ROUND_UP(dim, 8);
+ diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << ", aligned_dim = " << rounded_dim << "... "
+ << std::flush;
+ size_t allocSize = npts * rounded_dim * sizeof(T);
+ diskann::cout << "allocating aligned memory of " << allocSize << " bytes... " << std::flush;
+ alloc_aligned(((void **)&data), allocSize, 8 * sizeof(T));
+ diskann::cout << "done. Copying data to mem_aligned buffer..." << std::flush;
+
+ for (size_t i = 0; i < npts; i++)
+ {
+ reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
+ memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
+ }
+ diskann::cout << " done." << std::endl;
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T>
+inline void load_aligned_bin(MemoryMappedFiles &files, const std::string &bin_file, T *&data, size_t &npts, size_t &dim,
+ size_t &rounded_dim)
+{
+ try
+ {
+ diskann::cout << "Opening bin file " << bin_file << " ..." << std::flush;
+ FileContent fc = files.getContent(bin_file);
+ ContentBuf buf((char *)fc._content, fc._size);
+ std::basic_istream<char> reader(&buf);
+
+ size_t actual_file_size = fc._size;
+ load_aligned_bin_impl(reader, actual_file_size, data, npts, dim, rounded_dim);
+ }
+ catch (std::system_error &e)
+ {
+ throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+}
+#endif
+
+template <typename T>
+inline void load_aligned_bin(const std::string &bin_file, T *&data, size_t &npts, size_t &dim, size_t &rounded_dim)
+{
+ std::ifstream reader;
+ reader.exceptions(std::ifstream::failbit | std::ifstream::badbit);
+
+ try
+ {
+ diskann::cout << "Reading (with alignment) bin file " << bin_file << " ..." << std::flush;
+ reader.open(bin_file, std::ios::binary | std::ios::ate);
+
+ uint64_t fsize = reader.tellg();
+ reader.seekg(0);
+ load_aligned_bin_impl(reader, fsize, data, npts, dim, rounded_dim);
+ }
+ catch (std::system_error &e)
+ {
+ throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+}
+
+template <typename InType, typename OutType>
+void convert_types(const InType *srcmat, OutType *destmat, size_t npts, size_t dim)
+{
+#pragma omp parallel for schedule(static, 65536)
+ for (int64_t i = 0; i < (int64_t)npts; i++)
+ {
+ for (uint64_t j = 0; j < dim; j++)
+ {
+ destmat[i * dim + j] = (OutType)srcmat[i * dim + j];
+ }
+ }
+}
+
+// this function will take in_file of n*d dimensions and save the output as a
+// floating point matrix
+// with n*(d+1) dimensions. All vectors are scaled by a large value M so that
+// the norms are <=1 and the final coordinate is set so that the resulting
+// norm (in d+1 coordinates) is equal to 1 this is a classical transformation
+// from MIPS to L2 search from "On Symmetric and Asymmetric LSHs for Inner
+// Product Search" by Neyshabur and Srebro
+
+template <typename T> float prepare_base_for_inner_products(const std::string in_file, const std::string out_file)
+{
+ std::cout << "Pre-processing base file by adding extra coordinate" << std::endl;
+ std::ifstream in_reader(in_file.c_str(), std::ios::binary);
+ std::ofstream out_writer(out_file.c_str(), std::ios::binary);
+ uint64_t npts, in_dims, out_dims;
+ float max_norm = 0;
+
+ uint32_t npts32, dims32;
+ in_reader.read((char *)&npts32, sizeof(uint32_t));
+ in_reader.read((char *)&dims32, sizeof(uint32_t));
+
+ npts = npts32;
+ in_dims = dims32;
+ out_dims = in_dims + 1;
+ uint32_t outdims32 = (uint32_t)out_dims;
+
+ out_writer.write((char *)&npts32, sizeof(uint32_t));
+ out_writer.write((char *)&outdims32, sizeof(uint32_t));
+
+ size_t BLOCK_SIZE = 100000;
+ size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE;
+ std::unique_ptr<T[]> in_block_data = std::make_unique<T[]>(block_size * in_dims);
+ std::unique_ptr<float[]> out_block_data = std::make_unique<float[]>(block_size * out_dims);
+
+ std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims);
+ uint64_t num_blocks = DIV_ROUND_UP(npts, block_size);
+
+ std::vector<float> norms(npts, 0);
+
+ for (uint64_t b = 0; b < num_blocks; b++)
+ {
+ uint64_t start_id = b * block_size;
+ uint64_t end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts;
+ uint64_t block_pts = end_id - start_id;
+ in_reader.read((char *)in_block_data.get(), block_pts * in_dims * sizeof(T));
+ for (uint64_t p = 0; p < block_pts; p++)
+ {
+ for (uint64_t j = 0; j < in_dims; j++)
+ {
+ norms[start_id + p] += in_block_data[p * in_dims + j] * in_block_data[p * in_dims + j];
+ }
+ max_norm = max_norm > norms[start_id + p] ? max_norm : norms[start_id + p];
+ }
+ }
+
+ max_norm = std::sqrt(max_norm);
+
+ in_reader.seekg(2 * sizeof(uint32_t), std::ios::beg);
+ for (uint64_t b = 0; b < num_blocks; b++)
+ {
+ uint64_t start_id = b * block_size;
+ uint64_t end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts;
+ uint64_t block_pts = end_id - start_id;
+ in_reader.read((char *)in_block_data.get(), block_pts * in_dims * sizeof(T));
+ for (uint64_t p = 0; p < block_pts; p++)
+ {
+ for (uint64_t j = 0; j < in_dims; j++)
+ {
+ out_block_data[p * out_dims + j] = in_block_data[p * in_dims + j] / max_norm;
+ }
+ float res = 1 - (norms[start_id + p] / (max_norm * max_norm));
+ res = res <= 0 ? 0 : std::sqrt(res);
+ out_block_data[p * out_dims + out_dims - 1] = res;
+ }
+ out_writer.write((char *)out_block_data.get(), block_pts * out_dims * sizeof(float));
+ }
+ out_writer.close();
+ return max_norm;
+}
+
+// plain saves data as npts X ndims array into filename
+template <typename T> void save_Tvecs(const char *filename, T *data, size_t npts, size_t ndims)
+{
+ std::string fname(filename);
+
+ // create cached ofstream with 64MB cache
+ cached_ofstream writer(fname, 64 * 1048576);
+
+ unsigned dims_u32 = (unsigned)ndims;
+
+ // start writing
+ for (size_t i = 0; i < npts; i++)
+ {
+ // write dims in u32
+ writer.write((char *)&dims_u32, sizeof(unsigned));
+
+ // get cur point in data
+ T *cur_pt = data + i * ndims;
+ writer.write((char *)cur_pt, ndims * sizeof(T));
+ }
+}
+template <typename T>
+inline size_t save_data_in_base_dimensions(const std::string &filename, T *data, size_t npts, size_t ndims,
+ size_t aligned_dim, size_t offset = 0)
+{
+ std::ofstream writer; //(filename, std::ios::binary | std::ios::out);
+ open_file_to_write(writer, filename);
+ int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
+ size_t bytes_written = 2 * sizeof(uint32_t) + npts * ndims * sizeof(T);
+ writer.seekp(offset, writer.beg);
+ writer.write((char *)&npts_i32, sizeof(int));
+ writer.write((char *)&ndims_i32, sizeof(int));
+ for (size_t i = 0; i < npts; i++)
+ {
+ writer.write((char *)(data + i * aligned_dim), ndims * sizeof(T));
+ }
+ writer.close();
+ return bytes_written;
+}
+
+template <typename T>
+inline void copy_aligned_data_from_file(const char *bin_file, T *&data, size_t &npts, size_t &dim,
+ const size_t &rounded_dim, size_t offset = 0)
+{
+ if (data == nullptr)
+ {
+ diskann::cerr << "Memory was not allocated for " << data << " before calling the load function. Exiting..."
+ << std::endl;
+ throw diskann::ANNException("Null pointer passed to copy_aligned_data_from_file function", -1, __FUNCSIG__,
+ __FILE__, __LINE__);
+ }
+ std::ifstream reader;
+ reader.exceptions(std::ios::badbit | std::ios::failbit);
+ reader.open(bin_file, std::ios::binary);
+ reader.seekg(offset, reader.beg);
+
+ int npts_i32, dim_i32;
+ reader.read((char *)&npts_i32, sizeof(int));
+ reader.read((char *)&dim_i32, sizeof(int));
+ npts = (unsigned)npts_i32;
+ dim = (unsigned)dim_i32;
+
+ for (size_t i = 0; i < npts; i++)
+ {
+ reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
+ memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
+ }
+}
+
+// NOTE :: good efficiency when total_vec_size is integral multiple of 64
+inline void prefetch_vector(const char *vec, size_t vecsize)
+{
+ size_t max_prefetch_size = (vecsize / 64) * 64;
+ for (size_t d = 0; d < max_prefetch_size; d += 64)
+ _mm_prefetch((const char *)vec + d, _MM_HINT_T0);
+}
+
+// NOTE :: good efficiency when total_vec_size is integral multiple of 64
+inline void prefetch_vector_l2(const char *vec, size_t vecsize)
+{
+ size_t max_prefetch_size = (vecsize / 64) * 64;
+ for (size_t d = 0; d < max_prefetch_size; d += 64)
+ _mm_prefetch((const char *)vec + d, _MM_HINT_T1);
+}
+
+// NOTE: Implementation in utils.cpp.
+void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, uint64_t npts, uint64_t ndims);
+
+DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName);
+
+inline std::string get_tag_string(std::uint64_t tag)
+{
+ return std::to_string(tag);
+}
+
+inline std::string get_tag_string(const tag_uint128 &tag)
+{
+ std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1);
+ return str;
+}
+
+}; // namespace diskann
+
+struct PivotContainer
+{
+ PivotContainer() = default;
+
+ PivotContainer(size_t pivo_id, float pivo_dist) : piv_id{pivo_id}, piv_dist{pivo_dist}
+ {
+ }
+
+ bool operator<(const PivotContainer &p) const
+ {
+ return p.piv_dist < piv_dist;
+ }
+
+ bool operator>(const PivotContainer &p) const
+ {
+ return p.piv_dist > piv_dist;
+ }
+
+ size_t piv_id;
+ float piv_dist;
+};
+
+inline bool validate_index_file_size(std::ifstream &in)
+{
+ if (!in.is_open())
+ throw diskann::ANNException("Index file size check called on unopened file stream", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ in.seekg(0, in.end);
+ size_t actual_file_size = in.tellg();
+ in.seekg(0, in.beg);
+ size_t expected_file_size;
+ in.read((char *)&expected_file_size, sizeof(uint64_t));
+ in.seekg(0, in.beg);
+ if (actual_file_size != expected_file_size)
+ {
+ diskann::cerr << "Index file size error. Expected size (metadata): " << expected_file_size
+ << ", actual file size : " << actual_file_size << "." << std::endl;
+ return false;
+ }
+ return true;
+}
+
+template <typename T> inline float get_norm(T *arr, const size_t dim)
+{
+ float sum = 0.0f;
+ for (uint32_t i = 0; i < dim; i++)
+ {
+ sum += arr[i] * arr[i];
+ }
+ return sqrt(sum);
+}
+
+// This function is valid only for float data type.
+template <typename T = float> inline void normalize(T *arr, const size_t dim)
+{
+ float norm = get_norm(arr, dim);
+ for (uint32_t i = 0; i < dim; i++)
+ {
+ arr[i] = (T)(arr[i] / norm);
+ }
+}
+
+inline std::vector<std::string> read_file_to_vector_of_strings(const std::string &filename, bool unique = false)
+{
+ std::vector<std::string> result;
+ std::set<std::string> elementSet;
+ if (filename != "")
+ {
+ std::ifstream file(filename);
+ if (file.fail())
+ {
+ throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
+ }
+ std::string line;
+ while (std::getline(file, line))
+ {
+ if (line.empty())
+ {
+ break;
+ }
+ if (line.find(',') != std::string::npos)
+ {
+ std::cerr << "Every query must have exactly one filter" << std::endl;
+ exit(-1);
+ }
+ if (!line.empty() && (line.back() == '\r' || line.back() == '\n'))
+ {
+ line.erase(line.size() - 1);
+ }
+ if (!elementSet.count(line))
+ {
+ result.push_back(line);
+ }
+ if (unique)
+ {
+ elementSet.insert(line);
+ }
+ }
+ file.close();
+ }
+ else
+ {
+ throw diskann::ANNException(std::string("Failed to open file. filename can not be blank"), -1);
+ }
+ return result;
+}
+
+inline void clean_up_artifacts(tsl::robin_set<std::string> paths_to_clean, tsl::robin_set<std::string> path_suffixes)
+{
+ try
+ {
+ for (const auto &path : paths_to_clean)
+ {
+ for (const auto &suffix : path_suffixes)
+ {
+ std::string curr_path_to_clean(path + "_" + suffix);
+ if (std::remove(curr_path_to_clean.c_str()) != 0)
+ diskann::cout << "Warning: Unable to remove file :" << curr_path_to_clean << std::endl;
+ }
+ }
+ diskann::cout << "Cleaned all artifacts" << std::endl;
+ }
+ catch (const std::exception &e)
+ {
+ diskann::cout << "Warning: Unable to clean all artifacts " << e.what() << std::endl;
+ }
+}
+
+template <typename T> inline const char *diskann_type_to_name() = delete;
+template <> inline const char *diskann_type_to_name<float>()
+{
+ return "float";
+}
+template <> inline const char *diskann_type_to_name<uint8_t>()
+{
+ return "uint8";
+}
+template <> inline const char *diskann_type_to_name<int8_t>()
+{
+ return "int8";
+}
+template <> inline const char *diskann_type_to_name<uint16_t>()
+{
+ return "uint16";
+}
+template <> inline const char *diskann_type_to_name<int16_t>()
+{
+ return "int16";
+}
+template <> inline const char *diskann_type_to_name<uint32_t>()
+{
+ return "uint32";
+}
+template <> inline const char *diskann_type_to_name<int32_t>()
+{
+ return "int32";
+}
+template <> inline const char *diskann_type_to_name<uint64_t>()
+{
+ return "uint64";
+}
+template <> inline const char *diskann_type_to_name<int64_t>()
+{
+ return "int64";
+}
+
+#ifdef _WINDOWS
+#include <intrin.h>
+#include <Psapi.h>
+
+extern bool AvxSupportedCPU;
+extern bool Avx2SupportedCPU;
+
+inline size_t getMemoryUsage()
+{
+ PROCESS_MEMORY_COUNTERS_EX pmc;
+ GetProcessMemoryInfo(GetCurrentProcess(), (PROCESS_MEMORY_COUNTERS *)&pmc, sizeof(pmc));
+ return pmc.PrivateUsage;
+}
+
+inline std::string getWindowsErrorMessage(DWORD lastError)
+{
+ char *errorText;
+ FormatMessageA(
+ // use system message tables to retrieve error text
+ FORMAT_MESSAGE_FROM_SYSTEM
+ // allocate buffer on local heap for error text
+ | FORMAT_MESSAGE_ALLOCATE_BUFFER
+ // Important! will fail otherwise, since we're not
+ // (and CANNOT) pass insertion parameters
+ | FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL, // unused with FORMAT_MESSAGE_FROM_SYSTEM
+ lastError, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
+ (LPSTR)&errorText, // output
+ 0, // minimum size for output buffer
+ NULL); // arguments - see note
+
+ return errorText != nullptr ? std::string(errorText) : std::string();
+}
+
+inline void printProcessMemory(const char *message)
+{
+ PROCESS_MEMORY_COUNTERS counters;
+ HANDLE h = GetCurrentProcess();
+ GetProcessMemoryInfo(h, &counters, sizeof(counters));
+ diskann::cout << message
+ << " [Peaking Working Set size: " << counters.PeakWorkingSetSize * 1.0 / (1024.0 * 1024 * 1024)
+ << "GB Working set size: " << counters.WorkingSetSize * 1.0 / (1024.0 * 1024 * 1024)
+ << "GB Private bytes " << counters.PagefileUsage * 1.0 / (1024 * 1024 * 1024) << "GB]" << std::endl;
+}
+#else
+
+// need to check and change this
+inline bool avx2Supported()
+{
+ return true;
+}
+inline void printProcessMemory(const char *)
+{
+}
+
+inline size_t getMemoryUsage()
+{ // for non-windows, we have not implemented this function
+ return 0;
+}
+
+#endif
+
+extern bool AvxSupportedCPU;
+extern bool Avx2SupportedCPU;
diff --git a/be/src/extern/diskann/include/windows_aligned_file_reader.h b/be/src/extern/diskann/include/windows_aligned_file_reader.h
new file mode 100644
index 0000000..0d9a317
--- /dev/null
+++ b/be/src/extern/diskann/include/windows_aligned_file_reader.h
@@ -0,0 +1,57 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+#ifdef _WINDOWS
+#ifndef USE_BING_INFRA
+#include <Windows.h>
+#include <fcntl.h>
+#include <malloc.h>
+#include <minwinbase.h>
+
+#include <cstdio>
+#include <mutex>
+#include <thread>
+#include "aligned_file_reader.h"
+#include "tsl/robin_map.h"
+#include "utils.h"
+#include "windows_customizations.h"
+
+class WindowsAlignedFileReader : public AlignedFileReader
+{
+ private:
+#ifdef UNICODE
+ std::wstring m_filename;
+#else
+ std::string m_filename;
+#endif
+
+ protected:
+ // virtual IOContext createContext();
+
+ public:
+ DISKANN_DLLEXPORT WindowsAlignedFileReader(){};
+ DISKANN_DLLEXPORT virtual ~WindowsAlignedFileReader(){};
+
+ // Open & close ops
+ // Blocking calls
+ DISKANN_DLLEXPORT virtual void open(const std::string &fname) override;
+ DISKANN_DLLEXPORT virtual void close() override;
+
+ DISKANN_DLLEXPORT virtual void register_thread() override;
+ DISKANN_DLLEXPORT virtual void deregister_thread() override
+ {
+ // TODO: Needs implementation.
+ }
+ DISKANN_DLLEXPORT virtual void deregister_all_threads() override
+ {
+ // TODO: Needs implementation.
+ }
+ DISKANN_DLLEXPORT virtual IOContext &get_ctx() override;
+
+ // process batch of aligned requests in parallel
+ // NOTE :: blocking call for the calling thread, but can thread-safe
+ DISKANN_DLLEXPORT virtual void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async) override;
+};
+#endif // USE_BING_INFRA
+#endif //_WINDOWS
diff --git a/be/src/extern/diskann/include/windows_customizations.h b/be/src/extern/diskann/include/windows_customizations.h
new file mode 100644
index 0000000..e6c5846
--- /dev/null
+++ b/be/src/extern/diskann/include/windows_customizations.h
@@ -0,0 +1,16 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#pragma once
+
+#ifdef _WINDOWS
+
+#ifdef _WINDLL
+#define DISKANN_DLLEXPORT __declspec(dllexport)
+#else
+#define DISKANN_DLLEXPORT __declspec(dllimport)
+#endif
+
+#else
+#define DISKANN_DLLEXPORT
+#endif
diff --git a/be/src/extern/diskann/include/windows_slim_lock.h b/be/src/extern/diskann/include/windows_slim_lock.h
new file mode 100644
index 0000000..7fc09b8
--- /dev/null
+++ b/be/src/extern/diskann/include/windows_slim_lock.h
@@ -0,0 +1,73 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+#pragma once
+
+#ifndef WIN32_LEAN_AND_MEAN
+#define WIN32_LEAN_AND_MEAN
+#endif
+#include "Windows.h"
+
+namespace diskann
+{
+// A thin C++ wrapper around Windows exclusive functionality of Windows
+// SlimReaderWriterLock.
+//
+// The SlimReaderWriterLock is simpler/more lightweight than std::mutex
+// (8 bytes vs 80 bytes), which is useful in the scenario where DiskANN has
+// one lock per vector in the index. It does not support recursive locking and
+// requires Windows Vista or later.
+//
+// Full documentation can be found at.
+// https://msdn.microsoft.com/en-us/library/windows/desktop/aa904937(v=vs.85).aspx
+class windows_exclusive_slim_lock
+{
+ public:
+ windows_exclusive_slim_lock() : _lock(SRWLOCK_INIT)
+ {
+ }
+
+ // The lock is non-copyable. This also disables move constructor/operator=.
+ windows_exclusive_slim_lock(const windows_exclusive_slim_lock &) = delete;
+ windows_exclusive_slim_lock &operator=(const windows_exclusive_slim_lock &) = delete;
+
+ void lock()
+ {
+ return AcquireSRWLockExclusive(&_lock);
+ }
+
+ bool try_lock()
+ {
+ return TryAcquireSRWLockExclusive(&_lock) != FALSE;
+ }
+
+ void unlock()
+ {
+ return ReleaseSRWLockExclusive(&_lock);
+ }
+
+ private:
+ SRWLOCK _lock;
+};
+
+// An exclusive lock over a SlimReaderWriterLock.
+class windows_exclusive_slim_lock_guard
+{
+ public:
+ windows_exclusive_slim_lock_guard(windows_exclusive_slim_lock &p_lock) : _lock(p_lock)
+ {
+ _lock.lock();
+ }
+
+ // The lock is non-copyable. This also disables move constructor/operator=.
+ windows_exclusive_slim_lock_guard(const windows_exclusive_slim_lock_guard &) = delete;
+ windows_exclusive_slim_lock_guard &operator=(const windows_exclusive_slim_lock_guard &) = delete;
+
+ ~windows_exclusive_slim_lock_guard()
+ {
+ _lock.unlock();
+ }
+
+ private:
+ windows_exclusive_slim_lock &_lock;
+};
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/abstract_data_store.cpp b/be/src/extern/diskann/src/abstract_data_store.cpp
new file mode 100644
index 0000000..0cff015
--- /dev/null
+++ b/be/src/extern/diskann/src/abstract_data_store.cpp
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <vector>
+#include "abstract_data_store.h"
+
+namespace diskann
+{
+
+template <typename data_t>
+AbstractDataStore<data_t>::AbstractDataStore(const location_t capacity, const size_t dim)
+ : _capacity(capacity), _dim(dim)
+{
+}
+
+template <typename data_t> location_t AbstractDataStore<data_t>::capacity() const
+{
+ return _capacity;
+}
+
+template <typename data_t> size_t AbstractDataStore<data_t>::get_dims() const
+{
+ return _dim;
+}
+
+template <typename data_t> location_t AbstractDataStore<data_t>::resize(const location_t new_num_points)
+{
+ if (new_num_points > _capacity)
+ {
+ return expand(new_num_points);
+ }
+ else if (new_num_points < _capacity)
+ {
+ return shrink(new_num_points);
+ }
+ else
+ {
+ return _capacity;
+ }
+}
+
+template DISKANN_DLLEXPORT class AbstractDataStore<float>;
+template DISKANN_DLLEXPORT class AbstractDataStore<int8_t>;
+template DISKANN_DLLEXPORT class AbstractDataStore<uint8_t>;
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/abstract_index.cpp b/be/src/extern/diskann/src/abstract_index.cpp
new file mode 100644
index 0000000..9266582
--- /dev/null
+++ b/be/src/extern/diskann/src/abstract_index.cpp
@@ -0,0 +1,334 @@
+#include "common_includes.h"
+#include "windows_customizations.h"
+#include "abstract_index.h"
+
+namespace diskann
+{
+
+template <typename data_type, typename tag_type>
+void AbstractIndex::build(const data_type *data, const size_t num_points_to_load, const std::vector<tag_type> &tags)
+{
+ auto any_data = std::any(data);
+ auto any_tags_vec = TagVector(tags);
+ this->_build(any_data, num_points_to_load, any_tags_vec);
+}
+
+template <typename data_type, typename IDType>
+std::pair<uint32_t, uint32_t> AbstractIndex::search(const data_type *query, const size_t K, const uint32_t L,
+ IDType *indices, float *distances)
+{
+ auto any_indices = std::any(indices);
+ auto any_query = std::any(query);
+ return _search(any_query, K, L, any_indices, distances);
+}
+
+template <typename data_type, typename tag_type>
+size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
+ float *distances, std::vector<data_type *> &res_vectors, bool use_filters,
+ const std::string filter_label)
+{
+ auto any_query = std::any(query);
+ auto any_tags = std::any(tags);
+ auto any_res_vectors = DataVector(res_vectors);
+ return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_label);
+}
+
+template <typename IndexType>
+std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label,
+ const size_t K, const uint32_t L, IndexType *indices,
+ float *distances)
+{
+ auto any_indices = std::any(indices);
+ return _search_with_filters(query, raw_label, K, L, any_indices, distances);
+}
+
+template <typename data_type>
+void AbstractIndex::search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices)
+{
+ auto any_query = std::any(query);
+ this->_search_with_optimized_layout(any_query, K, L, indices);
+}
+
+template <typename data_type, typename tag_type>
+int AbstractIndex::insert_point(const data_type *point, const tag_type tag)
+{
+ auto any_point = std::any(point);
+ auto any_tag = std::any(tag);
+ return this->_insert_point(any_point, any_tag);
+}
+
+template <typename data_type, typename tag_type, typename label_type>
+int AbstractIndex::insert_point(const data_type *point, const tag_type tag, const std::vector<label_type> &labels)
+{
+ auto any_point = std::any(point);
+ auto any_tag = std::any(tag);
+ auto any_labels = Labelvector(labels);
+ return this->_insert_point(any_point, any_tag, any_labels);
+}
+
+template <typename tag_type> int AbstractIndex::lazy_delete(const tag_type &tag)
+{
+ auto any_tag = std::any(tag);
+ return this->_lazy_delete(any_tag);
+}
+
+template <typename tag_type>
+void AbstractIndex::lazy_delete(const std::vector<tag_type> &tags, std::vector<tag_type> &failed_tags)
+{
+ auto any_tags = TagVector(tags);
+ auto any_failed_tags = TagVector(failed_tags);
+ this->_lazy_delete(any_tags, any_failed_tags);
+}
+
+template <typename tag_type> void AbstractIndex::get_active_tags(tsl::robin_set<tag_type> &active_tags)
+{
+ auto any_active_tags = TagRobinSet(active_tags);
+ this->_get_active_tags(any_active_tags);
+}
+
+template <typename data_type> void AbstractIndex::set_start_points_at_random(data_type radius, uint32_t random_seed)
+{
+ auto any_radius = std::any(radius);
+ this->_set_start_points_at_random(any_radius, random_seed);
+}
+
+template <typename tag_type, typename data_type> int AbstractIndex::get_vector_by_tag(tag_type &tag, data_type *vec)
+{
+ auto any_tag = std::any(tag);
+ auto any_data_ptr = std::any(vec);
+ return this->_get_vector_by_tag(any_tag, any_data_ptr);
+}
+
+template <typename label_type> void AbstractIndex::set_universal_label(const label_type universal_label)
+{
+ auto any_label = std::any(universal_label);
+ this->_set_universal_label(any_label);
+}
+
+// exports
+template DISKANN_DLLEXPORT void AbstractIndex::build<float, int32_t>(const float *data, const size_t num_points_to_load,
+ const std::vector<int32_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<int8_t, int32_t>(const int8_t *data,
+ const size_t num_points_to_load,
+ const std::vector<int32_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<uint8_t, int32_t>(const uint8_t *data,
+ const size_t num_points_to_load,
+ const std::vector<int32_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<float, uint32_t>(const float *data,
+ const size_t num_points_to_load,
+ const std::vector<uint32_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<int8_t, uint32_t>(const int8_t *data,
+ const size_t num_points_to_load,
+ const std::vector<uint32_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<uint8_t, uint32_t>(const uint8_t *data,
+ const size_t num_points_to_load,
+ const std::vector<uint32_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<float, int64_t>(const float *data, const size_t num_points_to_load,
+ const std::vector<int64_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<int8_t, int64_t>(const int8_t *data,
+ const size_t num_points_to_load,
+ const std::vector<int64_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<uint8_t, int64_t>(const uint8_t *data,
+ const size_t num_points_to_load,
+ const std::vector<int64_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<float, uint64_t>(const float *data,
+ const size_t num_points_to_load,
+ const std::vector<uint64_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<int8_t, uint64_t>(const int8_t *data,
+ const size_t num_points_to_load,
+ const std::vector<uint64_t> &tags);
+template DISKANN_DLLEXPORT void AbstractIndex::build<uint8_t, uint64_t>(const uint8_t *data,
+ const size_t num_points_to_load,
+ const std::vector<uint64_t> &tags);
+
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<float, uint32_t>(
+ const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<uint8_t, uint32_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<int8_t, uint32_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<float, uint64_t>(
+ const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<uint8_t, uint64_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<int8_t, uint64_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
+ const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint64_t>(
+ const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(
+ const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
+ std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, int32_t>(
+ const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
+ std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int32_t>(
+ const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
+ std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint32_t>(
+ const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
+ std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, uint32_t>(
+ const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
+ std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint32_t>(
+ const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
+ std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int64_t>(
+ const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
+ std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, int64_t>(
+ const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
+ std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int64_t>(
+ const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
+ std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint64_t>(
+ const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
+ std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, uint64_t>(
+ const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
+ std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint64_t>(
+ const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
+ std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);
+
+template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout<float>(const float *query, size_t K,
+ size_t L, uint32_t *indices);
+template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout<uint8_t>(const uint8_t *query, size_t K,
+ size_t L, uint32_t *indices);
+template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout<int8_t>(const int8_t *query, size_t K,
+ size_t L, uint32_t *indices);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int32_t>(const float *point, const int32_t tag);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int32_t>(const uint8_t *point, const int32_t tag);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int32_t>(const int8_t *point, const int32_t tag);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint32_t>(const float *point, const uint32_t tag);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint32_t>(const uint8_t *point, const uint32_t tag);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint32_t>(const int8_t *point, const uint32_t tag);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int64_t>(const float *point, const int64_t tag);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int64_t>(const uint8_t *point, const int64_t tag);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int64_t>(const int8_t *point, const int64_t tag);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint64_t>(const float *point, const uint64_t tag);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint64_t>(const uint8_t *point, const uint64_t tag);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint64_t>(const int8_t *point, const uint64_t tag);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int32_t, uint16_t>(
+ const float *point, const int32_t tag, const std::vector<uint16_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int32_t, uint16_t>(
+ const uint8_t *point, const int32_t tag, const std::vector<uint16_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int32_t, uint16_t>(
+ const int8_t *point, const int32_t tag, const std::vector<uint16_t> &labels);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint32_t, uint16_t>(
+ const float *point, const uint32_t tag, const std::vector<uint16_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint32_t, uint16_t>(
+ const uint8_t *point, const uint32_t tag, const std::vector<uint16_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint32_t, uint16_t>(
+ const int8_t *point, const uint32_t tag, const std::vector<uint16_t> &labels);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int64_t, uint16_t>(
+ const float *point, const int64_t tag, const std::vector<uint16_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int64_t, uint16_t>(
+ const uint8_t *point, const int64_t tag, const std::vector<uint16_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int64_t, uint16_t>(
+ const int8_t *point, const int64_t tag, const std::vector<uint16_t> &labels);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint64_t, uint16_t>(
+ const float *point, const uint64_t tag, const std::vector<uint16_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint64_t, uint16_t>(
+ const uint8_t *point, const uint64_t tag, const std::vector<uint16_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint64_t, uint16_t>(
+ const int8_t *point, const uint64_t tag, const std::vector<uint16_t> &labels);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int32_t, uint32_t>(
+ const float *point, const int32_t tag, const std::vector<uint32_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int32_t, uint32_t>(
+ const uint8_t *point, const int32_t tag, const std::vector<uint32_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int32_t, uint32_t>(
+ const int8_t *point, const int32_t tag, const std::vector<uint32_t> &labels);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint32_t, uint32_t>(
+ const float *point, const uint32_t tag, const std::vector<uint32_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint32_t, uint32_t>(
+ const uint8_t *point, const uint32_t tag, const std::vector<uint32_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint32_t, uint32_t>(
+ const int8_t *point, const uint32_t tag, const std::vector<uint32_t> &labels);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int64_t, uint32_t>(
+ const float *point, const int64_t tag, const std::vector<uint32_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int64_t, uint32_t>(
+ const uint8_t *point, const int64_t tag, const std::vector<uint32_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int64_t, uint32_t>(
+ const int8_t *point, const int64_t tag, const std::vector<uint32_t> &labels);
+
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint64_t, uint32_t>(
+ const float *point, const uint64_t tag, const std::vector<uint32_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint64_t, uint32_t>(
+ const uint8_t *point, const uint64_t tag, const std::vector<uint32_t> &labels);
+template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint64_t, uint32_t>(
+ const int8_t *point, const uint64_t tag, const std::vector<uint32_t> &labels);
+
+template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete<int32_t>(const int32_t &tag);
+template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete<uint32_t>(const uint32_t &tag);
+template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete<int64_t>(const int64_t &tag);
+template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete<uint64_t>(const uint64_t &tag);
+
+template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete<int32_t>(const std::vector<int32_t> &tags,
+ std::vector<int32_t> &failed_tags);
+template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete<uint32_t>(const std::vector<uint32_t> &tags,
+ std::vector<uint32_t> &failed_tags);
+template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete<int64_t>(const std::vector<int64_t> &tags,
+ std::vector<int64_t> &failed_tags);
+template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete<uint64_t>(const std::vector<uint64_t> &tags,
+ std::vector<uint64_t> &failed_tags);
+
+template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags<int32_t>(tsl::robin_set<int32_t> &active_tags);
+template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags<uint32_t>(tsl::robin_set<uint32_t> &active_tags);
+template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags<int64_t>(tsl::robin_set<int64_t> &active_tags);
+template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags<uint64_t>(tsl::robin_set<uint64_t> &active_tags);
+
+template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random<float>(float radius, uint32_t random_seed);
+template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random<uint8_t>(uint8_t radius,
+ uint32_t random_seed);
+template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random<int8_t>(int8_t radius, uint32_t random_seed);
+
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<int32_t, float>(int32_t &tag, float *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<int32_t, uint8_t>(int32_t &tag, uint8_t *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<int32_t, int8_t>(int32_t &tag, int8_t *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<uint32_t, float>(uint32_t &tag, float *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<uint32_t, uint8_t>(uint32_t &tag, uint8_t *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<uint32_t, int8_t>(uint32_t &tag, int8_t *vec);
+
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<int64_t, float>(int64_t &tag, float *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<int64_t, uint8_t>(int64_t &tag, uint8_t *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<int64_t, int8_t>(int64_t &tag, int8_t *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<uint64_t, float>(uint64_t &tag, float *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<uint64_t, uint8_t>(uint64_t &tag, uint8_t *vec);
+template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<uint64_t, int8_t>(uint64_t &tag, int8_t *vec);
+
+template DISKANN_DLLEXPORT void AbstractIndex::set_universal_label<uint16_t>(const uint16_t label);
+template DISKANN_DLLEXPORT void AbstractIndex::set_universal_label<uint32_t>(const uint32_t label);
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/ann_exception.cpp b/be/src/extern/diskann/src/ann_exception.cpp
new file mode 100644
index 0000000..ba55e36
--- /dev/null
+++ b/be/src/extern/diskann/src/ann_exception.cpp
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include "ann_exception.h"
+#include <sstream>
+#include <string>
+
+namespace diskann
+{
+ANNException::ANNException(const std::string &message, int errorCode)
+ : std::runtime_error(message), _errorCode(errorCode)
+{
+}
+
+std::string package_string(const std::string &item_name, const std::string &item_val)
+{
+ return std::string("[") + item_name + ": " + std::string(item_val) + std::string("]");
+}
+
+ANNException::ANNException(const std::string &message, int errorCode, const std::string &funcSig,
+ const std::string &fileName, uint32_t lineNum)
+ : ANNException(package_string(std::string("FUNC"), funcSig) + package_string(std::string("FILE"), fileName) +
+ package_string(std::string("LINE"), std::to_string(lineNum)) + " " + message,
+ errorCode)
+{
+}
+
+FileException::FileException(const std::string &filename, std::system_error &e, const std::string &funcSig,
+ const std::string &fileName, uint32_t lineNum)
+ : ANNException(std::string(" While opening file \'") + filename + std::string("\', error code: ") +
+ std::to_string(e.code().value()) + " " + e.code().message(),
+ e.code().value(), funcSig, fileName, lineNum)
+{
+}
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/src/combined_file.cpp b/be/src/extern/diskann/src/combined_file.cpp
new file mode 100644
index 0000000..b52453c
--- /dev/null
+++ b/be/src/extern/diskann/src/combined_file.cpp
@@ -0,0 +1,218 @@
+#include <iostream>
+#include <fstream>
+#include <string>
+#include <vector>
+#include <cstring>
+#include <cstdio>
+#include <map>
+#include <memory>
+#include "combined_file.h"
+#include <cstdio>
+
+namespace diskann
+{
+
+ LocalFileReader::LocalFileReader(const std::string& file_path, uint64_t start, uint64_t size) {
+ this->open(file_path);
+ _base_offset = start;
+ _size = size;
+ }
+
+
+ bool LocalFileReader::open(const std::string& file_path) {
+ if (!in.is_open()) {
+ _file_path = file_path;
+ in.open(file_path, std::ios::binary);
+ return in.is_open();
+ }
+ return true;
+ }
+
+ uint64_t LocalFileReader::get_file_size() {
+ if (_size == 0){
+ in.seekg(0, std::ios::end);
+ uint64_t file_size = static_cast<uint64_t>(in.tellg());
+ return file_size;
+ } else {
+ return _size;
+ }
+ }
+
+ size_t LocalFileReader::read(char* buffer, uint64_t offset, size_t size) {
+ if (in.is_open()) {
+ //如果_size大于说明的文件中的一段,则读取的大小不能超过_size
+ if (_size > 0 && size > _size) {
+ size = _size;
+ }
+ {
+ std::lock_guard<std::mutex> lock(_mutex);
+ // 这里将传入的offset加上base_offset来正确定位读取位置
+ in.seekg(_base_offset + offset, std::ios::beg);
+ in.read(buffer, size);
+ }
+ return in.gcount();
+ }
+ return 0;
+ }
+
+ void LocalFileReader::seek(uint64_t offset) {
+ if (in.is_open()) {
+ in.seekg(offset, std::ios::beg);
+ }
+ }
+
+ void LocalFileReader::close() {
+ if (in.is_open()) {
+ in.close();
+ }
+ }
+
+ LocalFileIOWriter::LocalFileIOWriter(const std::string& file_path): out(file_path, std::ios::binary){
+
+ }
+
+ void LocalFileIOWriter::write(const char* data, size_t size) {
+ if (out.is_open()) {
+ out.write(data, size);
+ }
+ }
+ void LocalFileIOWriter::close() {
+ if (out.is_open()) {
+ out.close();
+ }
+ }
+
+
+ // 添加文件方法
+ bool FileMerger::add(const std::string& alias, const std::string& path) {
+ FileInfo info;
+ info.alias = alias;
+ info.path = path;
+ info.offset = current_offset;
+
+ std::ifstream file(path, std::ios::binary | std::ios::ate);
+ if (file.is_open()) {
+ info.size = static_cast<uint64_t>(file.tellg());
+ if (info.size <= 0){
+ std::cerr << "file can not empty: " << path << std::endl;
+ return false;
+ }
+ file.close();
+ current_offset += info.size;
+ files.push_back(info);
+ } else {
+ std::cerr << "无法打开文件: " << path << std::endl;
+ return false;
+ }
+ return true;
+ }
+
+ // 把最终合并的文件写到磁盘,文件的meta信息要放到最后,且方便反序列化,同时记录total_meta_size并写入磁盘末尾
+ void FileMerger::save(IOWriter* writer) {
+ // 先写入文件内容
+ for (const auto& file : files) {
+ std::ifstream in(file.path, std::ios::binary);
+ if (in.is_open()) {
+ char buffer[4096];
+ while (!in.eof()) {
+ in.read(buffer, sizeof(buffer));
+ writer->write(buffer, in.gcount());
+ }
+ in.close();
+ } else {
+ std::cerr << "无法打开文件进行读取: " << file.path << std::endl;
+ }
+ }
+
+ // 记录文件个数,方便后续反序列化时知道有多少个meta信息
+ uint32_t file_count = files.size();
+ writer->write(reinterpret_cast<const char*>(&file_count), sizeof(uint32_t));
+
+ // 计算meta信息总大小
+ uint64_t total_meta_size = sizeof(file_count);
+ for (const auto& file : files) {
+ total_meta_size += sizeof(uint32_t) + file.alias.size() + sizeof(file.offset) + sizeof(file.size);
+ }
+
+ // 再写入每个文件的meta信息(别名、offset、size)到磁盘最后
+ for (const auto& file : files) {
+ uint32_t alias_len = file.alias.size();
+ writer->write(reinterpret_cast<const char*>(&alias_len), sizeof(uint32_t));
+ writer->write(file.alias.c_str(), alias_len);
+ writer->write(reinterpret_cast<const char*>(&file.offset), sizeof(file.offset));
+ writer->write(reinterpret_cast<const char*>(&file.size), sizeof(file.size));
+ }
+ // 将total_meta_size写入磁盘末尾(占8个字节)
+ //std::cout <<"write total_meta_size:" << total_meta_size << std::endl;
+ writer->write(reinterpret_cast<const char*>(&total_meta_size), sizeof(total_meta_size));
+ writer->close();
+ }
+
+ // 根据文件名原始文件别名, 获取base_offset和文件大小,并根据base_offset和文件创建一个reader,返回这个reader
+ // 这里从磁盘读取文件相关offset和size信息(假设磁盘文件格式符合之前定义的规则),先获取total_meta_size来定位meta起始位置
+ // 同时添加了缓存功能,记录已经找到的reader,如果发现之前实例化过则直接返回
+ template <typename ReaderType>
+ std::shared_ptr<ReaderType> FileMerger::get_reader(const std::string& alias, const std::string& merged_file_path) {
+ // 先查看缓存中是否已经存在该别名对应的reader
+ auto it = reader_cache.find(alias);
+ if (it!= reader_cache.end()) {
+ return std::unique_ptr<ReaderType>(static_cast<ReaderType*>(it->second.get()));
+ }
+
+ ReaderType merged_file(merged_file_path);
+
+ //文件文件(文件大小)
+ uint64_t file_size = merged_file.get_file_size();
+
+ uint64_t total_meta_size;
+ uint64_t total_meta_size_start = file_size - sizeof(total_meta_size);
+ merged_file.read(reinterpret_cast<char*>(&total_meta_size), total_meta_size_start, sizeof(total_meta_size));
+
+ // 计算meta信息起始位置
+ uint64_t meta_start = file_size - total_meta_size - sizeof(total_meta_size);
+
+ // 读取文件个数
+ int file_count;
+ merged_file.read(reinterpret_cast<char*>(&file_count), meta_start, sizeof(file_count));
+
+ // 查找对应别名的文件meta信息
+ uint64_t meta_offset = meta_start + sizeof(file_count);
+ for (int i = 0; i < file_count; ++i) {
+ int alias_len;
+ merged_file.read(reinterpret_cast<char*>(&alias_len), meta_offset, sizeof(alias_len));
+ meta_offset += sizeof(alias_len);
+ char* buffer = new char[alias_len];
+ merged_file.read(buffer, meta_offset, alias_len);
+ meta_offset += alias_len;
+ std::string read_alias(buffer, alias_len);
+ delete[] buffer;
+
+ uint64_t offset;
+ merged_file.read(reinterpret_cast<char*>(&offset), meta_offset, sizeof(offset));
+ meta_offset += sizeof(offset);
+ uint64_t size;
+ merged_file.read(reinterpret_cast<char*>(&size), meta_offset, sizeof(size));
+ meta_offset += sizeof(size);
+ //std::cout << "alias:" << read_alias << ", offset:" << offset << ", size:" << size << std::endl;
+ if (read_alias == alias) {
+ std::shared_ptr<ReaderType> reader = std::make_shared<ReaderType>(merged_file_path, offset, size);
+ // 将新实例化的reader添加到缓存中
+ reader_cache[alias] = reader;
+ merged_file.close();
+ return reader;
+ }
+ }
+
+ std::cerr << "未找到对应别名的文件: " << alias << std::endl;
+ merged_file.close();
+ return nullptr;
+ }
+
+ void FileMerger::clear(){
+ for (const auto& file : files) {
+ remove(file.path.c_str());
+ }
+ }
+
+ template std::shared_ptr<LocalFileReader> FileMerger::get_reader<LocalFileReader>(const std::string& alias, const std::string& merged_file_path);
+}
\ No newline at end of file
diff --git a/be/src/extern/diskann/src/disk_utils.cpp b/be/src/extern/diskann/src/disk_utils.cpp
new file mode 100644
index 0000000..c9db73f
--- /dev/null
+++ b/be/src/extern/diskann/src/disk_utils.cpp
@@ -0,0 +1,1841 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include "common_includes.h"
+
+#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
+#include "gperftools/malloc_extension.h"
+#endif
+
+#include "logger.h"
+#include "disk_utils.h"
+#include "cached_io.h"
+#include "index.h"
+#include "mkl.h"
+#include "omp.h"
+#include "percentile_stats.h"
+#include "partition.h"
+#include "pq_flash_index.h"
+#include "timer.h"
+#include "tsl/robin_set.h"
+
+namespace diskann
+{
+
+void add_new_file_to_single_index(std::string index_file, std::string new_file)
+{
+ std::unique_ptr<uint64_t[]> metadata;
+ uint64_t nr, nc;
+ diskann::load_bin<uint64_t>(index_file, metadata, nr, nc);
+ if (nc != 1)
+ {
+ std::stringstream stream;
+ stream << "Error, index file specified does not have correct metadata. " << std::endl;
+ throw diskann::ANNException(stream.str(), -1);
+ }
+ size_t index_ending_offset = metadata[nr - 1];
+ size_t read_blk_size = 64 * 1024 * 1024;
+ cached_ofstream writer(index_file, read_blk_size);
+ size_t check_file_size = get_file_size(index_file);
+ if (check_file_size != index_ending_offset)
+ {
+ std::stringstream stream;
+ stream << "Error, index file specified does not have correct metadata "
+ "(last entry must match the filesize). "
+ << std::endl;
+ throw diskann::ANNException(stream.str(), -1);
+ }
+
+ cached_ifstream reader(new_file, read_blk_size);
+ size_t fsize = reader.get_file_size();
+ if (fsize == 0)
+ {
+ std::stringstream stream;
+ stream << "Error, new file specified is empty. Not appending. " << std::endl;
+ throw diskann::ANNException(stream.str(), -1);
+ }
+
+ size_t num_blocks = DIV_ROUND_UP(fsize, read_blk_size);
+ char *dump = new char[read_blk_size];
+ for (uint64_t i = 0; i < num_blocks; i++)
+ {
+ size_t cur_block_size =
+ read_blk_size > fsize - (i * read_blk_size) ? fsize - (i * read_blk_size) : read_blk_size;
+ reader.read(dump, cur_block_size);
+ writer.write(dump, cur_block_size);
+ }
+ // reader.close();
+ // writer.close();
+
+ delete[] dump;
+ std::vector<uint64_t> new_meta;
+ for (uint64_t i = 0; i < nr; i++)
+ new_meta.push_back(metadata[i]);
+ new_meta.push_back(metadata[nr - 1] + fsize);
+
+ diskann::save_bin<uint64_t>(index_file, new_meta.data(), new_meta.size(), 1);
+}
+
+double get_memory_budget(double search_ram_budget)
+{
+ double final_index_ram_limit = search_ram_budget;
+ if (search_ram_budget - SPACE_FOR_CACHED_NODES_IN_GB > THRESHOLD_FOR_CACHING_IN_GB)
+ { // slack for space used by cached
+ // nodes
+ final_index_ram_limit = search_ram_budget - SPACE_FOR_CACHED_NODES_IN_GB;
+ }
+ return final_index_ram_limit * 1024 * 1024 * 1024;
+}
+
+double get_memory_budget(const std::string &mem_budget_str)
+{
+ double search_ram_budget = atof(mem_budget_str.c_str());
+ return get_memory_budget(search_ram_budget);
+}
+
+size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, uint32_t dim,
+ const std::vector<std::string> ¶m_list)
+{
+ size_t num_pq_chunks = (size_t)(std::floor)(uint64_t(final_index_ram_limit / (double)points_num));
+ diskann::cout << "Calculated num_pq_chunks :" << num_pq_chunks << std::endl;
+ if (param_list.size() >= 6)
+ {
+ float compress_ratio = (float)atof(param_list[5].c_str());
+ if (compress_ratio > 0 && compress_ratio <= 1)
+ {
+ size_t chunks_by_cr = (size_t)(std::floor)(compress_ratio * dim);
+
+ if (chunks_by_cr > 0 && chunks_by_cr < num_pq_chunks)
+ {
+ diskann::cout << "Compress ratio:" << compress_ratio << " new #pq_chunks:" << chunks_by_cr << std::endl;
+ num_pq_chunks = chunks_by_cr;
+ }
+ else
+ {
+ diskann::cout << "Compress ratio: " << compress_ratio << " #new pq_chunks: " << chunks_by_cr
+ << " is either zero or greater than num_pq_chunks: " << num_pq_chunks
+ << ". num_pq_chunks is unchanged. " << std::endl;
+ }
+ }
+ else
+ {
+ diskann::cerr << "Compression ratio: " << compress_ratio << " should be in (0,1]" << std::endl;
+ }
+ }
+
+ num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks;
+ num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks;
+ num_pq_chunks = num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks;
+
+ diskann::cout << "Compressing " << dim << "-dimensional data into " << num_pq_chunks << " bytes per vector."
+ << std::endl;
+ return num_pq_chunks;
+}
+
+
+size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, uint32_t dim)
+{
+ size_t num_pq_chunks = (size_t)(std::floor)(uint64_t(final_index_ram_limit / (double)points_num));
+ diskann::cout << "Calculated num_pq_chunks :" << num_pq_chunks << std::endl;
+ num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks;
+ num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks;
+ num_pq_chunks = num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks;
+
+ diskann::cout << "Compressing " << dim << "-dimensional data into " << num_pq_chunks << " bytes per vector."
+ << std::endl;
+ return num_pq_chunks;
+}
+
+template <typename T> T *generateRandomWarmup(uint64_t warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim)
+{
+ T *warmup = nullptr;
+ warmup_num = 100000;
+ diskann::cout << "Generating random warmup file with dim " << warmup_dim << " and aligned dim "
+ << warmup_aligned_dim << std::flush;
+ diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
+ std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ std::uniform_int_distribution<> dis(-128, 127);
+ for (uint32_t i = 0; i < warmup_num; i++)
+ {
+ for (uint32_t d = 0; d < warmup_dim; d++)
+ {
+ warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
+ }
+ }
+ diskann::cout << "..done" << std::endl;
+ return warmup;
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T>
+T *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, uint64_t &warmup_num,
+ uint64_t warmup_dim, uint64_t warmup_aligned_dim)
+{
+ T *warmup = nullptr;
+ uint64_t file_dim, file_aligned_dim;
+
+ if (files.fileExists(cache_warmup_file))
+ {
+ diskann::load_aligned_bin<T>(files, cache_warmup_file, warmup, warmup_num, file_dim, file_aligned_dim);
+ diskann::cout << "In the warmup file: " << cache_warmup_file << " File dim: " << file_dim
+ << " File aligned dim: " << file_aligned_dim << " Expected dim: " << warmup_dim
+ << " Expected aligned dim: " << warmup_aligned_dim << std::endl;
+
+ if (file_dim != warmup_dim || file_aligned_dim != warmup_aligned_dim)
+ {
+ std::stringstream stream;
+ stream << "Mismatched dimensions in sample file. file_dim = " << file_dim
+ << " file_aligned_dim: " << file_aligned_dim << " index_dim: " << warmup_dim
+ << " index_aligned_dim: " << warmup_aligned_dim << std::endl;
+ diskann::cerr << stream.str();
+ throw diskann::ANNException(stream.str(), -1);
+ }
+ }
+ else
+ {
+ warmup = generateRandomWarmup<T>(warmup_num, warmup_dim, warmup_aligned_dim);
+ }
+ return warmup;
+}
+#endif
+
+template <typename T>
+T *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, uint64_t warmup_dim,
+ uint64_t warmup_aligned_dim)
+{
+ T *warmup = nullptr;
+ uint64_t file_dim, file_aligned_dim;
+
+ if (file_exists(cache_warmup_file))
+ {
+ diskann::load_aligned_bin<T>(cache_warmup_file, warmup, warmup_num, file_dim, file_aligned_dim);
+ if (file_dim != warmup_dim || file_aligned_dim != warmup_aligned_dim)
+ {
+ std::stringstream stream;
+ stream << "Mismatched dimensions in sample file. file_dim = " << file_dim
+ << " file_aligned_dim: " << file_aligned_dim << " index_dim: " << warmup_dim
+ << " index_aligned_dim: " << warmup_aligned_dim << std::endl;
+ throw diskann::ANNException(stream.str(), -1);
+ }
+ }
+ else
+ {
+ warmup = generateRandomWarmup<T>(warmup_num, warmup_dim, warmup_aligned_dim);
+ }
+ return warmup;
+}
+
+/***************************************************
+ Support for Merging Many Vamana Indices
+ ***************************************************/
+
+void read_idmap(const std::string &fname, std::vector<uint32_t> &ivecs)
+{
+ uint32_t npts32, dim;
+ size_t actual_file_size = get_file_size(fname);
+ std::ifstream reader(fname.c_str(), std::ios::binary);
+ reader.read((char *)&npts32, sizeof(uint32_t));
+ reader.read((char *)&dim, sizeof(uint32_t));
+ if (dim != 1 || actual_file_size != ((size_t)npts32) * sizeof(uint32_t) + 2 * sizeof(uint32_t))
+ {
+ std::stringstream stream;
+ stream << "Error reading idmap file. Check if the file is bin file with "
+ "1 dimensional data. Actual: "
+ << actual_file_size << ", expected: " << (size_t)npts32 + 2 * sizeof(uint32_t) << std::endl;
+
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ ivecs.resize(npts32);
+ reader.read((char *)ivecs.data(), ((size_t)npts32) * sizeof(uint32_t));
+ reader.close();
+}
+
+int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suffix, const std::string &idmaps_prefix,
+ const std::string &idmaps_suffix, const uint64_t nshards, uint32_t max_degree,
+ const std::string &output_vamana, const std::string &medoids_file, bool use_filters,
+ const std::string &labels_to_medoids_file)
+{
+ // Read ID maps
+ std::vector<std::string> vamana_names(nshards);
+ std::vector<std::vector<uint32_t>> idmaps(nshards);
+ for (uint64_t shard = 0; shard < nshards; shard++)
+ {
+ vamana_names[shard] = vamana_prefix + std::to_string(shard) + vamana_suffix;
+ read_idmap(idmaps_prefix + std::to_string(shard) + idmaps_suffix, idmaps[shard]);
+ }
+
+ // find max node id
+ size_t nnodes = 0;
+ size_t nelems = 0;
+ for (auto &idmap : idmaps)
+ {
+ for (auto &id : idmap)
+ {
+ nnodes = std::max(nnodes, (size_t)id);
+ }
+ nelems += idmap.size();
+ }
+ nnodes++;
+ diskann::cout << "# nodes: " << nnodes << ", max. degree: " << max_degree << std::endl;
+
+ // compute inverse map: node -> shards
+ std::vector<std::pair<uint32_t, uint32_t>> node_shard;
+ node_shard.reserve(nelems);
+ for (size_t shard = 0; shard < nshards; shard++)
+ {
+ diskann::cout << "Creating inverse map -- shard #" << shard << std::endl;
+ for (size_t idx = 0; idx < idmaps[shard].size(); idx++)
+ {
+ size_t node_id = idmaps[shard][idx];
+ node_shard.push_back(std::make_pair((uint32_t)node_id, (uint32_t)shard));
+ }
+ }
+ std::sort(node_shard.begin(), node_shard.end(), [](const auto &left, const auto &right) {
+ return left.first < right.first || (left.first == right.first && left.second < right.second);
+ });
+ diskann::cout << "Finished computing node -> shards map" << std::endl;
+
+ // will merge all the labels to medoids files of each shard into one
+ // combined file
+ if (use_filters)
+ {
+ std::unordered_map<uint32_t, std::vector<uint32_t>> global_label_to_medoids;
+
+ for (size_t i = 0; i < nshards; i++)
+ {
+ std::ifstream mapping_reader;
+ std::string map_file = vamana_names[i] + "_labels_to_medoids.txt";
+ mapping_reader.open(map_file);
+
+ std::string line, token;
+ uint32_t line_cnt = 0;
+
+ while (std::getline(mapping_reader, line))
+ {
+ std::istringstream iss(line);
+ uint32_t cnt = 0;
+ uint32_t medoid = 0;
+ uint32_t label = 0;
+ while (std::getline(iss, token, ','))
+ {
+ token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
+ token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
+
+ uint32_t token_as_num = std::stoul(token);
+
+ if (cnt == 0)
+ label = token_as_num;
+ else
+ medoid = token_as_num;
+ cnt++;
+ }
+ global_label_to_medoids[label].push_back(idmaps[i][medoid]);
+ line_cnt++;
+ }
+ mapping_reader.close();
+ }
+
+ std::ofstream mapping_writer(labels_to_medoids_file);
+ assert(mapping_writer.is_open());
+ for (auto iter : global_label_to_medoids)
+ {
+ mapping_writer << iter.first << ", ";
+ auto &vec = iter.second;
+ for (uint32_t idx = 0; idx < vec.size() - 1; idx++)
+ {
+ mapping_writer << vec[idx] << ", ";
+ }
+ mapping_writer << vec[vec.size() - 1] << std::endl;
+ }
+ mapping_writer.close();
+ }
+
+ // create cached vamana readers
+ std::vector<cached_ifstream> vamana_readers(nshards);
+ for (size_t i = 0; i < nshards; i++)
+ {
+ vamana_readers[i].open(vamana_names[i], BUFFER_SIZE_FOR_CACHED_IO);
+ size_t expected_file_size;
+ vamana_readers[i].read((char *)&expected_file_size, sizeof(uint64_t));
+ }
+
+ size_t vamana_metadata_size =
+ sizeof(uint64_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(uint64_t); // expected file size + max degree +
+ // medoid_id + frozen_point info
+
+ // create cached vamana writers
+ cached_ofstream merged_vamana_writer(output_vamana, BUFFER_SIZE_FOR_CACHED_IO);
+
+ size_t merged_index_size = vamana_metadata_size; // we initialize the size of the merged index to
+ // the metadata size
+ size_t merged_index_frozen = 0;
+ merged_vamana_writer.write((char *)&merged_index_size,
+ sizeof(uint64_t)); // we will overwrite the index size at the end
+
+ uint32_t output_width = max_degree;
+ uint32_t max_input_width = 0;
+ // read width from each vamana to advance buffer by sizeof(uint32_t) bytes
+ for (auto &reader : vamana_readers)
+ {
+ uint32_t input_width;
+ reader.read((char *)&input_width, sizeof(uint32_t));
+ max_input_width = input_width > max_input_width ? input_width : max_input_width;
+ }
+
+ diskann::cout << "Max input width: " << max_input_width << ", output width: " << output_width << std::endl;
+
+ merged_vamana_writer.write((char *)&output_width, sizeof(uint32_t));
+ std::ofstream medoid_writer(medoids_file.c_str(), std::ios::binary);
+ uint32_t nshards_u32 = (uint32_t)nshards;
+ uint32_t one_val = 1;
+ medoid_writer.write((char *)&nshards_u32, sizeof(uint32_t));
+ medoid_writer.write((char *)&one_val, sizeof(uint32_t));
+
+ uint64_t vamana_index_frozen = 0; // as of now the functionality to merge many overlapping vamana
+ // indices is supported only for bulk indices without frozen point.
+ // Hence the final index will also not have any frozen points.
+ for (uint64_t shard = 0; shard < nshards; shard++)
+ {
+ uint32_t medoid;
+ // read medoid
+ vamana_readers[shard].read((char *)&medoid, sizeof(uint32_t));
+ vamana_readers[shard].read((char *)&vamana_index_frozen, sizeof(uint64_t));
+ assert(vamana_index_frozen == false);
+ // rename medoid
+ medoid = idmaps[shard][medoid];
+
+ medoid_writer.write((char *)&medoid, sizeof(uint32_t));
+ // write renamed medoid
+ if (shard == (nshards - 1)) //--> uncomment if running hierarchical
+ merged_vamana_writer.write((char *)&medoid, sizeof(uint32_t));
+ }
+ merged_vamana_writer.write((char *)&merged_index_frozen, sizeof(uint64_t));
+ medoid_writer.close();
+
+ diskann::cout << "Starting merge" << std::endl;
+
+ // Gopal. random_shuffle() is deprecated.
+ std::random_device rng;
+ std::mt19937 urng(rng());
+
+ std::vector<bool> nhood_set(nnodes, 0);
+ std::vector<uint32_t> final_nhood;
+
+ uint32_t nnbrs = 0, shard_nnbrs = 0;
+ uint32_t cur_id = 0;
+ for (const auto &id_shard : node_shard)
+ {
+ uint32_t node_id = id_shard.first;
+ uint32_t shard_id = id_shard.second;
+ if (cur_id < node_id)
+ {
+ // Gopal. random_shuffle() is deprecated.
+ std::shuffle(final_nhood.begin(), final_nhood.end(), urng);
+ nnbrs = (uint32_t)(std::min)(final_nhood.size(), (uint64_t)max_degree);
+ // write into merged ofstream
+ merged_vamana_writer.write((char *)&nnbrs, sizeof(uint32_t));
+ merged_vamana_writer.write((char *)final_nhood.data(), nnbrs * sizeof(uint32_t));
+ merged_index_size += (sizeof(uint32_t) + nnbrs * sizeof(uint32_t));
+ if (cur_id % 499999 == 1)
+ {
+ diskann::cout << "." << std::flush;
+ }
+ cur_id = node_id;
+ nnbrs = 0;
+ for (auto &p : final_nhood)
+ nhood_set[p] = 0;
+ final_nhood.clear();
+ }
+ // read from shard_id ifstream
+ vamana_readers[shard_id].read((char *)&shard_nnbrs, sizeof(uint32_t));
+
+ if (shard_nnbrs == 0)
+ {
+ diskann::cout << "WARNING: shard #" << shard_id << ", node_id " << node_id << " has 0 nbrs" << std::endl;
+ }
+
+ std::vector<uint32_t> shard_nhood(shard_nnbrs);
+ if (shard_nnbrs > 0)
+ vamana_readers[shard_id].read((char *)shard_nhood.data(), shard_nnbrs * sizeof(uint32_t));
+ // rename nodes
+ for (uint64_t j = 0; j < shard_nnbrs; j++)
+ {
+ if (nhood_set[idmaps[shard_id][shard_nhood[j]]] == 0)
+ {
+ nhood_set[idmaps[shard_id][shard_nhood[j]]] = 1;
+ final_nhood.emplace_back(idmaps[shard_id][shard_nhood[j]]);
+ }
+ }
+ }
+
+ // Gopal. random_shuffle() is deprecated.
+ std::shuffle(final_nhood.begin(), final_nhood.end(), urng);
+ nnbrs = (uint32_t)(std::min)(final_nhood.size(), (uint64_t)max_degree);
+ // write into merged ofstream
+ merged_vamana_writer.write((char *)&nnbrs, sizeof(uint32_t));
+ if (nnbrs > 0)
+ {
+ merged_vamana_writer.write((char *)final_nhood.data(), nnbrs * sizeof(uint32_t));
+ }
+ merged_index_size += (sizeof(uint32_t) + nnbrs * sizeof(uint32_t));
+ for (auto &p : final_nhood)
+ nhood_set[p] = 0;
+ final_nhood.clear();
+
+ diskann::cout << "Expected size: " << merged_index_size << std::endl;
+
+ merged_vamana_writer.reset();
+ merged_vamana_writer.write((char *)&merged_index_size, sizeof(uint64_t));
+
+ diskann::cout << "Finished merge" << std::endl;
+ return 0;
+}
+
+// TODO: Make this a streaming implementation to avoid exceeding the memory
+// budget
+/* If the number of filters per point N exceeds the graph degree R,
+ then it is difficult to have edges to all labels from this point.
+ This function break up such dense points to have only a threshold of maximum
+ T labels per point It divides one graph nodes to multiple nodes and append
+ the new nodes at the end. The dummy map contains the real graph id of the
+ new nodes added to the graph */
+template <typename T>
+void breakup_dense_points(const std::string data_file, const std::string labels_file, uint32_t density,
+ const std::string out_data_file, const std::string out_labels_file,
+ const std::string out_metadata_file)
+{
+ std::string token, line;
+ std::ifstream labels_stream(labels_file);
+ T *data;
+ uint64_t npts, ndims;
+ diskann::load_bin<T>(data_file, data, npts, ndims);
+
+ std::unordered_map<uint32_t, uint32_t> dummy_pt_ids;
+ uint32_t next_dummy_id = (uint32_t)npts;
+
+ uint32_t point_cnt = 0;
+
+ std::vector<std::vector<uint32_t>> labels_per_point;
+ labels_per_point.resize(npts);
+
+ uint32_t dense_pts = 0;
+ if (labels_stream.is_open())
+ {
+ while (getline(labels_stream, line))
+ {
+ std::stringstream iss(line);
+ uint32_t lbl_cnt = 0;
+ uint32_t label_host = point_cnt;
+ while (getline(iss, token, ','))
+ {
+ if (lbl_cnt == density)
+ {
+ if (label_host == point_cnt)
+ dense_pts++;
+ label_host = next_dummy_id;
+ labels_per_point.resize(next_dummy_id + 1);
+ dummy_pt_ids[next_dummy_id] = (uint32_t)point_cnt;
+ next_dummy_id++;
+ lbl_cnt = 0;
+ }
+ token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
+ token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
+ uint32_t token_as_num = std::stoul(token);
+ labels_per_point[label_host].push_back(token_as_num);
+ lbl_cnt++;
+ }
+ point_cnt++;
+ }
+ }
+ diskann::cout << "fraction of dense points with >= " << density << " labels = " << (float)dense_pts / (float)npts
+ << std::endl;
+
+ if (labels_per_point.size() != 0)
+ {
+ diskann::cout << labels_per_point.size() << " is the new number of points" << std::endl;
+ std::ofstream label_writer(out_labels_file);
+ assert(label_writer.is_open());
+ for (uint32_t i = 0; i < labels_per_point.size(); i++)
+ {
+ for (uint32_t j = 0; j < (labels_per_point[i].size() - 1); j++)
+ {
+ label_writer << labels_per_point[i][j] << ",";
+ }
+ if (labels_per_point[i].size() != 0)
+ label_writer << labels_per_point[i][labels_per_point[i].size() - 1];
+ label_writer << std::endl;
+ }
+ label_writer.close();
+ }
+
+ if (dummy_pt_ids.size() != 0)
+ {
+ diskann::cout << dummy_pt_ids.size() << " is the number of dummy points created" << std::endl;
+
+ T *ptr = (T *)std::realloc((void *)data, labels_per_point.size() * ndims * sizeof(T));
+ if (ptr == nullptr)
+ {
+ diskann::cerr << "Realloc failed while creating dummy points" << std::endl;
+ free(data);
+ data = nullptr;
+ throw new diskann::ANNException("Realloc failed while expanding data.", -1, __FUNCTION__, __FILE__,
+ __LINE__);
+ }
+ else
+ {
+ data = ptr;
+ }
+
+ std::ofstream dummy_writer(out_metadata_file);
+ assert(dummy_writer.is_open());
+ for (auto i = dummy_pt_ids.begin(); i != dummy_pt_ids.end(); i++)
+ {
+ dummy_writer << i->first << "," << i->second << std::endl;
+ std::memcpy(data + i->first * ndims, data + i->second * ndims, ndims * sizeof(T));
+ }
+ dummy_writer.close();
+ }
+
+ diskann::save_bin<T>(out_data_file, data, labels_per_point.size(), ndims);
+}
+
+void extract_shard_labels(const std::string &in_label_file, const std::string &shard_ids_bin,
+ const std::string &shard_label_file)
+{ // assumes ith row is for ith
+ // point in labels file
+ diskann::cout << "Extracting labels for shard" << std::endl;
+
+ uint32_t *ids = nullptr;
+ uint64_t num_ids, tmp_dim;
+ diskann::load_bin(shard_ids_bin, ids, num_ids, tmp_dim);
+
+ uint32_t counter = 0, shard_counter = 0;
+ std::string cur_line;
+
+ std::ifstream label_reader(in_label_file);
+ std::ofstream label_writer(shard_label_file);
+ assert(label_reader.is_open());
+ assert(label_reader.is_open());
+ if (label_reader && label_writer)
+ {
+ while (std::getline(label_reader, cur_line))
+ {
+ if (shard_counter >= num_ids)
+ {
+ break;
+ }
+ if (counter == ids[shard_counter])
+ {
+ label_writer << cur_line << "\n";
+ shard_counter++;
+ }
+ counter++;
+ }
+ }
+ if (ids != nullptr)
+ delete[] ids;
+}
+
+template <typename T, typename LabelT>
+int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R,
+ double sampling_rate, double ram_budget, std::string mem_index_path,
+ std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq,
+ uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label,
+ const uint32_t Lf)
+{
+ size_t base_num, base_dim;
+ diskann::get_bin_metadata(base_file, base_num, base_dim);
+
+ double full_index_ram = estimate_ram_usage(base_num, (uint32_t)base_dim, sizeof(T), R);
+
+ // TODO: Make this honest when there is filter support
+ if (full_index_ram < ram_budget * 1024 * 1024 * 1024)
+ {
+ diskann::cout << "Full index fits in RAM budget, should consume at most "
+ << full_index_ram / (1024 * 1024 * 1024) << "GiBs, so building in one shot" << std::endl;
+
+ diskann::IndexWriteParameters paras = diskann::IndexWriteParametersBuilder(L, R)
+ .with_filter_list_size(Lf)
+ .with_saturate_graph(!use_filters)
+ .with_num_threads(num_threads)
+ .build();
+ using TagT = uint32_t;
+ diskann::Index<T, TagT, LabelT> _index(compareMetric, base_dim, base_num,
+ std::make_shared<diskann::IndexWriteParameters>(paras), nullptr,
+ defaults::NUM_FROZEN_POINTS_STATIC, false, false, false,
+ build_pq_bytes > 0, build_pq_bytes, use_opq, use_filters);
+
+ _index.build(base_file.c_str(), base_num);
+
+ _index.save(mem_index_path.c_str());
+
+ if (use_filters)
+ {
+ // need to copy the labels_to_medoids file to the specified input
+ // file
+ std::remove(labels_to_medoids_file.c_str());
+ std::string mem_labels_to_medoid_file = mem_index_path + "_labels_to_medoids.txt";
+ copy_file(mem_labels_to_medoid_file, labels_to_medoids_file);
+ std::remove(mem_labels_to_medoid_file.c_str());
+ }
+
+ std::remove(medoids_file.c_str());
+ std::remove(centroids_file.c_str());
+ return 0;
+ }
+
+ // where the universal label is to be saved in the final graph
+ std::string final_index_universal_label_file = mem_index_path + "_universal_label.txt";
+
+ std::string merged_index_prefix = mem_index_path + "_tempFiles";
+
+ Timer timer;
+ int num_parts =
+ partition_with_ram_budget<T>(base_file, sampling_rate, ram_budget, 2 * R / 3, merged_index_prefix, 2);
+ diskann::cout << timer.elapsed_seconds_for_step("partitioning data ") << std::endl;
+
+ std::string cur_centroid_filepath = merged_index_prefix + "_centroids.bin";
+ std::rename(cur_centroid_filepath.c_str(), centroids_file.c_str());
+
+ timer.reset();
+ for (int p = 0; p < num_parts; p++)
+ {
+#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
+ MallocExtension::instance()->ReleaseFreeMemory();
+#endif
+
+ std::string shard_base_file = merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin";
+
+ std::string shard_ids_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_ids_uint32.bin";
+
+ std::string shard_labels_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt";
+
+ retrieve_shard_data_from_ids<T>(base_file, shard_ids_file, shard_base_file);
+
+ std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index";
+
+ diskann::IndexWriteParameters low_degree_params = diskann::IndexWriteParametersBuilder(L, 2 * R / 3)
+ .with_filter_list_size(Lf)
+ .with_saturate_graph(false)
+ .with_num_threads(num_threads)
+ .build();
+
+ uint64_t shard_base_dim, shard_base_pts;
+ get_bin_metadata(shard_base_file, shard_base_pts, shard_base_dim);
+
+ diskann::Index<T> _index(compareMetric, shard_base_dim, shard_base_pts,
+ std::make_shared<diskann::IndexWriteParameters>(low_degree_params), nullptr,
+ defaults::NUM_FROZEN_POINTS_STATIC, false, false, false, build_pq_bytes > 0,
+ build_pq_bytes, use_opq);
+ if (!use_filters)
+ {
+ _index.build(shard_base_file.c_str(), shard_base_pts);
+ }
+ else
+ {
+ diskann::extract_shard_labels(label_file, shard_ids_file, shard_labels_file);
+ if (universal_label != "")
+ { // indicates no universal label
+ LabelT unv_label_as_num = 0;
+ _index.set_universal_label(unv_label_as_num);
+ }
+ _index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts);
+ }
+ _index.save(shard_index_file.c_str());
+ // copy universal label file from first shard to the final destination
+ // index, since all shards anyway share the universal label
+ if (p == 0)
+ {
+ std::string shard_universal_label_file = shard_index_file + "_universal_label.txt";
+ if (universal_label != "")
+ {
+ copy_file(shard_universal_label_file, final_index_universal_label_file);
+ }
+ }
+
+ std::remove(shard_base_file.c_str());
+ }
+ diskann::cout << timer.elapsed_seconds_for_step("building indices on shards") << std::endl;
+
+ timer.reset();
+ diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index", merged_index_prefix + "_subshard-",
+ "_ids_uint32.bin", num_parts, R, mem_index_path, medoids_file, use_filters,
+ labels_to_medoids_file);
+ diskann::cout << timer.elapsed_seconds_for_step("merging indices") << std::endl;
+
+ // delete tempFiles
+ for (int p = 0; p < num_parts; p++)
+ {
+ std::string shard_base_file = merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin";
+ std::string shard_id_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_ids_uint32.bin";
+ std::string shard_labels_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt";
+ std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index";
+ std::string shard_index_file_data = shard_index_file + ".data";
+
+ std::remove(shard_base_file.c_str());
+ std::remove(shard_id_file.c_str());
+ std::remove(shard_index_file.c_str());
+ std::remove(shard_index_file_data.c_str());
+ if (use_filters)
+ {
+ std::string shard_index_label_file = shard_index_file + "_labels.txt";
+ std::string shard_index_univ_label_file = shard_index_file + "_universal_label.txt";
+ std::string shard_index_label_map_file = shard_index_file + "_labels_to_medoids.txt";
+ std::remove(shard_labels_file.c_str());
+ std::remove(shard_index_label_file.c_str());
+ std::remove(shard_index_label_map_file.c_str());
+ std::remove(shard_index_univ_label_file.c_str());
+ }
+ }
+ return 0;
+}
+
+void print_opts_and_dim(std::stringstream &data_stream, const std::string &point) {
+ int32_t nrows_32 = 0, ncols_32 = 0;
+ data_stream.seekg(0, std::ios::beg);
+ if (!data_stream.read(reinterpret_cast<char*>(&nrows_32), sizeof(int32_t)) ||
+ !data_stream.read(reinterpret_cast<char*>(&ncols_32), sizeof(int32_t))) {
+ std::cerr << "Error: Failed to read nrows_32 and ncols_32 from data_stream." << std::endl;
+ return;
+ }
+ for(int i=0;i<nrows_32;i++){
+ std::cout << "vec" << i <<"->";
+ for(int j=0;j<ncols_32;j++){
+ float vec;
+ data_stream.read(reinterpret_cast<char*>(&vec), 4);
+ std::cout << vec << ",";
+ }
+ std::cout << "\n";
+ }
+ // 恢复流状态,防止影响后续操作
+ data_stream.clear();
+ data_stream.seekg(0, std::ios::beg);
+}
+
+template <typename T, typename LabelT>
+int build_merged_vamana_index(std::stringstream &data_stream, diskann::Metric compareMetric, uint32_t L, uint32_t R,
+ double sampling_rate, double ram_budget, std::stringstream &mem_index_stream,
+ std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq,
+ uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label,
+ const uint32_t Lf)
+{
+ size_t base_num, base_dim;
+ data_stream.seekg(0, data_stream.beg);
+ diskann::get_bin_metadata(data_stream, base_num, base_dim);
+ data_stream.seekg(0, data_stream.beg);
+
+ double full_index_ram = estimate_ram_usage(base_num, (uint32_t)base_dim, sizeof(T), R);
+ diskann::cout << "Full index fits in RAM budget, should consume at most "
+ << full_index_ram / (1024 * 1024 * 1024) << "GiBs, so building in one shot" << std::endl;
+
+ diskann::IndexWriteParameters paras = diskann::IndexWriteParametersBuilder(L, R)
+ .with_filter_list_size(Lf)
+ .with_saturate_graph(!use_filters)
+ .with_num_threads(num_threads)
+ .build();
+ using TagT = uint32_t;
+ diskann::Index<T, TagT, LabelT> _index(compareMetric, base_dim, base_num,
+ std::make_shared<diskann::IndexWriteParameters>(paras), nullptr,
+ defaults::NUM_FROZEN_POINTS_STATIC, false, false, false,
+ build_pq_bytes > 0, build_pq_bytes, use_opq, use_filters);
+
+
+
+ //需要优化
+ float* train_data = new float[base_num * base_dim]; // 使用裸指针和 `new` 分配内存
+ data_stream.seekg(8, data_stream.beg);
+ data_stream.read(reinterpret_cast<char*>(train_data), base_num * base_dim * sizeof(float));
+
+ std::vector<TagT> tags;
+ _index.build(static_cast<const float*>(train_data), base_num, tags);
+ _index.save(mem_index_stream);
+ delete[] train_data;
+ return 0;
+}
+
+// General purpose support for DiskANN interface
+
+// optimizes the beamwidth to maximize QPS for a given L_search subject to
+// 99.9 latency not blowing up
+template <typename T, typename LabelT>
+uint32_t optimize_beamwidth(std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex, T *tuning_sample,
+ uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L,
+ uint32_t nthreads, uint32_t start_bw)
+{
+ uint32_t cur_bw = start_bw;
+ double max_qps = 0;
+ uint32_t best_bw = start_bw;
+ bool stop_flag = false;
+
+ while (!stop_flag)
+ {
+ std::vector<uint64_t> tuning_sample_result_ids_64(tuning_sample_num, 0);
+ std::vector<float> tuning_sample_result_dists(tuning_sample_num, 0);
+ diskann::QueryStats *stats = new diskann::QueryStats[tuning_sample_num];
+
+ auto s = std::chrono::high_resolution_clock::now();
+#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads)
+ for (int64_t i = 0; i < (int64_t)tuning_sample_num; i++)
+ {
+ pFlashIndex->cached_beam_search(tuning_sample + (i * tuning_sample_aligned_dim), 1, L,
+ tuning_sample_result_ids_64.data() + (i * 1),
+ tuning_sample_result_dists.data() + (i * 1), cur_bw, nullptr, stats + i);
+ }
+ auto e = std::chrono::high_resolution_clock::now();
+ std::chrono::duration<double> diff = e - s;
+ double qps = (1.0f * (float)tuning_sample_num) / (1.0f * (float)diff.count());
+
+ double lat_999 = diskann::get_percentile_stats<float>(
+ stats, tuning_sample_num, 0.999f, [](const diskann::QueryStats &stats) { return stats.total_us; });
+
+ double mean_latency = diskann::get_mean_stats<float>(
+ stats, tuning_sample_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
+
+ if (qps > max_qps && lat_999 < (15000) + mean_latency * 2)
+ {
+ max_qps = qps;
+ best_bw = cur_bw;
+ cur_bw = (uint32_t)(std::ceil)((float)cur_bw * 1.1f);
+ }
+ else
+ {
+ stop_flag = true;
+ }
+ if (cur_bw > 64)
+ stop_flag = true;
+
+ delete[] stats;
+ }
+ return best_bw;
+}
+
+template <typename T>
+void create_disk_layout(const std::string base_file, const std::string mem_index_file, const std::string output_file,
+ const std::string reorder_data_file)
+{
+ uint32_t npts, ndims;
+
+ // amount to read or write in one shot
+ size_t read_blk_size = 64 * 1024 * 1024;
+ size_t write_blk_size = read_blk_size;
+ cached_ifstream base_reader(base_file, read_blk_size);
+ base_reader.read((char *)&npts, sizeof(uint32_t));
+ base_reader.read((char *)&ndims, sizeof(uint32_t));
+
+ size_t npts_64, ndims_64;
+ npts_64 = npts;
+ ndims_64 = ndims;
+
+ // Check if we need to append data for re-ordering
+ bool append_reorder_data = false;
+ std::ifstream reorder_data_reader;
+
+ uint32_t npts_reorder_file = 0, ndims_reorder_file = 0;
+ if (reorder_data_file != std::string(""))
+ {
+ append_reorder_data = true;
+ size_t reorder_data_file_size = get_file_size(reorder_data_file);
+ reorder_data_reader.exceptions(std::ofstream::failbit | std::ofstream::badbit);
+
+ try
+ {
+ reorder_data_reader.open(reorder_data_file, std::ios::binary);
+ reorder_data_reader.read((char *)&npts_reorder_file, sizeof(uint32_t));
+ reorder_data_reader.read((char *)&ndims_reorder_file, sizeof(uint32_t));
+ if (npts_reorder_file != npts)
+ throw ANNException("Mismatch in num_points between reorder "
+ "data file and base file",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ if (reorder_data_file_size != 8 + sizeof(float) * (size_t)npts_reorder_file * (size_t)ndims_reorder_file)
+ throw ANNException("Discrepancy in reorder data file size ", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ catch (std::system_error &e)
+ {
+ throw FileException(reorder_data_file, e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+
+ // create cached reader + writer
+ size_t actual_file_size = get_file_size(mem_index_file);
+ diskann::cout << "Vamana index file size=" << actual_file_size << std::endl;
+ std::ifstream vamana_reader(mem_index_file, std::ios::binary);
+ cached_ofstream diskann_writer(output_file, write_blk_size);
+
+ // metadata: width, medoid
+ uint32_t width_u32, medoid_u32;
+ size_t index_file_size;
+
+ vamana_reader.read((char *)&index_file_size, sizeof(uint64_t));
+ if (index_file_size != actual_file_size)
+ {
+ std::stringstream stream;
+ stream << "Vamana Index file size does not match expected size per "
+ "meta-data."
+ << " file size from file: " << index_file_size << " actual file size: " << actual_file_size << std::endl;
+
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ uint64_t vamana_frozen_num = false, vamana_frozen_loc = 0;
+
+ vamana_reader.read((char *)&width_u32, sizeof(uint32_t));
+ vamana_reader.read((char *)&medoid_u32, sizeof(uint32_t));
+ vamana_reader.read((char *)&vamana_frozen_num, sizeof(uint64_t));
+ // compute
+ uint64_t medoid, max_node_len, nnodes_per_sector;
+ npts_64 = (uint64_t)npts;
+ medoid = (uint64_t)medoid_u32;
+ if (vamana_frozen_num == 1)
+ vamana_frozen_loc = medoid;
+ max_node_len = (((uint64_t)width_u32 + 1) * sizeof(uint32_t)) + (ndims_64 * sizeof(T));
+ nnodes_per_sector = defaults::SECTOR_LEN / max_node_len; // 0 if max_node_len > SECTOR_LEN
+
+ diskann::cout << "medoid: " << medoid << "B" << std::endl;
+ diskann::cout << "max_node_len: " << max_node_len << "B" << std::endl;
+ diskann::cout << "nnodes_per_sector: " << nnodes_per_sector << "B" << std::endl;
+
+ // defaults::SECTOR_LEN buffer for each sector
+ std::unique_ptr<char[]> sector_buf = std::make_unique<char[]>(defaults::SECTOR_LEN);
+ std::unique_ptr<char[]> multisector_buf = std::make_unique<char[]>(ROUND_UP(max_node_len, defaults::SECTOR_LEN));
+ std::unique_ptr<char[]> node_buf = std::make_unique<char[]>(max_node_len);
+ uint32_t &nnbrs = *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T));
+ uint32_t *nhood_buf = (uint32_t *)(node_buf.get() + (ndims_64 * sizeof(T)) + sizeof(uint32_t));
+
+ // number of sectors (1 for meta data)
+ uint64_t n_sectors = nnodes_per_sector > 0 ? ROUND_UP(npts_64, nnodes_per_sector) / nnodes_per_sector
+ : npts_64 * DIV_ROUND_UP(max_node_len, defaults::SECTOR_LEN);
+ uint64_t n_reorder_sectors = 0;
+ uint64_t n_data_nodes_per_sector = 0;
+
+ if (append_reorder_data)
+ {
+ n_data_nodes_per_sector = defaults::SECTOR_LEN / (ndims_reorder_file * sizeof(float));
+ n_reorder_sectors = ROUND_UP(npts_64, n_data_nodes_per_sector) / n_data_nodes_per_sector;
+ }
+ uint64_t disk_index_file_size = (n_sectors + n_reorder_sectors + 1) * defaults::SECTOR_LEN;
+
+ std::vector<uint64_t> output_file_meta;
+ output_file_meta.push_back(npts_64);
+ output_file_meta.push_back(ndims_64);
+ output_file_meta.push_back(medoid);
+ output_file_meta.push_back(max_node_len);
+ output_file_meta.push_back(nnodes_per_sector);
+ output_file_meta.push_back(vamana_frozen_num);
+ output_file_meta.push_back(vamana_frozen_loc);
+ output_file_meta.push_back((uint64_t)append_reorder_data);
+ if (append_reorder_data)
+ {
+ output_file_meta.push_back(n_sectors + 1);
+ output_file_meta.push_back(ndims_reorder_file);
+ output_file_meta.push_back(n_data_nodes_per_sector);
+ }
+ output_file_meta.push_back(disk_index_file_size);
+
+ diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN);
+
+ std::unique_ptr<T[]> cur_node_coords = std::make_unique<T[]>(ndims_64);
+ diskann::cout << "# sectors: " << n_sectors << std::endl;
+ uint64_t cur_node_id = 0;
+
+ if (nnodes_per_sector > 0)
+ { // Write multiple nodes per sector
+ for (uint64_t sector = 0; sector < n_sectors; sector++)
+ {
+ if (sector % 100000 == 0)
+ {
+ diskann::cout << "Sector #" << sector << "written" << std::endl;
+ }
+ memset(sector_buf.get(), 0, defaults::SECTOR_LEN);
+ for (uint64_t sector_node_id = 0; sector_node_id < nnodes_per_sector && cur_node_id < npts_64;
+ sector_node_id++)
+ {
+ memset(node_buf.get(), 0, max_node_len);
+ // read cur node's nnbrs
+ vamana_reader.read((char *)&nnbrs, sizeof(uint32_t));
+
+ // sanity checks on nnbrs
+ assert(nnbrs > 0);
+ assert(nnbrs <= width_u32);
+
+ // read node's nhood
+ vamana_reader.read((char *)nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t));
+ if (nnbrs > width_u32)
+ {
+ vamana_reader.seekg((nnbrs - width_u32) * sizeof(uint32_t), vamana_reader.cur);
+ }
+
+ // write coords of node first
+ // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id);
+ base_reader.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64);
+ memcpy(node_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T));
+
+ // write nnbrs
+ *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32);
+
+ // write nhood next
+ memcpy(node_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), nhood_buf,
+ (std::min)(nnbrs, width_u32) * sizeof(uint32_t));
+
+ // get offset into sector_buf
+ char *sector_node_buf = sector_buf.get() + (sector_node_id * max_node_len);
+
+ // copy node buf into sector_node_buf
+ memcpy(sector_node_buf, node_buf.get(), max_node_len);
+ cur_node_id++;
+ }
+ // flush sector to disk
+ diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN);
+ }
+ }
+ else
+ { // Write multi-sector nodes
+ uint64_t nsectors_per_node = DIV_ROUND_UP(max_node_len, defaults::SECTOR_LEN);
+ for (uint64_t i = 0; i < npts_64; i++)
+ {
+ if ((i * nsectors_per_node) % 100000 == 0)
+ {
+ diskann::cout << "Sector #" << i * nsectors_per_node << "written" << std::endl;
+ }
+ memset(multisector_buf.get(), 0, nsectors_per_node * defaults::SECTOR_LEN);
+
+ memset(node_buf.get(), 0, max_node_len);
+ // read cur node's nnbrs
+ vamana_reader.read((char *)&nnbrs, sizeof(uint32_t));
+
+ // sanity checks on nnbrs
+ assert(nnbrs > 0);
+ assert(nnbrs <= width_u32);
+
+ // read node's nhood
+ vamana_reader.read((char *)nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t));
+ if (nnbrs > width_u32)
+ {
+ vamana_reader.seekg((nnbrs - width_u32) * sizeof(uint32_t), vamana_reader.cur);
+ }
+
+ // write coords of node first
+ // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id);
+ base_reader.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64);
+ memcpy(multisector_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T));
+
+ // write nnbrs
+ *(uint32_t *)(multisector_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32);
+
+ // write nhood next
+ memcpy(multisector_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), nhood_buf,
+ (std::min)(nnbrs, width_u32) * sizeof(uint32_t));
+
+ // flush sector to disk
+ diskann_writer.write(multisector_buf.get(), nsectors_per_node * defaults::SECTOR_LEN);
+ }
+ }
+
+ if (append_reorder_data)
+ {
+ diskann::cout << "Index written. Appending reorder data..." << std::endl;
+
+ auto vec_len = ndims_reorder_file * sizeof(float);
+ std::unique_ptr<char[]> vec_buf = std::make_unique<char[]>(vec_len);
+
+ for (uint64_t sector = 0; sector < n_reorder_sectors; sector++)
+ {
+ if (sector % 100000 == 0)
+ {
+ diskann::cout << "Reorder data Sector #" << sector << "written" << std::endl;
+ }
+
+ memset(sector_buf.get(), 0, defaults::SECTOR_LEN);
+
+ for (uint64_t sector_node_id = 0; sector_node_id < n_data_nodes_per_sector && sector_node_id < npts_64;
+ sector_node_id++)
+ {
+ memset(vec_buf.get(), 0, vec_len);
+ reorder_data_reader.read(vec_buf.get(), vec_len);
+
+ // copy node buf into sector_node_buf
+ memcpy(sector_buf.get() + (sector_node_id * vec_len), vec_buf.get(), vec_len);
+ }
+ // flush sector to disk
+ diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN);
+ }
+ }
+ diskann_writer.close();
+ diskann::save_bin<uint64_t>(output_file, output_file_meta.data(), output_file_meta.size(), 1, 0);
+ diskann::cout << "Output disk index file written to " << output_file << std::endl;
+}
+
+template <typename T>
+void create_disk_layout(std::stringstream &_data_stream, std::stringstream &vamana_index_stream, std::stringstream &disklayout_stream,
+ const std::string reorder_data_file)
+{
+ uint32_t npts, ndims;
+
+ // amount to read or write in one shot
+ size_t read_blk_size = 64 * 1024 * 1024;
+ size_t write_blk_size = read_blk_size;
+ _data_stream.read((char *)&npts, sizeof(uint32_t));
+ _data_stream.read((char *)&ndims, sizeof(uint32_t));
+
+ size_t npts_64, ndims_64;
+ npts_64 = npts;
+ ndims_64 = ndims;
+
+ // Check if we need to append data for re-ordering
+ bool append_reorder_data = false;
+ std::ifstream reorder_data_reader;
+
+ uint32_t npts_reorder_file = 0, ndims_reorder_file = 0;
+ // if (reorder_data_file != std::string(""))
+ // {
+ // append_reorder_data = true;
+ // size_t reorder_data_file_size = get_file_size(reorder_data_file);
+ // reorder_data_reader.exceptions(std::ofstream::failbit | std::ofstream::badbit);
+
+ // try
+ // {
+ // reorder_data_reader.open(reorder_data_file, std::ios::binary);
+ // reorder_data_reader.read((char *)&npts_reorder_file, sizeof(uint32_t));
+ // reorder_data_reader.read((char *)&ndims_reorder_file, sizeof(uint32_t));
+ // if (npts_reorder_file != npts)
+ // throw ANNException("Mismatch in num_points between reorder "
+ // "data file and base file",
+ // -1, __FUNCSIG__, __FILE__, __LINE__);
+ // if (reorder_data_file_size != 8 + sizeof(float) * (size_t)npts_reorder_file * (size_t)ndims_reorder_file)
+ // throw ANNException("Discrepancy in reorder data file size ", -1, __FUNCSIG__, __FILE__, __LINE__);
+ // }
+ // catch (std::system_error &e)
+ // {
+ // throw FileException(reorder_data_file, e, __FUNCSIG__, __FILE__, __LINE__);
+ // }
+ // }
+
+ // create cached reader + writer
+ // size_t actual_file_size = get_file_size(mem_index_file);
+ // diskann::cout << "Vamana index file size=" << actual_file_size << std::endl;
+ // std::ifstream vamana_reader(mem_index_file, std::ios::binary);
+ // cached_ofstream diskann_writer(output_file, write_blk_size);
+
+ // metadata: width, medoid
+ uint32_t width_u32, medoid_u32;
+ size_t index_file_size;
+
+ // vamana_reader.read((char *)&index_file_size, sizeof(uint64_t));
+ // if (index_file_size != actual_file_size)
+ // {
+ // std::stringstream stream;
+ // stream << "Vamana Index file size does not match expected size per "
+ // "meta-data."
+ // << " file size from file: " << index_file_size << " actual file size: " << actual_file_size << std::endl;
+
+ // throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ // }
+ uint64_t vamana_frozen_num = false, vamana_frozen_loc = 0;
+ vamana_index_stream.read((char *)&index_file_size, sizeof(size_t));
+ vamana_index_stream.read((char *)&width_u32, sizeof(uint32_t));
+ vamana_index_stream.read((char *)&medoid_u32, sizeof(uint32_t));
+ vamana_index_stream.read((char *)&vamana_frozen_num, sizeof(uint64_t));
+ // compute
+ uint64_t medoid, max_node_len, nnodes_per_sector;
+ npts_64 = (uint64_t)npts;
+ medoid = (uint64_t)medoid_u32;
+ if (vamana_frozen_num == 1)
+ vamana_frozen_loc = medoid;
+ max_node_len = (((uint64_t)width_u32 + 1) * sizeof(uint32_t)) + (ndims_64 * sizeof(T));
+ nnodes_per_sector = defaults::SECTOR_LEN / max_node_len; // 0 if max_node_len > SECTOR_LEN
+
+ diskann::cout << "medoid: " << medoid << "B" << std::endl;
+ diskann::cout << "max_node_len: " << max_node_len << "B" << std::endl;
+ diskann::cout << "nnodes_per_sector: " << nnodes_per_sector << "B" << std::endl;
+
+ // defaults::SECTOR_LEN buffer for each sector
+ std::unique_ptr<char[]> sector_buf = std::make_unique<char[]>(defaults::SECTOR_LEN);
+ std::unique_ptr<char[]> multisector_buf = std::make_unique<char[]>(ROUND_UP(max_node_len, defaults::SECTOR_LEN));
+ std::unique_ptr<char[]> node_buf = std::make_unique<char[]>(max_node_len);
+ uint32_t &nnbrs = *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T));
+ uint32_t *nhood_buf = (uint32_t *)(node_buf.get() + (ndims_64 * sizeof(T)) + sizeof(uint32_t));
+
+ // number of sectors (1 for meta data)
+ uint64_t n_sectors = nnodes_per_sector > 0 ? ROUND_UP(npts_64, nnodes_per_sector) / nnodes_per_sector
+ : npts_64 * DIV_ROUND_UP(max_node_len, defaults::SECTOR_LEN);
+ uint64_t n_reorder_sectors = 0;
+ uint64_t n_data_nodes_per_sector = 0;
+
+ if (append_reorder_data)
+ {
+ n_data_nodes_per_sector = defaults::SECTOR_LEN / (ndims_reorder_file * sizeof(float));
+ n_reorder_sectors = ROUND_UP(npts_64, n_data_nodes_per_sector) / n_data_nodes_per_sector;
+ }
+ uint64_t disk_index_file_size = (n_sectors + n_reorder_sectors + 1) * defaults::SECTOR_LEN;
+
+ std::vector<uint64_t> output_file_meta;
+ output_file_meta.push_back(npts_64);
+ output_file_meta.push_back(ndims_64);
+ output_file_meta.push_back(medoid);
+ output_file_meta.push_back(max_node_len);
+ output_file_meta.push_back(nnodes_per_sector);
+ output_file_meta.push_back(vamana_frozen_num);
+ output_file_meta.push_back(vamana_frozen_loc);
+ output_file_meta.push_back((uint64_t)append_reorder_data);
+ if (append_reorder_data)
+ {
+ output_file_meta.push_back(n_sectors + 1);
+ output_file_meta.push_back(ndims_reorder_file);
+ output_file_meta.push_back(n_data_nodes_per_sector);
+ }
+ output_file_meta.push_back(disk_index_file_size);
+
+ disklayout_stream.write(sector_buf.get(), defaults::SECTOR_LEN);
+
+ std::unique_ptr<T[]> cur_node_coords = std::make_unique<T[]>(ndims_64);
+ diskann::cout << "# sectors: " << n_sectors << std::endl;
+ uint64_t cur_node_id = 0;
+
+ if (nnodes_per_sector > 0)
+ { // Write multiple nodes per sector
+ for (uint64_t sector = 0; sector < n_sectors; sector++)
+ {
+ if (sector % 100000 == 0)
+ {
+ diskann::cout << "Sector #" << sector << "written" << std::endl;
+ }
+ memset(sector_buf.get(), 0, defaults::SECTOR_LEN);
+ for (uint64_t sector_node_id = 0; sector_node_id < nnodes_per_sector && cur_node_id < npts_64;
+ sector_node_id++)
+ {
+ memset(node_buf.get(), 0, max_node_len);
+ // read cur node's nnbrs
+ vamana_index_stream.read((char *)&nnbrs, sizeof(uint32_t));
+
+ // sanity checks on nnbrs
+ assert(nnbrs > 0);
+ assert(nnbrs <= width_u32);
+
+ // read node's nhood
+ vamana_index_stream.read((char *)nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t));
+ if (nnbrs > width_u32)
+ {
+ vamana_index_stream.seekg((nnbrs - width_u32) * sizeof(uint32_t), vamana_index_stream.cur);
+ }
+
+ // write coords of node first
+ // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id);
+ _data_stream.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64);
+ memcpy(node_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T));
+
+ // write nnbrs
+ *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32);
+
+ // write nhood next
+ memcpy(node_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), nhood_buf,
+ (std::min)(nnbrs, width_u32) * sizeof(uint32_t));
+
+ // get offset into sector_buf
+ char *sector_node_buf = sector_buf.get() + (sector_node_id * max_node_len);
+
+ // copy node buf into sector_node_buf
+ memcpy(sector_node_buf, node_buf.get(), max_node_len);
+ cur_node_id++;
+ }
+ // flush sector to disk
+ disklayout_stream.write(sector_buf.get(), defaults::SECTOR_LEN);
+ }
+ }
+ else
+ { // Write multi-sector nodes
+ uint64_t nsectors_per_node = DIV_ROUND_UP(max_node_len, defaults::SECTOR_LEN);
+ for (uint64_t i = 0; i < npts_64; i++)
+ {
+ if ((i * nsectors_per_node) % 100000 == 0)
+ {
+ diskann::cout << "Sector #" << i * nsectors_per_node << "written" << std::endl;
+ }
+ memset(multisector_buf.get(), 0, nsectors_per_node * defaults::SECTOR_LEN);
+
+ memset(node_buf.get(), 0, max_node_len);
+ // read cur node's nnbrs
+ vamana_index_stream.read((char *)&nnbrs, sizeof(uint32_t));
+
+ // sanity checks on nnbrs
+ assert(nnbrs > 0);
+ assert(nnbrs <= width_u32);
+
+ // read node's nhood
+ vamana_index_stream.read((char *)nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t));
+ if (nnbrs > width_u32)
+ {
+ vamana_index_stream.seekg((nnbrs - width_u32) * sizeof(uint32_t), vamana_index_stream.cur);
+ }
+
+ // write coords of node first
+ // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id);
+ _data_stream.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64);
+ memcpy(multisector_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T));
+
+ // write nnbrs
+ *(uint32_t *)(multisector_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32);
+
+ // write nhood next
+ memcpy(multisector_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), nhood_buf,
+ (std::min)(nnbrs, width_u32) * sizeof(uint32_t));
+
+ // flush sector to disk
+ disklayout_stream.write(multisector_buf.get(), nsectors_per_node * defaults::SECTOR_LEN);
+ }
+ }
+
+ if (append_reorder_data)
+ {
+ diskann::cout << "Index written. Appending reorder data..." << std::endl;
+
+ auto vec_len = ndims_reorder_file * sizeof(float);
+ std::unique_ptr<char[]> vec_buf = std::make_unique<char[]>(vec_len);
+
+ for (uint64_t sector = 0; sector < n_reorder_sectors; sector++)
+ {
+ if (sector % 100000 == 0)
+ {
+ diskann::cout << "Reorder data Sector #" << sector << "written" << std::endl;
+ }
+
+ memset(sector_buf.get(), 0, defaults::SECTOR_LEN);
+
+ for (uint64_t sector_node_id = 0; sector_node_id < n_data_nodes_per_sector && sector_node_id < npts_64;
+ sector_node_id++)
+ {
+ memset(vec_buf.get(), 0, vec_len);
+ reorder_data_reader.read(vec_buf.get(), vec_len);
+
+ // copy node buf into sector_node_buf
+ memcpy(sector_buf.get() + (sector_node_id * vec_len), vec_buf.get(), vec_len);
+ }
+ // flush sector to disk
+ disklayout_stream.write(sector_buf.get(), defaults::SECTOR_LEN);
+ }
+ }
+ // diskann_writer.close();
+ diskann::save_bin<uint64_t>(disklayout_stream, output_file_meta.data(), output_file_meta.size(), 1, 0);
+ //diskann::cout << "Output disk index file written to " << output_file << std::endl;
+}
+
+template <typename T, typename LabelT>
+int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters,
+ diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters,
+ const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold,
+ const uint32_t Lf)
+{
+ std::stringstream parser;
+ parser << std::string(indexBuildParameters);
+ std::string cur_param;
+ std::vector<std::string> param_list;
+ while (parser >> cur_param)
+ {
+ param_list.push_back(cur_param);
+ }
+ if (param_list.size() < 5 || param_list.size() > 9)
+ {
+ diskann::cout << "Correct usage of parameters is R (max degree)\n"
+ "L (indexing list size, better if >= R)\n"
+ "B (RAM limit of final index in GB)\n"
+ "M (memory limit while indexing)\n"
+ "T (number of threads for indexing)\n"
+ "B' (PQ bytes for disk index: optional parameter for "
+ "very large dimensional data)\n"
+ "reorder (set true to include full precision in data file"
+ ": optional paramter, use only when using disk PQ\n"
+ "build_PQ_byte (number of PQ bytes for inde build; set 0 to use "
+ "full precision vectors)\n"
+ "QD Quantized Dimension to overwrite the derived dim from B "
+ << std::endl;
+ return -1;
+ }
+
+ if (!std::is_same<T, float>::value &&
+ (compareMetric == diskann::Metric::INNER_PRODUCT || compareMetric == diskann::Metric::COSINE))
+ {
+ std::stringstream stream;
+ stream << "Disk-index build currently only supports floating point data for Max "
+ "Inner Product Search/ cosine similarity. "
+ << std::endl;
+ throw diskann::ANNException(stream.str(), -1);
+ }
+
+ size_t disk_pq_dims = 0;
+ bool use_disk_pq = false;
+ size_t build_pq_bytes = 0;
+
+ // if there is a 6th parameter, it means we compress the disk index
+ // vectors also using PQ data (for very large dimensionality data). If the
+ // provided parameter is 0, it means we store full vectors.
+ if (param_list.size() > 5)
+ {
+ disk_pq_dims = atoi(param_list[5].c_str());
+ use_disk_pq = true;
+ if (disk_pq_dims == 0)
+ use_disk_pq = false;
+ }
+
+ bool reorder_data = false;
+ if (param_list.size() >= 7)
+ {
+ if (1 == atoi(param_list[6].c_str()))
+ {
+ reorder_data = true;
+ }
+ }
+
+ if (param_list.size() >= 8)
+ {
+ build_pq_bytes = atoi(param_list[7].c_str());
+ }
+
+ std::string base_file(dataFilePath);
+ std::string data_file_to_use = base_file;
+ std::string labels_file_original = label_file;
+ std::string index_prefix_path(indexFilePath);
+ std::string labels_file_to_use = index_prefix_path + "_label_formatted.txt";
+ std::string pq_pivots_path_base = codebook_prefix;
+ std::string pq_pivots_path = file_exists(pq_pivots_path_base) ? pq_pivots_path_base + "_pq_pivots.bin"
+ : index_prefix_path + "_pq_pivots.bin";
+ std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin";
+ std::string mem_index_path = index_prefix_path + "_mem.index";
+ std::string disk_index_path = index_prefix_path + "_disk.index";
+ std::string medoids_path = disk_index_path + "_medoids.bin";
+ std::string centroids_path = disk_index_path + "_centroids.bin";
+
+ std::string labels_to_medoids_path = disk_index_path + "_labels_to_medoids.txt";
+ std::string mem_labels_file = mem_index_path + "_labels.txt";
+ std::string disk_labels_file = disk_index_path + "_labels.txt";
+ std::string mem_univ_label_file = mem_index_path + "_universal_label.txt";
+ std::string disk_univ_label_file = disk_index_path + "_universal_label.txt";
+ std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt";
+ std::string dummy_remap_file = disk_index_path + "_dummy_map.txt"; // remap will be used if we break-up points of
+ // high label-density to create copies
+
+ std::string sample_base_prefix = index_prefix_path + "_sample";
+ // optional, used if disk index file must store pq data
+ std::string disk_pq_pivots_path = index_prefix_path + "_disk.index_pq_pivots.bin";
+ // optional, used if disk index must store pq data
+ std::string disk_pq_compressed_vectors_path = index_prefix_path + "_disk.index_pq_compressed.bin";
+ std::string prepped_base =
+ index_prefix_path +
+ "_prepped_base.bin"; // temp file for storing pre-processed base file for cosine/ mips metrics
+ bool created_temp_file_for_processed_data = false;
+
+ // output a new base file which contains extra dimension with sqrt(1 -
+ // ||x||^2/M^2) for every x, M is max norm of all points. Extra space on
+ // disk needed!
+ if (compareMetric == diskann::Metric::INNER_PRODUCT)
+ {
+ Timer timer;
+ std::cout << "Using Inner Product search, so need to pre-process base "
+ "data into temp file. Please ensure there is additional "
+ "(n*(d+1)*4) bytes for storing pre-processed base vectors, "
+ "apart from the interim indices created by DiskANN and the final index."
+ << std::endl;
+ data_file_to_use = prepped_base;
+ float max_norm_of_base = diskann::prepare_base_for_inner_products<T>(base_file, prepped_base);
+ std::string norm_file = disk_index_path + "_max_base_norm.bin";
+ diskann::save_bin<float>(norm_file, &max_norm_of_base, 1, 1);
+ diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for inner product") << std::endl;
+ created_temp_file_for_processed_data = true;
+ }
+ else if (compareMetric == diskann::Metric::COSINE)
+ {
+ Timer timer;
+ std::cout << "Normalizing data for cosine to temporary file, please ensure there is additional "
+ "(n*d*4) bytes for storing normalized base vectors, "
+ "apart from the interim indices created by DiskANN and the final index."
+ << std::endl;
+ data_file_to_use = prepped_base;
+ diskann::normalize_data_file(base_file, prepped_base);
+ diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for cosine") << std::endl;
+ created_temp_file_for_processed_data = true;
+ }
+
+ uint32_t R = (uint32_t)atoi(param_list[0].c_str());
+ uint32_t L = (uint32_t)atoi(param_list[1].c_str());
+
+ double final_index_ram_limit = get_memory_budget(param_list[2]);
+ if (final_index_ram_limit <= 0)
+ {
+ std::cerr << "Insufficient memory budget (or string was not in right "
+ "format). Should be > 0."
+ << std::endl;
+ return -1;
+ }
+ double indexing_ram_budget = (float)atof(param_list[3].c_str());
+ if (indexing_ram_budget <= 0)
+ {
+ std::cerr << "Not building index. Please provide more RAM budget" << std::endl;
+ return -1;
+ }
+ uint32_t num_threads = (uint32_t)atoi(param_list[4].c_str());
+
+ if (num_threads != 0)
+ {
+ omp_set_num_threads(num_threads);
+ mkl_set_num_threads(num_threads);
+ }
+
+ diskann::cout << "Starting index build: R=" << R << " L=" << L << " Query RAM budget: " << final_index_ram_limit
+ << " Indexing ram budget: " << indexing_ram_budget << " T: " << num_threads << std::endl;
+
+ auto s = std::chrono::high_resolution_clock::now();
+
+ // If there is filter support, we break-up points which have too many labels
+ // into replica dummy points which evenly distribute the filters. The rest
+ // of index build happens on the augmented base and labels
+ std::string augmented_data_file, augmented_labels_file;
+ if (use_filters)
+ {
+ convert_labels_string_to_int(labels_file_original, labels_file_to_use, disk_labels_int_map_file,
+ universal_label);
+ augmented_data_file = index_prefix_path + "_augmented_data.bin";
+ augmented_labels_file = index_prefix_path + "_augmented_labels.txt";
+ if (filter_threshold != 0)
+ {
+ breakup_dense_points<T>(data_file_to_use, labels_file_to_use, filter_threshold, augmented_data_file,
+ augmented_labels_file,
+ dummy_remap_file); // RKNOTE: This has large memory footprint,
+ // need to make this streaming
+ data_file_to_use = augmented_data_file;
+ labels_file_to_use = augmented_labels_file;
+ }
+ }
+
+ size_t points_num, dim;
+
+ Timer timer;
+ diskann::get_bin_metadata(data_file_to_use.c_str(), points_num, dim);
+ const double p_val = ((double)MAX_PQ_TRAINING_SET_SIZE / (double)points_num);
+
+ if (use_disk_pq)
+ {
+ generate_disk_quantized_data<T>(data_file_to_use, disk_pq_pivots_path, disk_pq_compressed_vectors_path,
+ compareMetric, p_val, disk_pq_dims);
+ }
+ size_t num_pq_chunks = (size_t)(std::floor)(uint64_t(final_index_ram_limit / points_num));
+
+ num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks;
+ num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks;
+ num_pq_chunks = num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks;
+
+ if (param_list.size() >= 9 && atoi(param_list[8].c_str()) <= MAX_PQ_CHUNKS && atoi(param_list[8].c_str()) > 0)
+ {
+ std::cout << "Use quantized dimension (QD) to overwrite derived quantized "
+ "dimension from search_DRAM_budget (B)"
+ << std::endl;
+ num_pq_chunks = atoi(param_list[8].c_str());
+ }
+
+ diskann::cout << "Compressing " << dim << "-dimensional data into " << num_pq_chunks << " bytes per vector."
+ << std::endl;
+
+ generate_quantized_data<T>(data_file_to_use, pq_pivots_path, pq_compressed_vectors_path, compareMetric, p_val,
+ num_pq_chunks, use_opq, codebook_prefix);
+ diskann::cout << timer.elapsed_seconds_for_step("generating quantized data") << std::endl;
+
+// Gopal. Splitting diskann_dll into separate DLLs for search and build.
+// This code should only be available in the "build" DLL.
+#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
+ MallocExtension::instance()->ReleaseFreeMemory();
+#endif
+ // Whether it is cosine or inner product, we still L2 metric due to the pre-processing.
+ timer.reset();
+ diskann::build_merged_vamana_index<T, LabelT>(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val,
+ indexing_ram_budget, mem_index_path, medoids_path, centroids_path,
+ build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use,
+ labels_to_medoids_path, universal_label, Lf);
+ diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl;
+
+ timer.reset();
+ if (!use_disk_pq)
+ {
+ diskann::create_disk_layout<T>(data_file_to_use.c_str(), mem_index_path, disk_index_path);
+ }
+ else
+ {
+ if (!reorder_data)
+ diskann::create_disk_layout<uint8_t>(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path);
+ else
+ diskann::create_disk_layout<uint8_t>(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path,
+ data_file_to_use.c_str());
+ }
+ diskann::cout << timer.elapsed_seconds_for_step("generating disk layout") << std::endl;
+
+ double ten_percent_points = std::ceil(points_num * 0.1);
+ double num_sample_points =
+ ten_percent_points > MAX_SAMPLE_POINTS_FOR_WARMUP ? MAX_SAMPLE_POINTS_FOR_WARMUP : ten_percent_points;
+ double sample_sampling_rate = num_sample_points / points_num;
+ gen_random_slice<T>(data_file_to_use.c_str(), sample_base_prefix, sample_sampling_rate);
+ if (use_filters)
+ {
+ copy_file(labels_file_to_use, disk_labels_file);
+ std::remove(mem_labels_file.c_str());
+ if (universal_label != "")
+ {
+ copy_file(mem_univ_label_file, disk_univ_label_file);
+ std::remove(mem_univ_label_file.c_str());
+ }
+ std::remove(augmented_data_file.c_str());
+ std::remove(augmented_labels_file.c_str());
+ std::remove(labels_file_to_use.c_str());
+ }
+ if (created_temp_file_for_processed_data)
+ std::remove(prepped_base.c_str());
+ std::remove(mem_index_path.c_str());
+ std::remove((mem_index_path + ".data").c_str());
+ std::remove((mem_index_path + ".tags").c_str());
+ if (use_disk_pq)
+ std::remove(disk_pq_compressed_vectors_path.c_str());
+
+ auto e = std::chrono::high_resolution_clock::now();
+ std::chrono::duration<double> diff = e - s;
+ diskann::cout << "Indexing time: " << diff.count() << std::endl;
+
+ return 0;
+}
+
+template DISKANN_DLLEXPORT void create_disk_layout<int8_t>(const std::string base_file,
+ const std::string mem_index_file,
+ const std::string output_file,
+ const std::string reorder_data_file);
+template DISKANN_DLLEXPORT void create_disk_layout<uint8_t>(const std::string base_file,
+ const std::string mem_index_file,
+ const std::string output_file,
+ const std::string reorder_data_file);
+template DISKANN_DLLEXPORT void create_disk_layout<float>(const std::string base_file, const std::string mem_index_file,
+ const std::string output_file,
+ const std::string reorder_data_file);
+
+
+template DISKANN_DLLEXPORT void create_disk_layout<float>(std::stringstream & data_stream, std::stringstream &index_stream,
+ std::stringstream &disklayout_stream,
+ const std::string reorder_data_file = std::string(""));
+
+template DISKANN_DLLEXPORT int8_t *load_warmup<int8_t>(const std::string &cache_warmup_file, uint64_t &warmup_num,
+ uint64_t warmup_dim, uint64_t warmup_aligned_dim);
+template DISKANN_DLLEXPORT uint8_t *load_warmup<uint8_t>(const std::string &cache_warmup_file, uint64_t &warmup_num,
+ uint64_t warmup_dim, uint64_t warmup_aligned_dim);
+template DISKANN_DLLEXPORT float *load_warmup<float>(const std::string &cache_warmup_file, uint64_t &warmup_num,
+ uint64_t warmup_dim, uint64_t warmup_aligned_dim);
+
+#ifdef EXEC_ENV_OLS
+template DISKANN_DLLEXPORT int8_t *load_warmup<int8_t>(MemoryMappedFiles &files, const std::string &cache_warmup_file,
+ uint64_t &warmup_num, uint64_t warmup_dim,
+ uint64_t warmup_aligned_dim);
+template DISKANN_DLLEXPORT uint8_t *load_warmup<uint8_t>(MemoryMappedFiles &files, const std::string &cache_warmup_file,
+ uint64_t &warmup_num, uint64_t warmup_dim,
+ uint64_t warmup_aligned_dim);
+template DISKANN_DLLEXPORT float *load_warmup<float>(MemoryMappedFiles &files, const std::string &cache_warmup_file,
+ uint64_t &warmup_num, uint64_t warmup_dim,
+ uint64_t warmup_aligned_dim);
+#endif
+
+template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t, uint32_t>(
+ std::unique_ptr<diskann::PQFlashIndex<int8_t, uint32_t>> &pFlashIndex, int8_t *tuning_sample,
+ uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw);
+template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t, uint32_t>(
+ std::unique_ptr<diskann::PQFlashIndex<uint8_t, uint32_t>> &pFlashIndex, uint8_t *tuning_sample,
+ uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw);
+template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float, uint32_t>(
+ std::unique_ptr<diskann::PQFlashIndex<float, uint32_t>> &pFlashIndex, float *tuning_sample,
+ uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw);
+
+template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t, uint16_t>(
+ std::unique_ptr<diskann::PQFlashIndex<int8_t, uint16_t>> &pFlashIndex, int8_t *tuning_sample,
+ uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw);
+template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t, uint16_t>(
+ std::unique_ptr<diskann::PQFlashIndex<uint8_t, uint16_t>> &pFlashIndex, uint8_t *tuning_sample,
+ uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw);
+template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float, uint16_t>(
+ std::unique_ptr<diskann::PQFlashIndex<float, uint16_t>> &pFlashIndex, float *tuning_sample,
+ uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw);
+
+template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint32_t>(const char *dataFilePath, const char *indexFilePath,
+ const char *indexBuildParameters,
+ diskann::Metric compareMetric, bool use_opq,
+ const std::string &codebook_prefix, bool use_filters,
+ const std::string &label_file,
+ const std::string &universal_label,
+ const uint32_t filter_threshold, const uint32_t Lf);
+template DISKANN_DLLEXPORT int build_disk_index<uint8_t, uint32_t>(const char *dataFilePath, const char *indexFilePath,
+ const char *indexBuildParameters,
+ diskann::Metric compareMetric, bool use_opq,
+ const std::string &codebook_prefix, bool use_filters,
+ const std::string &label_file,
+ const std::string &universal_label,
+ const uint32_t filter_threshold, const uint32_t Lf);
+template DISKANN_DLLEXPORT int build_disk_index<float, uint32_t>(const char *dataFilePath, const char *indexFilePath,
+ const char *indexBuildParameters,
+ diskann::Metric compareMetric, bool use_opq,
+ const std::string &codebook_prefix, bool use_filters,
+ const std::string &label_file,
+ const std::string &universal_label,
+ const uint32_t filter_threshold, const uint32_t Lf);
+// LabelT = uint16
+template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint16_t>(const char *dataFilePath, const char *indexFilePath,
+ const char *indexBuildParameters,
+ diskann::Metric compareMetric, bool use_opq,
+ const std::string &codebook_prefix, bool use_filters,
+ const std::string &label_file,
+ const std::string &universal_label,
+ const uint32_t filter_threshold, const uint32_t Lf);
+template DISKANN_DLLEXPORT int build_disk_index<uint8_t, uint16_t>(const char *dataFilePath, const char *indexFilePath,
+ const char *indexBuildParameters,
+ diskann::Metric compareMetric, bool use_opq,
+ const std::string &codebook_prefix, bool use_filters,
+ const std::string &label_file,
+ const std::string &universal_label,
+ const uint32_t filter_threshold, const uint32_t Lf);
+template DISKANN_DLLEXPORT int build_disk_index<float, uint16_t>(const char *dataFilePath, const char *indexFilePath,
+ const char *indexBuildParameters,
+ diskann::Metric compareMetric, bool use_opq,
+ const std::string &codebook_prefix, bool use_filters,
+ const std::string &label_file,
+ const std::string &universal_label,
+ const uint32_t filter_threshold, const uint32_t Lf);
+
+template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t, uint32_t>(
+ std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate,
+ double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file,
+ size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf);
+
+template DISKANN_DLLEXPORT int build_merged_vamana_index<float, uint32_t>(std::stringstream & data_stream, diskann::Metric compareMetric, uint32_t L, uint32_t R,
+ double sampling_rate, double ram_budget, std::stringstream &mem_index_stream,
+ std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq,
+ uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label,
+ const uint32_t Lf);
+
+template DISKANN_DLLEXPORT int build_merged_vamana_index<float, uint32_t>(
+ std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate,
+ double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file,
+ size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf);
+template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t, uint32_t>(
+ std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate,
+ double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file,
+ size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf);
+// Label=16_t
+template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t, uint16_t>(
+ std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate,
+ double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file,
+ size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf);
+template DISKANN_DLLEXPORT int build_merged_vamana_index<float, uint16_t>(
+ std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate,
+ double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file,
+ size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf);
+template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t, uint16_t>(
+ std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate,
+ double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file,
+ size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file,
+ const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf);
+}; // namespace diskann
diff --git a/be/src/extern/diskann/src/distance.cpp b/be/src/extern/diskann/src/distance.cpp
new file mode 100644
index 0000000..c2f88c8
--- /dev/null
+++ b/be/src/extern/diskann/src/distance.cpp
@@ -0,0 +1,733 @@
+// TODO
+// CHECK COSINE ON LINUX
+
+#ifdef _WINDOWS
+#include <immintrin.h>
+#include <smmintrin.h>
+#include <tmmintrin.h>
+#include <intrin.h>
+#else
+#include <immintrin.h>
+#endif
+
+#include "simd_utils.h"
+#include <cosine_similarity.h>
+#include <iostream>
+
+#include "distance.h"
+#include "utils.h"
+#include "logger.h"
+#include "ann_exception.h"
+
+namespace diskann
+{
+
+//
+// Base Class Implementatons
+//
+template <typename T>
+float Distance<T>::compare(const T *a, const T *b, const float normA, const float normB, uint32_t length) const
+{
+ throw std::logic_error("This function is not implemented.");
+}
+
+template <typename T> uint32_t Distance<T>::post_normalization_dimension(uint32_t orig_dimension) const
+{
+ return orig_dimension;
+}
+
+template <typename T> diskann::Metric Distance<T>::get_metric() const
+{
+ return _distance_metric;
+}
+
+template <typename T> bool Distance<T>::preprocessing_required() const
+{
+ return false;
+}
+
+template <typename T>
+void Distance<T>::preprocess_base_points(T *original_data, const size_t orig_dim, const size_t num_points)
+{
+}
+
+template <typename T> void Distance<T>::preprocess_query(const T *query_vec, const size_t query_dim, T *scratch_query)
+{
+ std::memcpy(scratch_query, query_vec, query_dim * sizeof(T));
+}
+
+template <typename T> size_t Distance<T>::get_required_alignment() const
+{
+ return _alignment_factor;
+}
+
+//
+// Cosine distance functions.
+//
+
+float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, uint32_t length) const
+{
+#ifdef _WINDOWS
+ return diskann::CosineSimilarity2<int8_t>(a, b, length);
+#else
+ int magA = 0, magB = 0, scalarProduct = 0;
+ for (uint32_t i = 0; i < length; i++)
+ {
+ magA += ((int32_t)a[i]) * ((int32_t)a[i]);
+ magB += ((int32_t)b[i]) * ((int32_t)b[i]);
+ scalarProduct += ((int32_t)a[i]) * ((int32_t)b[i]);
+ }
+ // similarity == 1-cosine distance
+ return 1.0f - (float)(scalarProduct / (sqrt(magA) * sqrt(magB)));
+#endif
+}
+
+float DistanceCosineFloat::compare(const float *a, const float *b, uint32_t length) const
+{
+#ifdef _WINDOWS
+ return diskann::CosineSimilarity2<float>(a, b, length);
+#else
+ float magA = 0, magB = 0, scalarProduct = 0;
+ for (uint32_t i = 0; i < length; i++)
+ {
+ magA += (a[i]) * (a[i]);
+ magB += (b[i]) * (b[i]);
+ scalarProduct += (a[i]) * (b[i]);
+ }
+ // similarity == 1-cosine distance
+ return 1.0f - (scalarProduct / (sqrt(magA) * sqrt(magB)));
+#endif
+}
+
+float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t length) const
+{
+ int magA = 0, magB = 0, scalarProduct = 0;
+ for (uint32_t i = 0; i < length; i++)
+ {
+ magA += ((uint32_t)a[i]) * ((uint32_t)a[i]);
+ magB += ((uint32_t)b[i]) * ((uint32_t)b[i]);
+ scalarProduct += ((uint32_t)a[i]) * ((uint32_t)b[i]);
+ }
+ // similarity == 1-cosine distance
+ return 1.0f - (float)(scalarProduct / (sqrt(magA) * sqrt(magB)));
+}
+
+//
+// L2 distance functions.
+//
+
+float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) const
+{
+#ifdef _WINDOWS
+#ifdef USE_AVX2
+ __m256 r = _mm256_setzero_ps();
+ char *pX = (char *)a, *pY = (char *)b;
+ while (size >= 32)
+ {
+ __m256i r1 = _mm256_subs_epi8(_mm256_loadu_si256((__m256i *)pX), _mm256_loadu_si256((__m256i *)pY));
+ r = _mm256_add_ps(r, _mm256_mul_epi8(r1, r1));
+ pX += 32;
+ pY += 32;
+ size -= 32;
+ }
+ while (size > 0)
+ {
+ __m128i r2 = _mm_subs_epi8(_mm_loadu_si128((__m128i *)pX), _mm_loadu_si128((__m128i *)pY));
+ r = _mm256_add_ps(r, _mm256_mul32_pi8(r2, r2));
+ pX += 4;
+ pY += 4;
+ size -= 4;
+ }
+ r = _mm256_hadd_ps(_mm256_hadd_ps(r, r), r);
+ return r.m256_f32[0] + r.m256_f32[4];
+#else
+ int32_t result = 0;
+#pragma omp simd reduction(+ : result) aligned(a, b : 8)
+ for (int32_t i = 0; i < (int32_t)size; i++)
+ {
+ result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i]));
+ }
+ return (float)result;
+#endif
+#else
+ int32_t result = 0;
+#pragma omp simd reduction(+ : result) aligned(a, b : 8)
+ for (int32_t i = 0; i < (int32_t)size; i++)
+ {
+ result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i]));
+ }
+ return (float)result;
+#endif
+}
+
+float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t size) const
+{
+ uint32_t result = 0;
+#ifndef _WINDOWS
+#pragma omp simd reduction(+ : result) aligned(a, b : 8)
+#endif
+ for (int32_t i = 0; i < (int32_t)size; i++)
+ {
+ result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i]));
+ }
+ return (float)result;
+}
+
+#ifndef _WINDOWS
+float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) const
+{
+ a = (const float *)__builtin_assume_aligned(a, 32);
+ b = (const float *)__builtin_assume_aligned(b, 32);
+#else
+float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) const
+{
+#endif
+
+ float result = 0;
+#ifdef USE_AVX2
+ // assume size is divisible by 8
+ uint16_t niters = (uint16_t)(size / 8);
+ __m256 sum = _mm256_setzero_ps();
+ for (uint16_t j = 0; j < niters; j++)
+ {
+ // scope is a[8j:8j+7], b[8j:8j+7]
+ // load a_vec
+ if (j < (niters - 1))
+ {
+ _mm_prefetch((char *)(a + 8 * (j + 1)), _MM_HINT_T0);
+ _mm_prefetch((char *)(b + 8 * (j + 1)), _MM_HINT_T0);
+ }
+ __m256 a_vec = _mm256_load_ps(a + 8 * j);
+ // load b_vec
+ __m256 b_vec = _mm256_load_ps(b + 8 * j);
+ // a_vec - b_vec
+ __m256 tmp_vec = _mm256_sub_ps(a_vec, b_vec);
+
+ sum = _mm256_fmadd_ps(tmp_vec, tmp_vec, sum);
+ }
+
+ // horizontal add sum
+ result = _mm256_reduce_add_ps(sum);
+#else
+#ifndef _WINDOWS
+#pragma omp simd reduction(+ : result) aligned(a, b : 32)
+#endif
+ for (int32_t i = 0; i < (int32_t)size; i++)
+ {
+ result += (a[i] - b[i]) * (a[i] - b[i]);
+ }
+#endif
+ return result;
+}
+
+template <typename T> float SlowDistanceL2<T>::compare(const T *a, const T *b, uint32_t length) const
+{
+ float result = 0.0f;
+ for (uint32_t i = 0; i < length; i++)
+ {
+ result += ((float)(a[i] - b[i])) * (a[i] - b[i]);
+ }
+ return result;
+}
+
+#ifdef _WINDOWS
+float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t length) const
+{
+ __m128 r = _mm_setzero_ps();
+ __m128i r1;
+ while (length >= 16)
+ {
+ r1 = _mm_subs_epi8(_mm_load_si128((__m128i *)a), _mm_load_si128((__m128i *)b));
+ r = _mm_add_ps(r, _mm_mul_epi8(r1));
+ a += 16;
+ b += 16;
+ length -= 16;
+ }
+ r = _mm_hadd_ps(_mm_hadd_ps(r, r), r);
+ float res = r.m128_f32[0];
+
+ if (length >= 8)
+ {
+ __m128 r2 = _mm_setzero_ps();
+ __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *)(a - 8)), _mm_load_si128((__m128i *)(b - 8)));
+ r2 = _mm_add_ps(r2, _mm_mulhi_epi8(r3));
+ a += 8;
+ b += 8;
+ length -= 8;
+ r2 = _mm_hadd_ps(_mm_hadd_ps(r2, r2), r2);
+ res += r2.m128_f32[0];
+ }
+
+ if (length >= 4)
+ {
+ __m128 r2 = _mm_setzero_ps();
+ __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *)(a - 12)), _mm_load_si128((__m128i *)(b - 12)));
+ r2 = _mm_add_ps(r2, _mm_mulhi_epi8_shift32(r3));
+ res += r2.m128_f32[0] + r2.m128_f32[1];
+ }
+
+ return res;
+}
+
+float AVXDistanceL2Float::compare(const float *a, const float *b, uint32_t length) const
+{
+ __m128 diff, v1, v2;
+ __m128 sum = _mm_set1_ps(0);
+
+ while (length >= 4)
+ {
+ v1 = _mm_loadu_ps(a);
+ a += 4;
+ v2 = _mm_loadu_ps(b);
+ b += 4;
+ diff = _mm_sub_ps(v1, v2);
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
+ length -= 4;
+ }
+
+ return sum.m128_f32[0] + sum.m128_f32[1] + sum.m128_f32[2] + sum.m128_f32[3];
+}
+#else
+float AVXDistanceL2Int8::compare(const int8_t *, const int8_t *, uint32_t) const
+{
+ return 0;
+}
+float AVXDistanceL2Float::compare(const float *, const float *, uint32_t) const
+{
+ return 0;
+}
+#endif
+
+template <typename T> float DistanceInnerProduct<T>::inner_product(const T *a, const T *b, uint32_t size) const
+{
+ if (!std::is_floating_point<T>::value)
+ {
+ diskann::cerr << "ERROR: Inner Product only defined for float currently." << std::endl;
+ throw diskann::ANNException("ERROR: Inner Product only defined for float currently.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ float result = 0;
+
+#ifdef __GNUC__
+#ifdef USE_AVX2
+#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \
+ tmp1 = _mm256_loadu_ps(addr1); \
+ tmp2 = _mm256_loadu_ps(addr2); \
+ tmp1 = _mm256_mul_ps(tmp1, tmp2); \
+ dest = _mm256_add_ps(dest, tmp1);
+
+ __m256 sum;
+ __m256 l0, l1;
+ __m256 r0, r1;
+ uint32_t D = (size + 7) & ~7U;
+ uint32_t DR = D % 16;
+ uint32_t DD = D - DR;
+ const float *l = (float *)a;
+ const float *r = (float *)b;
+ const float *e_l = l + DD;
+ const float *e_r = r + DD;
+ float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
+
+ sum = _mm256_loadu_ps(unpack);
+ if (DR)
+ {
+ AVX_DOT(e_l, e_r, sum, l0, r0);
+ }
+
+ for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16)
+ {
+ AVX_DOT(l, r, sum, l0, r0);
+ AVX_DOT(l + 8, r + 8, sum, l1, r1);
+ }
+ _mm256_storeu_ps(unpack, sum);
+ result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
+
+#else
+#ifdef __SSE2__
+#define SSE_DOT(addr1, addr2, dest, tmp1, tmp2) \
+ tmp1 = _mm128_loadu_ps(addr1); \
+ tmp2 = _mm128_loadu_ps(addr2); \
+ tmp1 = _mm128_mul_ps(tmp1, tmp2); \
+ dest = _mm128_add_ps(dest, tmp1);
+ __m128 sum;
+ __m128 l0, l1, l2, l3;
+ __m128 r0, r1, r2, r3;
+ uint32_t D = (size + 3) & ~3U;
+ uint32_t DR = D % 16;
+ uint32_t DD = D - DR;
+ const float *l = a;
+ const float *r = b;
+ const float *e_l = l + DD;
+ const float *e_r = r + DD;
+ float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0};
+
+ sum = _mm_load_ps(unpack);
+ switch (DR)
+ {
+ case 12:
+ SSE_DOT(e_l + 8, e_r + 8, sum, l2, r2);
+ case 8:
+ SSE_DOT(e_l + 4, e_r + 4, sum, l1, r1);
+ case 4:
+ SSE_DOT(e_l, e_r, sum, l0, r0);
+ default:
+ break;
+ }
+ for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16)
+ {
+ SSE_DOT(l, r, sum, l0, r0);
+ SSE_DOT(l + 4, r + 4, sum, l1, r1);
+ SSE_DOT(l + 8, r + 8, sum, l2, r2);
+ SSE_DOT(l + 12, r + 12, sum, l3, r3);
+ }
+ _mm_storeu_ps(unpack, sum);
+ result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
+#else
+
+ float dot0, dot1, dot2, dot3;
+ const float *last = a + size;
+ const float *unroll_group = last - 3;
+
+ /* Process 4 items with each loop for efficiency. */
+ while (a < unroll_group)
+ {
+ dot0 = a[0] * b[0];
+ dot1 = a[1] * b[1];
+ dot2 = a[2] * b[2];
+ dot3 = a[3] * b[3];
+ result += dot0 + dot1 + dot2 + dot3;
+ a += 4;
+ b += 4;
+ }
+ /* Process last 0-3 pixels. Not needed for standard vector lengths. */
+ while (a < last)
+ {
+ result += *a++ * *b++;
+ }
+#endif
+#endif
+#endif
+ return result;
+}
+
+template <typename T> float DistanceFastL2<T>::compare(const T *a, const T *b, float norm, uint32_t size) const
+{
+ float result = -2 * DistanceInnerProduct<T>::inner_product(a, b, size);
+ result += norm;
+ return result;
+}
+
+template <typename T> float DistanceFastL2<T>::norm(const T *a, uint32_t size) const
+{
+ if (!std::is_floating_point<T>::value)
+ {
+ diskann::cerr << "ERROR: FastL2 only defined for float currently." << std::endl;
+ throw diskann::ANNException("ERROR: FastL2 only defined for float currently.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+ float result = 0;
+#ifdef __GNUC__
+#ifdef __AVX__
+#define AVX_L2NORM(addr, dest, tmp) \
+ tmp = _mm256_loadu_ps(addr); \
+ tmp = _mm256_mul_ps(tmp, tmp); \
+ dest = _mm256_add_ps(dest, tmp);
+
+ __m256 sum;
+ __m256 l0, l1;
+ uint32_t D = (size + 7) & ~7U;
+ uint32_t DR = D % 16;
+ uint32_t DD = D - DR;
+ const float *l = (float *)a;
+ const float *e_l = l + DD;
+ float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
+
+ sum = _mm256_loadu_ps(unpack);
+ if (DR)
+ {
+ AVX_L2NORM(e_l, sum, l0);
+ }
+ for (uint32_t i = 0; i < DD; i += 16, l += 16)
+ {
+ AVX_L2NORM(l, sum, l0);
+ AVX_L2NORM(l + 8, sum, l1);
+ }
+ _mm256_storeu_ps(unpack, sum);
+ result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
+#else
+#ifdef __SSE2__
+#define SSE_L2NORM(addr, dest, tmp) \
+ tmp = _mm128_loadu_ps(addr); \
+ tmp = _mm128_mul_ps(tmp, tmp); \
+ dest = _mm128_add_ps(dest, tmp);
+
+ __m128 sum;
+ __m128 l0, l1, l2, l3;
+ uint32_t D = (size + 3) & ~3U;
+ uint32_t DR = D % 16;
+ uint32_t DD = D - DR;
+ const float *l = a;
+ const float *e_l = l + DD;
+ float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0};
+
+ sum = _mm_load_ps(unpack);
+ switch (DR)
+ {
+ case 12:
+ SSE_L2NORM(e_l + 8, sum, l2);
+ case 8:
+ SSE_L2NORM(e_l + 4, sum, l1);
+ case 4:
+ SSE_L2NORM(e_l, sum, l0);
+ default:
+ break;
+ }
+ for (uint32_t i = 0; i < DD; i += 16, l += 16)
+ {
+ SSE_L2NORM(l, sum, l0);
+ SSE_L2NORM(l + 4, sum, l1);
+ SSE_L2NORM(l + 8, sum, l2);
+ SSE_L2NORM(l + 12, sum, l3);
+ }
+ _mm_storeu_ps(unpack, sum);
+ result += unpack[0] + unpack[1] + unpack[2] + unpack[3];
+#else
+ float dot0, dot1, dot2, dot3;
+ const float *last = a + size;
+ const float *unroll_group = last - 3;
+
+ /* Process 4 items with each loop for efficiency. */
+ while (a < unroll_group)
+ {
+ dot0 = a[0] * a[0];
+ dot1 = a[1] * a[1];
+ dot2 = a[2] * a[2];
+ dot3 = a[3] * a[3];
+ result += dot0 + dot1 + dot2 + dot3;
+ a += 4;
+ }
+ /* Process last 0-3 pixels. Not needed for standard vector lengths. */
+ while (a < last)
+ {
+ result += (*a) * (*a);
+ a++;
+ }
+#endif
+#endif
+#endif
+ return result;
+}
+
+float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, uint32_t size) const
+{
+ float result = 0.0f;
+#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \
+ tmp1 = _mm256_loadu_ps(addr1); \
+ tmp2 = _mm256_loadu_ps(addr2); \
+ tmp1 = _mm256_mul_ps(tmp1, tmp2); \
+ dest = _mm256_add_ps(dest, tmp1);
+
+ __m256 sum;
+ __m256 l0, l1;
+ __m256 r0, r1;
+ uint32_t D = (size + 7) & ~7U;
+ uint32_t DR = D % 16;
+ uint32_t DD = D - DR;
+ const float *l = (float *)a;
+ const float *r = (float *)b;
+ const float *e_l = l + DD;
+ const float *e_r = r + DD;
+#ifndef _WINDOWS
+ float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0};
+#else
+ __declspec(align(32)) float unpack[8] = {0, 0, 0, 0, 0, 0, 0, 0};
+#endif
+
+ sum = _mm256_loadu_ps(unpack);
+ if (DR)
+ {
+ AVX_DOT(e_l, e_r, sum, l0, r0);
+ }
+
+ for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16)
+ {
+ AVX_DOT(l, r, sum, l0, r0);
+ AVX_DOT(l + 8, r + 8, sum, l1, r1);
+ }
+ _mm256_storeu_ps(unpack, sum);
+ result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7];
+
+ return -result;
+}
+
+uint32_t AVXNormalizedCosineDistanceFloat::post_normalization_dimension(uint32_t orig_dimension) const
+{
+ return orig_dimension;
+}
+bool AVXNormalizedCosineDistanceFloat::preprocessing_required() const
+{
+ return true;
+}
+void AVXNormalizedCosineDistanceFloat::preprocess_base_points(float *original_data, const size_t orig_dim,
+ const size_t num_points)
+{
+ for (uint32_t i = 0; i < num_points; i++)
+ {
+ normalize((float *)(original_data + i * orig_dim), orig_dim);
+ }
+}
+
+void AVXNormalizedCosineDistanceFloat::preprocess_query(const float *query_vec, const size_t query_dim,
+ float *query_scratch)
+{
+ normalize_and_copy(query_vec, (uint32_t)query_dim, query_scratch);
+}
+
+void AVXNormalizedCosineDistanceFloat::normalize_and_copy(const float *query_vec, const uint32_t query_dim,
+ float *query_target) const
+{
+ float norm = get_norm(query_vec, query_dim);
+
+ for (uint32_t i = 0; i < query_dim; i++)
+ {
+ query_target[i] = query_vec[i] / norm;
+ }
+}
+
+// Get the right distance function for the given metric.
+template <> diskann::Distance<float> *get_distance_function(diskann::Metric m)
+{
+ if (m == diskann::Metric::L2)
+ {
+ if (Avx2SupportedCPU)
+ {
+ diskann::cout << "L2: Using AVX2 distance computation DistanceL2Float" << std::endl;
+ return new diskann::DistanceL2Float();
+ }
+ else if (AvxSupportedCPU)
+ {
+ diskann::cout << "L2: AVX2 not supported. Using AVX distance computation" << std::endl;
+ return new diskann::AVXDistanceL2Float();
+ }
+ else
+ {
+ diskann::cout << "L2: Older CPU. Using slow distance computation" << std::endl;
+ return new diskann::SlowDistanceL2<float>();
+ }
+ }
+ else if (m == diskann::Metric::COSINE)
+ {
+ diskann::cout << "Cosine: Using either AVX or AVX2 implementation" << std::endl;
+ return new diskann::DistanceCosineFloat();
+ }
+ else if (m == diskann::Metric::INNER_PRODUCT)
+ {
+ diskann::cout << "Inner product: Using AVX2 implementation "
+ "AVXDistanceInnerProductFloat"
+ << std::endl;
+ return new diskann::AVXDistanceInnerProductFloat();
+ }
+ else if (m == diskann::Metric::FAST_L2)
+ {
+ diskann::cout << "Fast_L2: Using AVX2 implementation with norm "
+ "memoization DistanceFastL2<float>"
+ << std::endl;
+ return new diskann::DistanceFastL2<float>();
+ }
+ else
+ {
+ std::stringstream stream;
+ stream << "Only L2, cosine, and inner product supported for floating "
+ "point vectors as of now."
+ << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+}
+
+template <> diskann::Distance<int8_t> *get_distance_function(diskann::Metric m)
+{
+ if (m == diskann::Metric::L2)
+ {
+ if (Avx2SupportedCPU)
+ {
+ diskann::cout << "Using AVX2 distance computation DistanceL2Int8." << std::endl;
+ return new diskann::DistanceL2Int8();
+ }
+ else if (AvxSupportedCPU)
+ {
+ diskann::cout << "AVX2 not supported. Using AVX distance computation" << std::endl;
+ return new diskann::AVXDistanceL2Int8();
+ }
+ else
+ {
+ diskann::cout << "Older CPU. Using slow distance computation "
+ "SlowDistanceL2Int<int8_t>."
+ << std::endl;
+ return new diskann::SlowDistanceL2<int8_t>();
+ }
+ }
+ else if (m == diskann::Metric::COSINE)
+ {
+ diskann::cout << "Using either AVX or AVX2 for Cosine similarity "
+ "DistanceCosineInt8."
+ << std::endl;
+ return new diskann::DistanceCosineInt8();
+ }
+ else
+ {
+ std::stringstream stream;
+ stream << "Only L2 and cosine supported for signed byte vectors." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+}
+
+template <> diskann::Distance<uint8_t> *get_distance_function(diskann::Metric m)
+{
+ if (m == diskann::Metric::L2)
+ {
+#ifdef _WINDOWS
+ diskann::cout << "WARNING: AVX/AVX2 distance function not defined for Uint8. "
+ "Using "
+ "slow version. "
+ "Contact gopalsr@microsoft.com if you need AVX/AVX2 support."
+ << std::endl;
+#endif
+ return new diskann::DistanceL2UInt8();
+ }
+ else if (m == diskann::Metric::COSINE)
+ {
+ diskann::cout << "AVX/AVX2 distance function not defined for Uint8. Using "
+ "slow version SlowDistanceCosineUint8() "
+ "Contact gopalsr@microsoft.com if you need AVX/AVX2 support."
+ << std::endl;
+ return new diskann::SlowDistanceCosineUInt8();
+ }
+ else
+ {
+ std::stringstream stream;
+ stream << "Only L2 and cosine supported for uint32_t byte vectors." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+}
+
+template DISKANN_DLLEXPORT class DistanceInnerProduct<float>;
+template DISKANN_DLLEXPORT class DistanceInnerProduct<int8_t>;
+template DISKANN_DLLEXPORT class DistanceInnerProduct<uint8_t>;
+
+template DISKANN_DLLEXPORT class DistanceFastL2<float>;
+template DISKANN_DLLEXPORT class DistanceFastL2<int8_t>;
+template DISKANN_DLLEXPORT class DistanceFastL2<uint8_t>;
+
+template DISKANN_DLLEXPORT class SlowDistanceL2<float>;
+template DISKANN_DLLEXPORT class SlowDistanceL2<int8_t>;
+template DISKANN_DLLEXPORT class SlowDistanceL2<uint8_t>;
+
+template DISKANN_DLLEXPORT Distance<float> *get_distance_function(Metric m);
+template DISKANN_DLLEXPORT Distance<int8_t> *get_distance_function(Metric m);
+template DISKANN_DLLEXPORT Distance<uint8_t> *get_distance_function(Metric m);
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/filter_utils.cpp b/be/src/extern/diskann/src/filter_utils.cpp
new file mode 100644
index 0000000..09d740e
--- /dev/null
+++ b/be/src/extern/diskann/src/filter_utils.cpp
@@ -0,0 +1,355 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <chrono>
+#include <cstdio>
+#include <cstring>
+#include <random>
+#include <string>
+#include <tuple>
+
+#include <omp.h>
+#include "filter_utils.h"
+#include "index.h"
+#include "parameters.h"
+#include "utils.h"
+
+namespace diskann
+{
+/*
+ * Using passed in parameters and files generated from step 3,
+ * builds a vanilla diskANN index for each label.
+ *
+ * Each index is saved under the following path:
+ * final_index_path_prefix + "_" + label
+ */
+template <typename T>
+void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels, uint32_t R,
+ uint32_t L, float alpha, uint32_t num_threads)
+{
+ diskann::IndexWriteParameters label_index_build_parameters = diskann::IndexWriteParametersBuilder(L, R)
+ .with_saturate_graph(false)
+ .with_alpha(alpha)
+ .with_num_threads(num_threads)
+ .build();
+
+ std::cout << "Generating indices per label..." << std::endl;
+ // for each label, build an index on resp. points
+ double total_indexing_time = 0.0, indexing_percentage = 0.0;
+ std::cout.setstate(std::ios_base::failbit);
+ diskann::cout.setstate(std::ios_base::failbit);
+ for (const auto &lbl : all_labels)
+ {
+ path curr_label_input_data_path(input_data_path + "_" + lbl);
+ path curr_label_index_path(final_index_path_prefix + "_" + lbl);
+
+ size_t number_of_label_points, dimension;
+ diskann::get_bin_metadata(curr_label_input_data_path, number_of_label_points, dimension);
+
+ diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points,
+ std::make_shared<diskann::IndexWriteParameters>(label_index_build_parameters), nullptr,
+ 0, false, false, false, false, 0, false);
+
+ auto index_build_timer = std::chrono::high_resolution_clock::now();
+ index.build(curr_label_input_data_path.c_str(), number_of_label_points);
+ std::chrono::duration<double> current_indexing_time =
+ std::chrono::high_resolution_clock::now() - index_build_timer;
+
+ total_indexing_time += current_indexing_time.count();
+ indexing_percentage += (1 / (double)all_labels.size());
+ print_progress(indexing_percentage);
+
+ index.save(curr_label_index_path.c_str());
+ }
+ std::cout.clear();
+ diskann::cout.clear();
+
+ std::cout << "\nDone. Generated per-label indices in " << total_indexing_time << " seconds\n" << std::endl;
+}
+
+// for use on systems without writev (i.e. Windows)
+template <typename T>
+tsl::robin_map<std::string, std::vector<uint32_t>> generate_label_specific_vector_files_compat(
+ path input_data_path, tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
+ std::vector<label_set> point_ids_to_labels, label_set all_labels)
+{
+ auto file_writing_timer = std::chrono::high_resolution_clock::now();
+ std::ifstream input_data_stream(input_data_path);
+
+ uint32_t number_of_points, dimension;
+ input_data_stream.read((char *)&number_of_points, sizeof(uint32_t));
+ input_data_stream.read((char *)&dimension, sizeof(uint32_t));
+ const uint32_t VECTOR_SIZE = dimension * sizeof(T);
+ if (number_of_points != point_ids_to_labels.size())
+ {
+ std::cerr << "Error: number of points in labels file and data file differ." << std::endl;
+ throw;
+ }
+
+ tsl::robin_map<std::string, char *> labels_to_vectors;
+ tsl::robin_map<std::string, uint32_t> labels_to_curr_vector;
+ tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id;
+
+ for (const auto &lbl : all_labels)
+ {
+ uint32_t number_of_label_pts = labels_to_number_of_points[lbl];
+ char *vectors = (char *)malloc(number_of_label_pts * VECTOR_SIZE);
+ if (vectors == nullptr)
+ {
+ throw;
+ }
+ labels_to_vectors[lbl] = vectors;
+ labels_to_curr_vector[lbl] = 0;
+ label_id_to_orig_id[lbl].reserve(number_of_label_pts);
+ }
+
+ for (uint32_t point_id = 0; point_id < number_of_points; point_id++)
+ {
+ char *curr_vector = (char *)malloc(VECTOR_SIZE);
+ input_data_stream.read(curr_vector, VECTOR_SIZE);
+ for (const auto &lbl : point_ids_to_labels[point_id])
+ {
+ char *curr_label_vector_ptr = labels_to_vectors[lbl] + (labels_to_curr_vector[lbl] * VECTOR_SIZE);
+ memcpy(curr_label_vector_ptr, curr_vector, VECTOR_SIZE);
+ labels_to_curr_vector[lbl]++;
+ label_id_to_orig_id[lbl].push_back(point_id);
+ }
+ free(curr_vector);
+ }
+
+ for (const auto &lbl : all_labels)
+ {
+ path curr_label_input_data_path(input_data_path + "_" + lbl);
+ uint32_t number_of_label_pts = labels_to_number_of_points[lbl];
+
+ std::ofstream label_file_stream;
+ label_file_stream.exceptions(std::ios::badbit | std::ios::failbit);
+ label_file_stream.open(curr_label_input_data_path, std::ios_base::binary);
+ label_file_stream.write((char *)&number_of_label_pts, sizeof(uint32_t));
+ label_file_stream.write((char *)&dimension, sizeof(uint32_t));
+ label_file_stream.write((char *)labels_to_vectors[lbl], number_of_label_pts * VECTOR_SIZE);
+
+ label_file_stream.close();
+ free(labels_to_vectors[lbl]);
+ }
+ input_data_stream.close();
+
+ std::chrono::duration<double> file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer;
+ std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time "
+ << file_writing_time.count() << "\n"
+ << std::endl;
+
+ return label_id_to_orig_id;
+}
+
+/*
+ * Manually loads a graph index in from a given file.
+ *
+ * Returns both the graph index and the size of the file in bytes.
+ */
+load_label_index_return_values load_label_index(path label_index_path, uint32_t label_number_of_points)
+{
+ std::ifstream label_index_stream;
+ label_index_stream.exceptions(std::ios::badbit | std::ios::failbit);
+ label_index_stream.open(label_index_path, std::ios::binary);
+
+ uint64_t index_file_size, index_num_frozen_points;
+ uint32_t index_max_observed_degree, index_entry_point;
+ const size_t INDEX_METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
+ label_index_stream.read((char *)&index_file_size, sizeof(uint64_t));
+ label_index_stream.read((char *)&index_max_observed_degree, sizeof(uint32_t));
+ label_index_stream.read((char *)&index_entry_point, sizeof(uint32_t));
+ label_index_stream.read((char *)&index_num_frozen_points, sizeof(uint64_t));
+ size_t bytes_read = INDEX_METADATA;
+
+ std::vector<std::vector<uint32_t>> label_index(label_number_of_points);
+ uint32_t nodes_read = 0;
+ while (bytes_read != index_file_size)
+ {
+ uint32_t current_node_num_neighbors;
+ label_index_stream.read((char *)¤t_node_num_neighbors, sizeof(uint32_t));
+ nodes_read++;
+
+ std::vector<uint32_t> current_node_neighbors(current_node_num_neighbors);
+ label_index_stream.read((char *)current_node_neighbors.data(), current_node_num_neighbors * sizeof(uint32_t));
+ label_index[nodes_read - 1].swap(current_node_neighbors);
+ bytes_read += sizeof(uint32_t) * (current_node_num_neighbors + 1);
+ }
+
+ return std::make_tuple(label_index, index_file_size);
+}
+
+/*
+ * Parses the label datafile, which has comma-separated labels on
+ * each line. Line i corresponds to point id i.
+ *
+ * Returns three objects via std::tuple:
+ * 1. map: key is point id, value is vector of labels said point has
+ * 2. map: key is label, value is number of points with the label
+ * 3. the label universe as a set
+ */
+parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label)
+{
+ std::ifstream label_data_stream(label_data_path);
+ std::string line, token;
+ uint32_t line_cnt = 0;
+
+ // allows us to reserve space for the points_to_labels vector
+ while (std::getline(label_data_stream, line))
+ line_cnt++;
+ label_data_stream.clear();
+ label_data_stream.seekg(0, std::ios::beg);
+
+ // values to return
+ std::vector<label_set> point_ids_to_labels(line_cnt);
+ tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
+ label_set all_labels;
+
+ std::vector<uint32_t> points_with_universal_label;
+ line_cnt = 0;
+ while (std::getline(label_data_stream, line))
+ {
+ std::istringstream current_labels_comma_separated(line);
+ label_set current_labels;
+
+ // get point id
+ uint32_t point_id = line_cnt;
+
+ // parse comma separated labels
+ bool current_universal_label_check = false;
+ while (getline(current_labels_comma_separated, token, ','))
+ {
+ token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
+ token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
+
+ // if token is empty, there's no labels for the point
+ if (token == universal_label)
+ {
+ points_with_universal_label.push_back(point_id);
+ current_universal_label_check = true;
+ }
+ else
+ {
+ all_labels.insert(token);
+ current_labels.insert(token);
+ labels_to_number_of_points[token]++;
+ }
+ }
+
+ if (current_labels.size() <= 0 && !current_universal_label_check)
+ {
+ std::cerr << "Error: " << point_id << " has no labels." << std::endl;
+ exit(-1);
+ }
+ point_ids_to_labels[point_id] = current_labels;
+ line_cnt++;
+ }
+
+ // for every point with universal label, set its label set to all labels
+ // also, increment the count for number of points a label has
+ for (const auto &point_id : points_with_universal_label)
+ {
+ point_ids_to_labels[point_id] = all_labels;
+ for (const auto &lbl : all_labels)
+ labels_to_number_of_points[lbl]++;
+ }
+
+ std::cout << "Identified " << all_labels.size() << " distinct label(s) for " << point_ids_to_labels.size()
+ << " points\n"
+ << std::endl;
+
+ return std::make_tuple(point_ids_to_labels, labels_to_number_of_points, all_labels);
+}
+
+/*
+ * A templated function to parse a file of labels that are already represented
+ * as either uint16_t or uint32_t
+ *
+ * Returns two objects via std::tuple:
+ * 1. a vector of vectors of labels, where the outer vector is indexed by point id
+ * 2. a set of all labels
+ */
+template <typename LabelT>
+std::tuple<std::vector<std::vector<LabelT>>, tsl::robin_set<LabelT>> parse_formatted_label_file(std::string label_file)
+{
+ std::vector<std::vector<LabelT>> pts_to_labels;
+ tsl::robin_set<LabelT> labels;
+
+ // Format of Label txt file: filters with comma separators
+ std::ifstream infile(label_file);
+ if (infile.fail())
+ {
+ throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
+ }
+
+ std::string line, token;
+ uint32_t line_cnt = 0;
+
+ while (std::getline(infile, line))
+ {
+ line_cnt++;
+ }
+ pts_to_labels.resize(line_cnt, std::vector<LabelT>());
+
+ infile.clear();
+ infile.seekg(0, std::ios::beg);
+ line_cnt = 0;
+
+ while (std::getline(infile, line))
+ {
+ std::istringstream iss(line);
+ std::vector<LabelT> lbls(0);
+ getline(iss, token, '\t');
+ std::istringstream new_iss(token);
+ while (getline(new_iss, token, ','))
+ {
+ token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
+ token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
+ LabelT token_as_num = static_cast<LabelT>(std::stoul(token));
+ lbls.push_back(token_as_num);
+ labels.insert(token_as_num);
+ }
+ if (lbls.size() <= 0)
+ {
+ diskann::cout << "No label found";
+ exit(-1);
+ }
+ std::sort(lbls.begin(), lbls.end());
+ pts_to_labels[line_cnt] = lbls;
+ line_cnt++;
+ }
+ diskann::cout << "Identified " << labels.size() << " distinct label(s)" << std::endl;
+
+ return std::make_tuple(pts_to_labels, labels);
+}
+
+template DISKANN_DLLEXPORT std::tuple<std::vector<std::vector<uint32_t>>, tsl::robin_set<uint32_t>>
+parse_formatted_label_file(path label_file);
+
+template DISKANN_DLLEXPORT std::tuple<std::vector<std::vector<uint16_t>>, tsl::robin_set<uint16_t>>
+parse_formatted_label_file(path label_file);
+
+template DISKANN_DLLEXPORT void generate_label_indices<float>(path input_data_path, path final_index_path_prefix,
+ label_set all_labels, uint32_t R, uint32_t L, float alpha,
+ uint32_t num_threads);
+template DISKANN_DLLEXPORT void generate_label_indices<uint8_t>(path input_data_path, path final_index_path_prefix,
+ label_set all_labels, uint32_t R, uint32_t L,
+ float alpha, uint32_t num_threads);
+template DISKANN_DLLEXPORT void generate_label_indices<int8_t>(path input_data_path, path final_index_path_prefix,
+ label_set all_labels, uint32_t R, uint32_t L,
+ float alpha, uint32_t num_threads);
+
+template DISKANN_DLLEXPORT tsl::robin_map<std::string, std::vector<uint32_t>>
+generate_label_specific_vector_files_compat<float>(path input_data_path,
+ tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
+ std::vector<label_set> point_ids_to_labels, label_set all_labels);
+template DISKANN_DLLEXPORT tsl::robin_map<std::string, std::vector<uint32_t>>
+generate_label_specific_vector_files_compat<uint8_t>(path input_data_path,
+ tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
+ std::vector<label_set> point_ids_to_labels, label_set all_labels);
+template DISKANN_DLLEXPORT tsl::robin_map<std::string, std::vector<uint32_t>>
+generate_label_specific_vector_files_compat<int8_t>(path input_data_path,
+ tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
+ std::vector<label_set> point_ids_to_labels, label_set all_labels);
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/src/in_mem_data_store.cpp b/be/src/extern/diskann/src/in_mem_data_store.cpp
new file mode 100644
index 0000000..cc7acf6
--- /dev/null
+++ b/be/src/extern/diskann/src/in_mem_data_store.cpp
@@ -0,0 +1,401 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <memory>
+#include "abstract_scratch.h"
+#include "in_mem_data_store.h"
+
+#include "utils.h"
+
+namespace diskann
+{
+
+template <typename data_t>
+InMemDataStore<data_t>::InMemDataStore(const location_t num_points, const size_t dim,
+ std::unique_ptr<Distance<data_t>> distance_fn)
+ : AbstractDataStore<data_t>(num_points, dim), _distance_fn(std::move(distance_fn))
+{
+ _aligned_dim = ROUND_UP(dim, _distance_fn->get_required_alignment());
+ alloc_aligned(((void **)&_data), this->_capacity * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t));
+ std::memset(_data, 0, this->_capacity * _aligned_dim * sizeof(data_t));
+}
+
+template <typename data_t> InMemDataStore<data_t>::~InMemDataStore()
+{
+ if (_data != nullptr)
+ {
+ aligned_free(this->_data);
+ }
+}
+
+template <typename data_t> size_t InMemDataStore<data_t>::get_aligned_dim() const
+{
+ return _aligned_dim;
+}
+
+template <typename data_t> size_t InMemDataStore<data_t>::get_alignment_factor() const
+{
+ return _distance_fn->get_required_alignment();
+}
+
+template <typename data_t> location_t InMemDataStore<data_t>::load(const std::string &filename)
+{
+ return load_impl(filename);
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename data_t> location_t InMemDataStore<data_t>::load_impl(AlignedFileReader &reader)
+{
+ size_t file_dim, file_num_points;
+
+ diskann::get_bin_metadata(reader, file_num_points, file_dim);
+
+ if (file_dim != this->_dim)
+ {
+ std::stringstream stream;
+ stream << "ERROR: Driver requests loading " << this->_dim << " dimension,"
+ << "but file has " << file_dim << " dimension." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ aligned_free(_data);
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (file_num_points > this->capacity())
+ {
+ this->resize((location_t)file_num_points);
+ }
+ copy_aligned_data_from_file<data_t>(reader, _data, file_num_points, file_dim, _aligned_dim);
+
+ return (location_t)file_num_points;
+}
+#endif
+
+template <typename data_t> location_t InMemDataStore<data_t>::load_impl(const std::string &filename)
+{
+ size_t file_dim, file_num_points;
+ if (!file_exists(filename))
+ {
+ std::stringstream stream;
+ stream << "ERROR: data file " << filename << " does not exist." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ aligned_free(_data);
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ diskann::get_bin_metadata(filename, file_num_points, file_dim);
+
+ if (file_dim != this->_dim)
+ {
+ std::stringstream stream;
+ stream << "ERROR: Driver requests loading " << this->_dim << " dimension,"
+ << "but file has " << file_dim << " dimension." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ aligned_free(_data);
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (file_num_points > this->capacity())
+ {
+ this->resize((location_t)file_num_points);
+ }
+
+ copy_aligned_data_from_file<data_t>(filename.c_str(), _data, file_num_points, file_dim, _aligned_dim);
+
+ return (location_t)file_num_points;
+}
+
+template <typename data_t> size_t InMemDataStore<data_t>::save(const std::string &filename, const location_t num_points)
+{
+ return save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U);
+}
+
+template <typename data_t> void InMemDataStore<data_t>::populate_data(const data_t *vectors, const location_t num_pts)
+{
+ memset(_data, 0, _aligned_dim * sizeof(data_t) * num_pts);
+ for (location_t i = 0; i < num_pts; i++)
+ {
+ std::memmove(_data + i * _aligned_dim, vectors + i * this->_dim, this->_dim * sizeof(data_t));
+ }
+
+ if (_distance_fn->preprocessing_required())
+ {
+ _distance_fn->preprocess_base_points(_data, this->_aligned_dim, num_pts);
+ }
+}
+
+template <typename data_t> void InMemDataStore<data_t>::populate_data(const std::string &filename, const size_t offset)
+{
+ size_t npts, ndim;
+ copy_aligned_data_from_file(filename.c_str(), _data, npts, ndim, _aligned_dim, offset);
+
+ if ((location_t)npts > this->capacity())
+ {
+ std::stringstream ss;
+ ss << "Number of points in the file: " << filename
+ << " is greater than the capacity of data store: " << this->capacity()
+ << ". Must invoke resize before calling populate_data()" << std::endl;
+ throw diskann::ANNException(ss.str(), -1);
+ }
+
+ if ((location_t)ndim != this->get_dims())
+ {
+ std::stringstream ss;
+ ss << "Number of dimensions of a point in the file: " << filename
+ << " is not equal to dimensions of data store: " << this->capacity() << "." << std::endl;
+ throw diskann::ANNException(ss.str(), -1);
+ }
+
+ if (_distance_fn->preprocessing_required())
+ {
+ _distance_fn->preprocess_base_points(_data, this->_aligned_dim, this->capacity());
+ }
+}
+
+template <typename data_t>
+void InMemDataStore<data_t>::extract_data_to_bin(const std::string &filename, const location_t num_points)
+{
+ save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U);
+}
+
+template <typename data_t> void InMemDataStore<data_t>::get_vector(const location_t i, data_t *dest) const
+{
+ // REFACTOR TODO: Should we denormalize and return values?
+ memcpy(dest, _data + i * _aligned_dim, this->_dim * sizeof(data_t));
+}
+
+template <typename data_t> void InMemDataStore<data_t>::set_vector(const location_t loc, const data_t *const vector)
+{
+ size_t offset_in_data = loc * _aligned_dim;
+ memset(_data + offset_in_data, 0, _aligned_dim * sizeof(data_t));
+ memcpy(_data + offset_in_data, vector, this->_dim * sizeof(data_t));
+ if (_distance_fn->preprocessing_required())
+ {
+ _distance_fn->preprocess_base_points(_data + offset_in_data, _aligned_dim, 1);
+ }
+}
+
+template <typename data_t> void InMemDataStore<data_t>::prefetch_vector(const location_t loc)
+{
+ diskann::prefetch_vector((const char *)_data + _aligned_dim * (size_t)loc * sizeof(data_t),
+ sizeof(data_t) * _aligned_dim);
+}
+
+template <typename data_t>
+void InMemDataStore<data_t>::preprocess_query(const data_t *query, AbstractScratch<data_t> *query_scratch) const
+{
+ if (query_scratch != nullptr)
+ {
+ memcpy(query_scratch->aligned_query_T(), query, sizeof(data_t) * this->get_dims());
+ }
+ else
+ {
+ std::stringstream ss;
+ ss << "In InMemDataStore::preprocess_query: Query scratch is null";
+ diskann::cerr << ss.str() << std::endl;
+ throw diskann::ANNException(ss.str(), -1);
+ }
+}
+
+template <typename data_t> float InMemDataStore<data_t>::get_distance(const data_t *query, const location_t loc) const
+{
+ return _distance_fn->compare(query, _data + _aligned_dim * loc, (uint32_t)_aligned_dim);
+}
+
+template <typename data_t>
+void InMemDataStore<data_t>::get_distance(const data_t *query, const location_t *locations,
+ const uint32_t location_count, float *distances,
+ AbstractScratch<data_t> *scratch_space) const
+{
+ for (location_t i = 0; i < location_count; i++)
+ {
+ distances[i] = _distance_fn->compare(query, _data + locations[i] * _aligned_dim, (uint32_t)this->_aligned_dim);
+ }
+}
+
+template <typename data_t>
+float InMemDataStore<data_t>::get_distance(const location_t loc1, const location_t loc2) const
+{
+ return _distance_fn->compare(_data + loc1 * _aligned_dim, _data + loc2 * _aligned_dim,
+ (uint32_t)this->_aligned_dim);
+}
+
+template <typename data_t>
+void InMemDataStore<data_t>::get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
+ std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const
+{
+ for (int i = 0; i < ids.size(); i++)
+ {
+ distances[i] =
+ _distance_fn->compare(preprocessed_query, _data + ids[i] * _aligned_dim, (uint32_t)this->_aligned_dim);
+ }
+}
+
+template <typename data_t> location_t InMemDataStore<data_t>::expand(const location_t new_size)
+{
+ if (new_size == this->capacity())
+ {
+ return this->capacity();
+ }
+ else if (new_size < this->capacity())
+ {
+ std::stringstream ss;
+ ss << "Cannot 'expand' datastore when new capacity (" << new_size << ") < existing capacity("
+ << this->capacity() << ")" << std::endl;
+ throw diskann::ANNException(ss.str(), -1);
+ }
+#ifndef _WINDOWS
+ data_t *new_data;
+ alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t));
+ memcpy(new_data, _data, this->capacity() * _aligned_dim * sizeof(data_t));
+ aligned_free(_data);
+ _data = new_data;
+#else
+ realloc_aligned((void **)&_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t));
+#endif
+ this->_capacity = new_size;
+ return this->_capacity;
+}
+
+template <typename data_t> location_t InMemDataStore<data_t>::shrink(const location_t new_size)
+{
+ if (new_size == this->capacity())
+ {
+ return this->capacity();
+ }
+ else if (new_size > this->capacity())
+ {
+ std::stringstream ss;
+ ss << "Cannot 'shrink' datastore when new capacity (" << new_size << ") > existing capacity("
+ << this->capacity() << ")" << std::endl;
+ throw diskann::ANNException(ss.str(), -1);
+ }
+#ifndef _WINDOWS
+ data_t *new_data;
+ alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t));
+ memcpy(new_data, _data, new_size * _aligned_dim * sizeof(data_t));
+ aligned_free(_data);
+ _data = new_data;
+#else
+ realloc_aligned((void **)&_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t));
+#endif
+ this->_capacity = new_size;
+ return this->_capacity;
+}
+
+template <typename data_t>
+void InMemDataStore<data_t>::move_vectors(const location_t old_location_start, const location_t new_location_start,
+ const location_t num_locations)
+{
+ if (num_locations == 0 || old_location_start == new_location_start)
+ {
+ return;
+ }
+
+ /* // Update pointers to the moved nodes. Note: the computation is correct
+ even
+ // when new_location_start < old_location_start given the C++ uint32_t
+ // integer arithmetic rules.
+ const uint32_t location_delta = new_location_start - old_location_start;
+ */
+ // The [start, end) interval which will contain obsolete points to be
+ // cleared.
+ uint32_t mem_clear_loc_start = old_location_start;
+ uint32_t mem_clear_loc_end_limit = old_location_start + num_locations;
+
+ if (new_location_start < old_location_start)
+ {
+ // If ranges are overlapping, make sure not to clear the newly copied
+ // data.
+ if (mem_clear_loc_start < new_location_start + num_locations)
+ {
+ // Clear only after the end of the new range.
+ mem_clear_loc_start = new_location_start + num_locations;
+ }
+ }
+ else
+ {
+ // If ranges are overlapping, make sure not to clear the newly copied
+ // data.
+ if (mem_clear_loc_end_limit > new_location_start)
+ {
+ // Clear only up to the beginning of the new range.
+ mem_clear_loc_end_limit = new_location_start;
+ }
+ }
+
+ // Use memmove to handle overlapping ranges.
+ copy_vectors(old_location_start, new_location_start, num_locations);
+ memset(_data + _aligned_dim * mem_clear_loc_start, 0,
+ sizeof(data_t) * _aligned_dim * (mem_clear_loc_end_limit - mem_clear_loc_start));
+}
+
+template <typename data_t>
+void InMemDataStore<data_t>::copy_vectors(const location_t from_loc, const location_t to_loc,
+ const location_t num_points)
+{
+ assert(from_loc < this->_capacity);
+ assert(to_loc < this->_capacity);
+ assert(num_points < this->_capacity);
+ memmove(_data + _aligned_dim * to_loc, _data + _aligned_dim * from_loc, num_points * _aligned_dim * sizeof(data_t));
+}
+
+template <typename data_t> location_t InMemDataStore<data_t>::calculate_medoid() const
+{
+ // allocate and init centroid
+ float *center = new float[_aligned_dim];
+ for (size_t j = 0; j < _aligned_dim; j++)
+ center[j] = 0;
+
+ for (size_t i = 0; i < this->capacity(); i++)
+ for (size_t j = 0; j < _aligned_dim; j++)
+ center[j] += (float)_data[i * _aligned_dim + j];
+
+ for (size_t j = 0; j < _aligned_dim; j++)
+ center[j] /= (float)this->capacity();
+
+ // compute all to one distance
+ float *distances = new float[this->capacity()];
+
+ // TODO: REFACTOR. Removing pragma might make this slow. Must revisit.
+ // Problem is that we need to pass num_threads here, it is not clear
+ // if data store must be aware of threads!
+ // #pragma omp parallel for schedule(static, 65536)
+ for (int64_t i = 0; i < (int64_t)this->capacity(); i++)
+ {
+ // extract point and distance reference
+ float &dist = distances[i];
+ const data_t *cur_vec = _data + (i * (size_t)_aligned_dim);
+ dist = 0;
+ float diff = 0;
+ for (size_t j = 0; j < _aligned_dim; j++)
+ {
+ diff = (center[j] - (float)cur_vec[j]) * (center[j] - (float)cur_vec[j]);
+ dist += diff;
+ }
+ }
+ // find imin
+ uint32_t min_idx = 0;
+ float min_dist = distances[0];
+ for (uint32_t i = 1; i < this->capacity(); i++)
+ {
+ if (distances[i] < min_dist)
+ {
+ min_idx = i;
+ min_dist = distances[i];
+ }
+ }
+
+ delete[] distances;
+ delete[] center;
+ return min_idx;
+}
+
+template <typename data_t> Distance<data_t> *InMemDataStore<data_t>::get_dist_fn() const
+{
+ return this->_distance_fn.get();
+}
+
+template DISKANN_DLLEXPORT class InMemDataStore<float>;
+template DISKANN_DLLEXPORT class InMemDataStore<int8_t>;
+template DISKANN_DLLEXPORT class InMemDataStore<uint8_t>;
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/src/in_mem_graph_store.cpp b/be/src/extern/diskann/src/in_mem_graph_store.cpp
new file mode 100644
index 0000000..cae6459
--- /dev/null
+++ b/be/src/extern/diskann/src/in_mem_graph_store.cpp
@@ -0,0 +1,284 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include "in_mem_graph_store.h"
+#include "utils.h"
+
+namespace diskann
+{
+InMemGraphStore::InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree)
+ : AbstractGraphStore(total_pts, reserve_graph_degree)
+{
+ this->resize_graph(total_pts);
+ for (size_t i = 0; i < total_pts; i++)
+ {
+ _graph[i].reserve(reserve_graph_degree);
+ }
+}
+
+std::tuple<uint32_t, uint32_t, size_t> InMemGraphStore::load(const std::string &index_path_prefix,
+ const size_t num_points)
+{
+ return load_impl(index_path_prefix, num_points);
+}
+int InMemGraphStore::store(const std::string &index_path_prefix, const size_t num_points,
+ const size_t num_frozen_points, const uint32_t start)
+{
+ return save_graph(index_path_prefix, num_points, num_frozen_points, start);
+}
+
+int InMemGraphStore::store(std::stringstream &index_stream, const size_t num_points,
+ const size_t num_frozen_points, const uint32_t start)
+{
+ return save_graph(index_stream, num_points, num_frozen_points, start);
+}
+
+const std::vector<location_t> &InMemGraphStore::get_neighbours(const location_t i) const
+{
+ return _graph.at(i);
+}
+
+void InMemGraphStore::add_neighbour(const location_t i, location_t neighbour_id)
+{
+ _graph[i].emplace_back(neighbour_id);
+ if (_max_observed_degree < _graph[i].size())
+ {
+ _max_observed_degree = (uint32_t)(_graph[i].size());
+ }
+}
+
+void InMemGraphStore::clear_neighbours(const location_t i)
+{
+ _graph[i].clear();
+};
+void InMemGraphStore::swap_neighbours(const location_t a, location_t b)
+{
+ _graph[a].swap(_graph[b]);
+};
+
+void InMemGraphStore::set_neighbours(const location_t i, std::vector<location_t> &neighbours)
+{
+ _graph[i].assign(neighbours.begin(), neighbours.end());
+ if (_max_observed_degree < neighbours.size())
+ {
+ _max_observed_degree = (uint32_t)(neighbours.size());
+ }
+}
+
+size_t InMemGraphStore::resize_graph(const size_t new_size)
+{
+ _graph.resize(new_size);
+ set_total_points(new_size);
+ return _graph.size();
+}
+
+void InMemGraphStore::clear_graph()
+{
+ _graph.clear();
+}
+
+#ifdef EXEC_ENV_OLS
+std::tuple<uint32_t, uint32_t, size_t> InMemGraphStore::load_impl(AlignedFileReader &reader, size_t expected_num_points)
+{
+ size_t expected_file_size;
+ size_t file_frozen_pts;
+ uint32_t start;
+
+ auto max_points = get_max_points();
+ int header_size = 2 * sizeof(size_t) + 2 * sizeof(uint32_t);
+ std::unique_ptr<char[]> header = std::make_unique<char[]>(header_size);
+ read_array(reader, header.get(), header_size);
+
+ expected_file_size = *((size_t *)header.get());
+ _max_observed_degree = *((uint32_t *)(header.get() + sizeof(size_t)));
+ start = *((uint32_t *)(header.get() + sizeof(size_t) + sizeof(uint32_t)));
+ file_frozen_pts = *((size_t *)(header.get() + sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t)));
+
+ diskann::cout << "From graph header, expected_file_size: " << expected_file_size
+ << ", _max_observed_degree: " << _max_observed_degree << ", _start: " << start
+ << ", file_frozen_pts: " << file_frozen_pts << std::endl;
+
+ diskann::cout << "Loading vamana graph from reader..." << std::flush;
+
+ // If user provides more points than max_points
+ // resize the _graph to the larger size.
+ if (get_total_points() < expected_num_points)
+ {
+ diskann::cout << "resizing graph to " << expected_num_points << std::endl;
+ this->resize_graph(expected_num_points);
+ }
+
+ uint32_t nodes_read = 0;
+ size_t cc = 0;
+ size_t graph_offset = header_size;
+ while (nodes_read < expected_num_points)
+ {
+ uint32_t k;
+ read_value(reader, k, graph_offset);
+ graph_offset += sizeof(uint32_t);
+ std::vector<uint32_t> tmp(k);
+ tmp.reserve(k);
+ read_array(reader, tmp.data(), k, graph_offset);
+ graph_offset += k * sizeof(uint32_t);
+ cc += k;
+ _graph[nodes_read].swap(tmp);
+ nodes_read++;
+ if (nodes_read % 1000000 == 0)
+ {
+ diskann::cout << "." << std::flush;
+ }
+ if (k > _max_range_of_graph)
+ {
+ _max_range_of_graph = k;
+ }
+ }
+
+ diskann::cout << "done. Index has " << nodes_read << " nodes and " << cc << " out-edges, _start is set to " << start
+ << std::endl;
+ return std::make_tuple(nodes_read, start, file_frozen_pts);
+}
+#endif
+
+std::tuple<uint32_t, uint32_t, size_t> InMemGraphStore::load_impl(const std::string &filename,
+ size_t expected_num_points)
+{
+ size_t expected_file_size;
+ size_t file_frozen_pts;
+ uint32_t start;
+ size_t file_offset = 0; // will need this for single file format support
+
+ std::ifstream in;
+ in.exceptions(std::ios::badbit | std::ios::failbit);
+ in.open(filename, std::ios::binary);
+ in.seekg(file_offset, in.beg);
+ in.read((char *)&expected_file_size, sizeof(size_t));
+ in.read((char *)&_max_observed_degree, sizeof(uint32_t));
+ in.read((char *)&start, sizeof(uint32_t));
+ in.read((char *)&file_frozen_pts, sizeof(size_t));
+ size_t vamana_metadata_size = sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(size_t);
+
+ diskann::cout << "From graph header, expected_file_size: " << expected_file_size
+ << ", _max_observed_degree: " << _max_observed_degree << ", _start: " << start
+ << ", file_frozen_pts: " << file_frozen_pts << std::endl;
+
+ diskann::cout << "Loading vamana graph " << filename << "..." << std::flush;
+
+ // If user provides more points than max_points
+ // resize the _graph to the larger size.
+ if (get_total_points() < expected_num_points)
+ {
+ diskann::cout << "resizing graph to " << expected_num_points << std::endl;
+ this->resize_graph(expected_num_points);
+ }
+
+ size_t bytes_read = vamana_metadata_size;
+ size_t cc = 0;
+ uint32_t nodes_read = 0;
+ while (bytes_read != expected_file_size)
+ {
+ uint32_t k;
+ in.read((char *)&k, sizeof(uint32_t));
+
+ if (k == 0)
+ {
+ diskann::cerr << "ERROR: Point found with no out-neighbours, point#" << nodes_read << std::endl;
+ }
+
+ cc += k;
+ ++nodes_read;
+ std::vector<uint32_t> tmp(k);
+ tmp.reserve(k);
+ in.read((char *)tmp.data(), k * sizeof(uint32_t));
+ _graph[nodes_read - 1].swap(tmp);
+ bytes_read += sizeof(uint32_t) * ((size_t)k + 1);
+ if (nodes_read % 10000000 == 0)
+ diskann::cout << "." << std::flush;
+ if (k > _max_range_of_graph)
+ {
+ _max_range_of_graph = k;
+ }
+ }
+
+ diskann::cout << "done. Index has " << nodes_read << " nodes and " << cc << " out-edges, _start is set to " << start
+ << std::endl;
+ return std::make_tuple(nodes_read, start, file_frozen_pts);
+}
+
+int InMemGraphStore::save_graph(const std::string &index_path_prefix, const size_t num_points,
+ const size_t num_frozen_points, const uint32_t start)
+{
+ std::ofstream out;
+ open_file_to_write(out, index_path_prefix);
+
+ size_t file_offset = 0;
+ out.seekp(file_offset, out.beg);
+ size_t index_size = 24;
+ uint32_t max_degree = 0;
+ out.write((char *)&index_size, sizeof(uint64_t));
+ out.write((char *)&_max_observed_degree, sizeof(uint32_t));
+ uint32_t ep_u32 = start;
+ out.write((char *)&ep_u32, sizeof(uint32_t));
+ out.write((char *)&num_frozen_points, sizeof(size_t));
+
+ // Note: num_points = _nd + _num_frozen_points
+ for (uint32_t i = 0; i < num_points; i++)
+ {
+ uint32_t GK = (uint32_t)_graph[i].size();
+ out.write((char *)&GK, sizeof(uint32_t));
+ out.write((char *)_graph[i].data(), GK * sizeof(uint32_t));
+ max_degree = _graph[i].size() > max_degree ? (uint32_t)_graph[i].size() : max_degree;
+ index_size += (size_t)(sizeof(uint32_t) * (GK + 1));
+ }
+ out.seekp(file_offset, out.beg);
+ out.write((char *)&index_size, sizeof(uint64_t));
+ out.write((char *)&max_degree, sizeof(uint32_t));
+ out.close();
+ return (int)index_size;
+}
+
+
+
+
+int InMemGraphStore::save_graph(std::stringstream &out, const size_t num_points,
+ const size_t num_frozen_points, const uint32_t start)
+{
+ size_t file_offset = 0;
+ out.seekp(file_offset, out.beg);
+ size_t index_size = 24;
+ uint32_t max_degree = 0;
+ out.write((char *)&index_size, sizeof(uint64_t));
+ out.write((char *)&_max_observed_degree, sizeof(uint32_t));
+ uint32_t ep_u32 = start;
+ out.write((char *)&ep_u32, sizeof(uint32_t));
+ out.write((char *)&num_frozen_points, sizeof(size_t));
+
+ // Note: num_points = _nd + _num_frozen_points
+ for (uint32_t i = 0; i < num_points; i++)
+ {
+ uint32_t GK = (uint32_t)_graph[i].size();
+ out.write((char *)&GK, sizeof(uint32_t));
+ out.write((char *)_graph[i].data(), GK * sizeof(uint32_t));
+ max_degree = _graph[i].size() > max_degree ? (uint32_t)_graph[i].size() : max_degree;
+ index_size += (size_t)(sizeof(uint32_t) * (GK + 1));
+ for(int m=0;m<GK;m++){
+ std::cout << "," << _graph[i][m] << std::endl;
+ }
+ std::cout << std::endl;
+ }
+ out.seekp(file_offset, out.beg);
+ out.write((char *)&index_size, sizeof(uint64_t));
+ out.write((char *)&max_degree, sizeof(uint32_t));
+ return (int)index_size;
+}
+
+size_t InMemGraphStore::get_max_range_of_graph()
+{
+ return _max_range_of_graph;
+}
+
+uint32_t InMemGraphStore::get_max_observed_degree()
+{
+ return _max_observed_degree;
+}
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/index.cpp b/be/src/extern/diskann/src/index.cpp
new file mode 100644
index 0000000..7b037e4
--- /dev/null
+++ b/be/src/extern/diskann/src/index.cpp
@@ -0,0 +1,3627 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <omp.h>
+
+#include <type_traits>
+
+#include "boost/dynamic_bitset.hpp"
+#include "index_factory.h"
+#include "memory_mapper.h"
+#include "timer.h"
+#include "tsl/robin_map.h"
+#include "tsl/robin_set.h"
+#include "windows_customizations.h"
+#include "tag_uint128.h"
+#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
+#include "gperftools/malloc_extension.h"
+#endif
+
+#ifdef _WINDOWS
+#include <xmmintrin.h>
+#endif
+
+#include "index.h"
+
+#define MAX_POINTS_FOR_USING_BITSET 10000000
+
+namespace diskann
+{
+// Initialize an index with metric m, load the data of type T with filename
+// (bin), and initialize max_points
+template <typename T, typename TagT, typename LabelT>
+Index<T, TagT, LabelT>::Index(const IndexConfig &index_config, std::shared_ptr<AbstractDataStore<T>> data_store,
+ std::unique_ptr<AbstractGraphStore> graph_store,
+ std::shared_ptr<AbstractDataStore<T>> pq_data_store)
+ : _dist_metric(index_config.metric), _dim(index_config.dimension), _max_points(index_config.max_points),
+ _num_frozen_pts(index_config.num_frozen_pts), _dynamic_index(index_config.dynamic_index),
+ _enable_tags(index_config.enable_tags), _indexingMaxC(DEFAULT_MAXC), _query_scratch(nullptr),
+ _pq_dist(index_config.pq_dist_build), _use_opq(index_config.use_opq),
+ _filtered_index(index_config.filtered_index), _num_pq_chunks(index_config.num_pq_chunks),
+ _delete_set(new tsl::robin_set<uint32_t>), _conc_consolidate(index_config.concurrent_consolidate)
+{
+ if (_dynamic_index && !_enable_tags)
+ {
+ throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (_pq_dist)
+ {
+ if (_dynamic_index)
+ throw ANNException("ERROR: Dynamic Indexing not supported with PQ distance based "
+ "index construction",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ if (_dist_metric == diskann::Metric::INNER_PRODUCT)
+ throw ANNException("ERROR: Inner product metrics not yet supported "
+ "with PQ distance "
+ "base index",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (_dynamic_index && _num_frozen_pts == 0)
+ {
+ _num_frozen_pts = 1;
+ }
+ // Sanity check. While logically it is correct, max_points = 0 causes
+ // downstream problems.
+ if (_max_points == 0)
+ {
+ _max_points = 1;
+ }
+ const size_t total_internal_points = _max_points + _num_frozen_pts;
+
+ _start = (uint32_t)_max_points;
+
+ _data_store = data_store;
+ _pq_data_store = pq_data_store;
+ _graph_store = std::move(graph_store);
+
+ _locks = std::vector<non_recursive_mutex>(total_internal_points);
+ if (_enable_tags)
+ {
+ _location_to_tag.reserve(total_internal_points);
+ _tag_to_location.reserve(total_internal_points);
+ }
+
+ if (_dynamic_index)
+ {
+ this->enable_delete(); // enable delete by default for dynamic index
+ if (_filtered_index)
+ {
+ _location_to_labels.resize(total_internal_points);
+ }
+ }
+
+ if (index_config.index_write_params != nullptr)
+ {
+ _indexingQueueSize = index_config.index_write_params->search_list_size;
+ _indexingRange = index_config.index_write_params->max_degree;
+ _indexingMaxC = index_config.index_write_params->max_occlusion_size;
+ _indexingAlpha = index_config.index_write_params->alpha;
+ _filterIndexingQueueSize = index_config.index_write_params->filter_list_size;
+ _indexingThreads = index_config.index_write_params->num_threads;
+ _saturate_graph = index_config.index_write_params->saturate_graph;
+
+ if (index_config.index_search_params != nullptr)
+ {
+ uint32_t num_scratch_spaces = index_config.index_search_params->num_search_threads + _indexingThreads;
+ initialize_query_scratch(num_scratch_spaces, index_config.index_search_params->initial_search_list_size,
+ _indexingQueueSize, _indexingRange, _indexingMaxC, _data_store->get_dims());
+ }
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+Index<T, TagT, LabelT>::Index(Metric m, const size_t dim, const size_t max_points,
+ const std::shared_ptr<IndexWriteParameters> index_parameters,
+ const std::shared_ptr<IndexSearchParams> index_search_params, const size_t num_frozen_pts,
+ const bool dynamic_index, const bool enable_tags, const bool concurrent_consolidate,
+ const bool pq_dist_build, const size_t num_pq_chunks, const bool use_opq,
+ const bool filtered_index)
+ : Index(
+ IndexConfigBuilder()
+ .with_metric(m)
+ .with_dimension(dim)
+ .with_max_points(max_points)
+ .with_index_write_params(index_parameters)
+ .with_index_search_params(index_search_params)
+ .with_num_frozen_pts(num_frozen_pts)
+ .is_dynamic_index(dynamic_index)
+ .is_enable_tags(enable_tags)
+ .is_concurrent_consolidate(concurrent_consolidate)
+ .is_pq_dist_build(pq_dist_build)
+ .with_num_pq_chunks(num_pq_chunks)
+ .is_use_opq(use_opq)
+ .is_filtered(filtered_index)
+ .with_data_type(diskann_type_to_name<T>())
+ .build(),
+ IndexFactory::construct_datastore<T>(DataStoreStrategy::MEMORY,
+ (max_points == 0 ? (size_t)1 : max_points) +
+ (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts),
+ dim, m),
+ IndexFactory::construct_graphstore(GraphStoreStrategy::MEMORY,
+ (max_points == 0 ? (size_t)1 : max_points) +
+ (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts),
+ (size_t)((index_parameters == nullptr ? 0 : index_parameters->max_degree) *
+ defaults::GRAPH_SLACK_FACTOR * 1.05)))
+{
+ if (_pq_dist)
+ {
+ _pq_data_store = IndexFactory::construct_pq_datastore<T>(DataStoreStrategy::MEMORY, max_points + num_frozen_pts,
+ dim, m, num_pq_chunks, use_opq);
+ }
+ else
+ {
+ _pq_data_store = _data_store;
+ }
+}
+
+template <typename T, typename TagT, typename LabelT> Index<T, TagT, LabelT>::~Index()
+{
+ // Ensure that no other activity is happening before dtor()
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> cl(_consolidate_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+
+ for (auto &lock : _locks)
+ {
+ LockGuard lg(lock);
+ }
+
+ if (_opt_graph != nullptr)
+ {
+ delete[] _opt_graph;
+ }
+
+ if (!_query_scratch.empty())
+ {
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ manager.destroy();
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::initialize_query_scratch(uint32_t num_threads, uint32_t search_l, uint32_t indexing_l,
+ uint32_t r, uint32_t maxc, size_t dim)
+{
+ for (uint32_t i = 0; i < num_threads; i++)
+ {
+ auto scratch = new InMemQueryScratch<T>(search_l, indexing_l, r, maxc, dim, _data_store->get_aligned_dim(),
+ _data_store->get_alignment_factor(), _pq_dist);
+ _query_scratch.push(scratch);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT> size_t Index<T, TagT, LabelT>::save_tags(std::string tags_file)
+{
+ if (!_enable_tags)
+ {
+ diskann::cout << "Not saving tags as they are not enabled." << std::endl;
+ return 0;
+ }
+
+ size_t tag_bytes_written;
+ TagT *tag_data = new TagT[_nd + _num_frozen_pts];
+ for (uint32_t i = 0; i < _nd; i++)
+ {
+ TagT tag;
+ if (_location_to_tag.try_get(i, tag))
+ {
+ tag_data[i] = tag;
+ }
+ else
+ {
+ // catering to future when tagT can be any type.
+ std::memset((char *)&tag_data[i], 0, sizeof(TagT));
+ }
+ }
+ if (_num_frozen_pts > 0)
+ {
+ std::memset((char *)&tag_data[_start], 0, sizeof(TagT) * _num_frozen_pts);
+ }
+ try
+ {
+ tag_bytes_written = save_bin<TagT>(tags_file, tag_data, _nd + _num_frozen_pts, 1);
+ }
+ catch (std::system_error &e)
+ {
+ throw FileException(tags_file, e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ delete[] tag_data;
+ return tag_bytes_written;
+}
+
+template <typename T, typename TagT, typename LabelT> size_t Index<T, TagT, LabelT>::save_data(std::string data_file)
+{
+ // Note: at this point, either _nd == _max_points or any frozen points have
+ // been temporarily moved to _nd, so _nd + _num_frozen_pts is the valid
+ // location limit.
+ return _data_store->save(data_file, (location_t)(_nd + _num_frozen_pts));
+}
+
+// save the graph index on a file as an adjacency list. For each point,
+// first store the number of neighbors, and then the neighbor list (each as
+// 4 byte uint32_t)
+template <typename T, typename TagT, typename LabelT> size_t Index<T, TagT, LabelT>::save_graph(std::string graph_file)
+{
+ return _graph_store->store(graph_file, _nd + _num_frozen_pts, _num_frozen_pts, _start);
+}
+
+template <typename T, typename TagT, typename LabelT> size_t Index<T, TagT, LabelT>::save_graph(std::stringstream &graph_stream)
+{
+ return _graph_store->store(graph_stream, _nd + _num_frozen_pts, _num_frozen_pts, _start);
+}
+
+
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::save_delete_list(const std::string &filename)
+{
+ if (_delete_set->size() == 0)
+ {
+ return 0;
+ }
+ std::unique_ptr<uint32_t[]> delete_list = std::make_unique<uint32_t[]>(_delete_set->size());
+ uint32_t i = 0;
+ for (auto &del : *_delete_set)
+ {
+ delete_list[i++] = del;
+ }
+ return save_bin<uint32_t>(filename, delete_list.get(), _delete_set->size(), 1);
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::save(const char *filename, bool compact_before_save)
+{
+ diskann::Timer timer;
+
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> cl(_consolidate_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+
+ if (compact_before_save)
+ {
+ compact_data();
+ compact_frozen_point();
+ }
+ else
+ {
+ if (!_data_compacted)
+ {
+ throw ANNException("Index save for non-compacted index is not yet implemented", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+ }
+
+ if (!_save_as_one_file)
+ {
+ if (_filtered_index)
+ {
+ if (_label_to_start_id.size() > 0)
+ {
+ std::ofstream medoid_writer(std::string(filename) + "_labels_to_medoids.txt");
+ if (medoid_writer.fail())
+ {
+ throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
+ }
+ for (auto iter : _label_to_start_id)
+ {
+ medoid_writer << iter.first << ", " << iter.second << std::endl;
+ }
+ medoid_writer.close();
+ }
+
+ if (_use_universal_label)
+ {
+ std::ofstream universal_label_writer(std::string(filename) + "_universal_label.txt");
+ assert(universal_label_writer.is_open());
+ universal_label_writer << _universal_label << std::endl;
+ universal_label_writer.close();
+ }
+
+ if (_location_to_labels.size() > 0)
+ {
+ std::ofstream label_writer(std::string(filename) + "_labels.txt");
+ assert(label_writer.is_open());
+ for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++)
+ {
+ for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++)
+ {
+ label_writer << _location_to_labels[i][j] << ",";
+ }
+ if (_location_to_labels[i].size() != 0)
+ label_writer << _location_to_labels[i][_location_to_labels[i].size() - 1];
+
+ label_writer << std::endl;
+ }
+ label_writer.close();
+
+ // write compacted raw_labels if data hence _location_to_labels was also compacted
+ if (compact_before_save && _dynamic_index)
+ {
+ _label_map = load_label_map(std::string(filename) + "_labels_map.txt");
+ std::unordered_map<LabelT, std::string> mapped_to_raw_labels;
+ // invert label map
+ for (const auto &[key, value] : _label_map)
+ {
+ mapped_to_raw_labels.insert({value, key});
+ }
+
+ // write updated labels
+ std::ofstream raw_label_writer(std::string(filename) + "_raw_labels.txt");
+ assert(raw_label_writer.is_open());
+ for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++)
+ {
+ for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++)
+ {
+ raw_label_writer << mapped_to_raw_labels[_location_to_labels[i][j]] << ",";
+ }
+ if (_location_to_labels[i].size() != 0)
+ raw_label_writer
+ << mapped_to_raw_labels[_location_to_labels[i][_location_to_labels[i].size() - 1]];
+
+ raw_label_writer << std::endl;
+ }
+ raw_label_writer.close();
+ }
+ }
+ }
+
+ std::string graph_file = std::string(filename);
+ std::string tags_file = std::string(filename) + ".tags";
+ std::string data_file = std::string(filename) + ".data";
+ std::string delete_list_file = std::string(filename) + ".del";
+
+ // Because the save_* functions use append mode, ensure that
+ // the files are deleted before save. Ideally, we should check
+ // the error code for delete_file, but will ignore now because
+ // delete should succeed if save will succeed.
+ delete_file(graph_file);
+ save_graph(graph_file);
+ delete_file(data_file);
+ save_data(data_file);
+ delete_file(tags_file);
+ save_tags(tags_file);
+ delete_file(delete_list_file);
+ save_delete_list(delete_list_file);
+ }
+ else
+ {
+ diskann::cout << "Save index in a single file currently not supported. "
+ "Not saving the index."
+ << std::endl;
+ }
+
+ // If frozen points were temporarily compacted to _nd, move back to
+ // _max_points.
+ reposition_frozen_point_to_end();
+
+ diskann::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 << "s." << std::endl;
+}
+
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::save(std::stringstream &mem_index_stream, bool compact_before_save)
+{
+ diskann::Timer timer;
+
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> cl(_consolidate_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+
+ if (compact_before_save)
+ {
+ compact_data();
+ compact_frozen_point();
+ }
+ else
+ {
+ if (!_data_compacted)
+ {
+ throw ANNException("Index save for non-compacted index is not yet implemented", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+ }
+
+ if (!_save_as_one_file)
+ {
+ // if (_filtered_index)
+ // {
+ // if (_label_to_start_id.size() > 0)
+ // {
+ // std::ofstream medoid_writer(std::string(filename) + "_labels_to_medoids.txt");
+ // if (medoid_writer.fail())
+ // {
+ // throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
+ // }
+ // for (auto iter : _label_to_start_id)
+ // {
+ // medoid_writer << iter.first << ", " << iter.second << std::endl;
+ // }
+ // medoid_writer.close();
+ // }
+
+ // if (_use_universal_label)
+ // {
+ // std::ofstream universal_label_writer(std::string(filename) + "_universal_label.txt");
+ // assert(universal_label_writer.is_open());
+ // universal_label_writer << _universal_label << std::endl;
+ // universal_label_writer.close();
+ // }
+
+ // if (_location_to_labels.size() > 0)
+ // {
+ // std::ofstream label_writer(std::string(filename) + "_labels.txt");
+ // assert(label_writer.is_open());
+ // for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++)
+ // {
+ // for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++)
+ // {
+ // label_writer << _location_to_labels[i][j] << ",";
+ // }
+ // if (_location_to_labels[i].size() != 0)
+ // label_writer << _location_to_labels[i][_location_to_labels[i].size() - 1];
+
+ // label_writer << std::endl;
+ // }
+ // label_writer.close();
+
+ // // write compacted raw_labels if data hence _location_to_labels was also compacted
+ // if (compact_before_save && _dynamic_index)
+ // {
+ // _label_map = load_label_map(std::string(filename) + "_labels_map.txt");
+ // std::unordered_map<LabelT, std::string> mapped_to_raw_labels;
+ // // invert label map
+ // for (const auto &[key, value] : _label_map)
+ // {
+ // mapped_to_raw_labels.insert({value, key});
+ // }
+
+ // // write updated labels
+ // std::ofstream raw_label_writer(std::string(filename) + "_raw_labels.txt");
+ // assert(raw_label_writer.is_open());
+ // for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++)
+ // {
+ // for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++)
+ // {
+ // raw_label_writer << mapped_to_raw_labels[_location_to_labels[i][j]] << ",";
+ // }
+ // if (_location_to_labels[i].size() != 0)
+ // raw_label_writer
+ // << mapped_to_raw_labels[_location_to_labels[i][_location_to_labels[i].size() - 1]];
+
+ // raw_label_writer << std::endl;
+ // }
+ // raw_label_writer.close();
+ // }
+ // }
+ // }
+
+ // std::string graph_file = std::string(filename);
+ // std::string tags_file = std::string(filename) + ".tags";
+ // std::string data_file = std::string(filename) + ".data";
+ // std::string delete_list_file = std::string(filename) + ".del";
+
+ // Because the save_* functions use append mode, ensure that
+ // the files are deleted before save. Ideally, we should check
+ // the error code for delete_file, but will ignore now because
+ // delete should succeed if save will succeed.
+ //delete_file(graph_file);
+ save_graph(mem_index_stream);
+ // delete_file(data_file);
+ // save_data(data_file);
+ // delete_file(tags_file);
+ // save_tags(tags_file);
+ // delete_file(delete_list_file);
+ // save_delete_list(delete_list_file);
+ }
+ else
+ {
+ diskann::cout << "Save index in a single file currently not supported. "
+ "Not saving the index."
+ << std::endl;
+ }
+
+ // If frozen points were temporarily compacted to _nd, move back to
+ // _max_points.
+ reposition_frozen_point_to_end();
+
+ diskann::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 << "s." << std::endl;
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::load_tags(AlignedFileReader &reader)
+{
+#else
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::load_tags(const std::string tag_filename)
+{
+ if (_enable_tags && !file_exists(tag_filename))
+ {
+ diskann::cerr << "Tag file " << tag_filename << " does not exist!" << std::endl;
+ throw diskann::ANNException("Tag file " + tag_filename + " does not exist!", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+#endif
+ if (!_enable_tags)
+ {
+ diskann::cout << "Tags not loaded as tags not enabled." << std::endl;
+ return 0;
+ }
+
+ size_t file_dim, file_num_points;
+ TagT *tag_data;
+#ifdef EXEC_ENV_OLS
+ load_bin<TagT>(reader, tag_data, file_num_points, file_dim);
+#else
+ load_bin<TagT>(std::string(tag_filename), tag_data, file_num_points, file_dim);
+#endif
+
+ if (file_dim != 1)
+ {
+ std::stringstream stream;
+ stream << "ERROR: Found " << file_dim << " dimensions for tags,"
+ << "but tag file must have 1 dimension." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ delete[] tag_data;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ const size_t num_data_points = file_num_points - _num_frozen_pts;
+ _location_to_tag.reserve(num_data_points);
+ _tag_to_location.reserve(num_data_points);
+ for (uint32_t i = 0; i < (uint32_t)num_data_points; i++)
+ {
+ TagT tag = *(tag_data + i);
+ if (_delete_set->find(i) == _delete_set->end())
+ {
+ _location_to_tag.set(i, tag);
+ _tag_to_location[tag] = i;
+ }
+ }
+ diskann::cout << "Tags loaded." << std::endl;
+ delete[] tag_data;
+ return file_num_points;
+}
+
+template <typename T, typename TagT, typename LabelT>
+#ifdef EXEC_ENV_OLS
+size_t Index<T, TagT, LabelT>::load_data(AlignedFileReader &reader)
+{
+#else
+size_t Index<T, TagT, LabelT>::load_data(std::string filename)
+{
+#endif
+ size_t file_dim, file_num_points;
+#ifdef EXEC_ENV_OLS
+ diskann::get_bin_metadata(reader, file_num_points, file_dim);
+#else
+ if (!file_exists(filename))
+ {
+ std::stringstream stream;
+ stream << "ERROR: data file " << filename << " does not exist." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ diskann::get_bin_metadata(filename, file_num_points, file_dim);
+#endif
+
+ // since we are loading a new dataset, _empty_slots must be cleared
+ _empty_slots.clear();
+
+ if (file_dim != _dim)
+ {
+ std::stringstream stream;
+ stream << "ERROR: Driver requests loading " << _dim << " dimension,"
+ << "but file has " << file_dim << " dimension." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (file_num_points > _max_points + _num_frozen_pts)
+ {
+ // update and tag lock acquired in load() before calling load_data
+ resize(file_num_points - _num_frozen_pts);
+ }
+
+#ifdef EXEC_ENV_OLS
+ // REFACTOR TODO: Must figure out how to support aligned reader in a clean
+ // manner.
+ copy_aligned_data_from_file<T>(reader, _data, file_num_points, file_dim, _data_store->get_aligned_dim());
+#else
+ _data_store->load(filename); // offset == 0.
+#endif
+ return file_num_points;
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::load_delete_set(AlignedFileReader &reader)
+{
+#else
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::load_delete_set(const std::string &filename)
+{
+#endif
+ std::unique_ptr<uint32_t[]> delete_list;
+ size_t npts, ndim;
+
+#ifdef EXEC_ENV_OLS
+ diskann::load_bin<uint32_t>(reader, delete_list, npts, ndim);
+#else
+ diskann::load_bin<uint32_t>(filename, delete_list, npts, ndim);
+#endif
+ assert(ndim == 1);
+ for (uint32_t i = 0; i < npts; i++)
+ {
+ _delete_set->insert(delete_list[i]);
+ }
+ return npts;
+}
+
+// load the index from file and update the max_degree, cur (navigating
+// node loc), and _final_graph (adjacency list)
+template <typename T, typename TagT, typename LabelT>
+#ifdef EXEC_ENV_OLS
+void Index<T, TagT, LabelT>::load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l)
+{
+#else
+void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, uint32_t search_l)
+{
+#endif
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> cl(_consolidate_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+
+ _has_built = true;
+
+ size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0, label_num_pts = 0;
+
+ std::string mem_index_file(filename);
+ std::string labels_file = mem_index_file + "_labels.txt";
+ std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt";
+ std::string labels_map_file = mem_index_file + "_labels_map.txt";
+
+ if (!_save_as_one_file)
+ {
+ // For DLVS Store, we will not support saving the index in multiple
+ // files.
+#ifndef EXEC_ENV_OLS
+ std::string data_file = std::string(filename) + ".data";
+ std::string tags_file = std::string(filename) + ".tags";
+ std::string delete_set_file = std::string(filename) + ".del";
+ std::string graph_file = std::string(filename);
+ data_file_num_pts = load_data(data_file);
+ if (file_exists(delete_set_file))
+ {
+ load_delete_set(delete_set_file);
+ }
+ if (_enable_tags)
+ {
+ tags_file_num_pts = load_tags(tags_file);
+ }
+ graph_num_pts = load_graph(graph_file, data_file_num_pts);
+#endif
+ }
+ else
+ {
+ diskann::cout << "Single index file saving/loading support not yet "
+ "enabled. Not loading the index."
+ << std::endl;
+ return;
+ }
+
+ if (data_file_num_pts != graph_num_pts || (data_file_num_pts != tags_file_num_pts && _enable_tags))
+ {
+ std::stringstream stream;
+ stream << "ERROR: When loading index, loaded " << data_file_num_pts << " points from datafile, "
+ << graph_num_pts << " from graph, and " << tags_file_num_pts
+ << " tags, with num_frozen_pts being set to " << _num_frozen_pts << " in constructor." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (file_exists(labels_file))
+ {
+ _label_map = load_label_map(labels_map_file);
+ parse_label_file(labels_file, label_num_pts);
+ assert(label_num_pts == data_file_num_pts - _num_frozen_pts);
+ if (file_exists(labels_to_medoids))
+ {
+ std::ifstream medoid_stream(labels_to_medoids);
+ std::string line, token;
+ uint32_t line_cnt = 0;
+
+ _label_to_start_id.clear();
+
+ while (std::getline(medoid_stream, line))
+ {
+ std::istringstream iss(line);
+ uint32_t cnt = 0;
+ uint32_t medoid = 0;
+ LabelT label;
+ while (std::getline(iss, token, ','))
+ {
+ token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
+ token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
+ LabelT token_as_num = (LabelT)std::stoul(token);
+ if (cnt == 0)
+ label = token_as_num;
+ else
+ medoid = token_as_num;
+ cnt++;
+ }
+ _label_to_start_id[label] = medoid;
+ line_cnt++;
+ }
+ }
+
+ std::string universal_label_file(filename);
+ universal_label_file += "_universal_label.txt";
+ if (file_exists(universal_label_file))
+ {
+ std::ifstream universal_label_reader(universal_label_file);
+ universal_label_reader >> _universal_label;
+ _use_universal_label = true;
+ universal_label_reader.close();
+ }
+ }
+
+ _nd = data_file_num_pts - _num_frozen_pts;
+ _empty_slots.clear();
+ _empty_slots.reserve(_max_points);
+ for (auto i = _nd; i < _max_points; i++)
+ {
+ _empty_slots.insert((uint32_t)i);
+ }
+
+ reposition_frozen_point_to_end();
+ diskann::cout << "Num frozen points:" << _num_frozen_pts << " _nd: " << _nd << " _start: " << _start
+ << " size(_location_to_tag): " << _location_to_tag.size()
+ << " size(_tag_to_location):" << _tag_to_location.size() << " Max points: " << _max_points
+ << std::endl;
+
+ // For incremental index, _query_scratch is initialized in the constructor.
+ // For the bulk index, the params required to initialize _query_scratch
+ // are known only at load time, hence this check and the call to
+ // initialize_q_s().
+ if (_query_scratch.size() == 0)
+ {
+ initialize_query_scratch(num_threads, search_l, search_l, (uint32_t)_graph_store->get_max_range_of_graph(),
+ _indexingMaxC, _dim);
+ }
+}
+
+#ifndef EXEC_ENV_OLS
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::get_graph_num_frozen_points(const std::string &graph_file)
+{
+ size_t expected_file_size;
+ uint32_t max_observed_degree, start;
+ size_t file_frozen_pts;
+
+ std::ifstream in;
+ in.exceptions(std::ios::badbit | std::ios::failbit);
+
+ in.open(graph_file, std::ios::binary);
+ in.read((char *)&expected_file_size, sizeof(size_t));
+ in.read((char *)&max_observed_degree, sizeof(uint32_t));
+ in.read((char *)&start, sizeof(uint32_t));
+ in.read((char *)&file_frozen_pts, sizeof(size_t));
+
+ return file_frozen_pts;
+}
+#endif
+
+#ifdef EXEC_ENV_OLS
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::load_graph(AlignedFileReader &reader, size_t expected_num_points)
+{
+#else
+
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::load_graph(std::string filename, size_t expected_num_points)
+{
+#endif
+ auto res = _graph_store->load(filename, expected_num_points);
+ _start = std::get<1>(res);
+ _num_frozen_pts = std::get<2>(res);
+ return std::get<0>(res);
+}
+
+template <typename T, typename TagT, typename LabelT>
+int Index<T, TagT, LabelT>::_get_vector_by_tag(TagType &tag, DataType &vec)
+{
+ try
+ {
+ TagT tag_val = std::any_cast<TagT>(tag);
+ T *vec_val = std::any_cast<T *>(vec);
+ return this->get_vector_by_tag(tag_val, vec_val);
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException("Error: bad any cast while performing _get_vector_by_tags() " + std::string(e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw ANNException("Error: " + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT> int Index<T, TagT, LabelT>::get_vector_by_tag(TagT &tag, T *vec)
+{
+ std::shared_lock<std::shared_timed_mutex> lock(_tag_lock);
+ if (_tag_to_location.find(tag) == _tag_to_location.end())
+ {
+ diskann::cout << "Tag " << get_tag_string(tag) << " does not exist" << std::endl;
+ return -1;
+ }
+
+ location_t location = _tag_to_location[tag];
+ _data_store->get_vector(location, vec);
+
+ return 0;
+}
+
+template <typename T, typename TagT, typename LabelT> uint32_t Index<T, TagT, LabelT>::calculate_entry_point()
+{
+ // REFACTOR TODO: This function does not support multi-threaded calculation of medoid.
+ // Must revisit if perf is a concern.
+ return _data_store->calculate_medoid();
+}
+
+template <typename T, typename TagT, typename LabelT> std::vector<uint32_t> Index<T, TagT, LabelT>::get_init_ids()
+{
+ std::vector<uint32_t> init_ids;
+ init_ids.reserve(1 + _num_frozen_pts);
+
+ init_ids.emplace_back(_start);
+
+ for (uint32_t frozen = (uint32_t)_max_points; frozen < _max_points + _num_frozen_pts; frozen++)
+ {
+ if (frozen != _start)
+ {
+ init_ids.emplace_back(frozen);
+ }
+ }
+
+ return init_ids;
+}
+
+// Find common filter between a node's labels and a given set of labels, while
+// taking into account universal label
+template <typename T, typename TagT, typename LabelT>
+bool Index<T, TagT, LabelT>::detect_common_filters(uint32_t point_id, bool search_invocation,
+ const std::vector<LabelT> &incoming_labels)
+{
+ auto &curr_node_labels = _location_to_labels[point_id];
+ std::vector<LabelT> common_filters;
+ std::set_intersection(incoming_labels.begin(), incoming_labels.end(), curr_node_labels.begin(),
+ curr_node_labels.end(), std::back_inserter(common_filters));
+ if (common_filters.size() > 0)
+ {
+ // This is to reduce the repetitive calls. If common_filters size is > 0 ,
+ // we dont need to check further for universal label
+ return true;
+ }
+ if (_use_universal_label)
+ {
+ if (!search_invocation)
+ {
+ if (std::find(incoming_labels.begin(), incoming_labels.end(), _universal_label) != incoming_labels.end() ||
+ std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end())
+ common_filters.push_back(_universal_label);
+ }
+ else
+ {
+ if (std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end())
+ common_filters.push_back(_universal_label);
+ }
+ }
+ return (common_filters.size() > 0);
+}
+
+template <typename T, typename TagT, typename LabelT>
+std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
+ InMemQueryScratch<T> *scratch, const uint32_t Lsize, const std::vector<uint32_t> &init_ids, bool use_filter,
+ const std::vector<LabelT> &filter_labels, bool search_invocation)
+{
+ std::vector<Neighbor> &expanded_nodes = scratch->pool();
+ NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes();
+ best_L_nodes.reserve(Lsize);
+ tsl::robin_set<uint32_t> &inserted_into_pool_rs = scratch->inserted_into_pool_rs();
+ boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs();
+ std::vector<uint32_t> &id_scratch = scratch->id_scratch();
+ std::vector<float> &dist_scratch = scratch->dist_scratch();
+ assert(id_scratch.size() == 0);
+
+ T *aligned_query = scratch->aligned_query();
+
+ float *pq_dists = nullptr;
+
+ _pq_data_store->preprocess_query(aligned_query, scratch);
+
+ if (expanded_nodes.size() > 0 || id_scratch.size() > 0)
+ {
+ throw ANNException("ERROR: Clear scratch space before passing.", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ // Decide whether to use bitset or robin set to mark visited nodes
+ auto total_num_points = _max_points + _num_frozen_pts;
+ bool fast_iterate = total_num_points <= MAX_POINTS_FOR_USING_BITSET;
+
+ if (fast_iterate)
+ {
+ if (inserted_into_pool_bs.size() < total_num_points)
+ {
+ // hopefully using 2X will reduce the number of allocations.
+ auto resize_size =
+ 2 * total_num_points > MAX_POINTS_FOR_USING_BITSET ? MAX_POINTS_FOR_USING_BITSET : 2 * total_num_points;
+ inserted_into_pool_bs.resize(resize_size);
+ }
+ }
+
+ // Lambda to determine if a node has been visited
+ auto is_not_visited = [this, fast_iterate, &inserted_into_pool_bs, &inserted_into_pool_rs](const uint32_t id) {
+ return fast_iterate ? inserted_into_pool_bs[id] == 0
+ : inserted_into_pool_rs.find(id) == inserted_into_pool_rs.end();
+ };
+
+ // Lambda to batch compute query<-> node distances in PQ space
+ auto compute_dists = [this, scratch, pq_dists](const std::vector<uint32_t> &ids, std::vector<float> &dists_out) {
+ _pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, scratch);
+ };
+
+ // Initialize the candidate pool with starting points
+ for (auto id : init_ids)
+ {
+ if (id >= _max_points + _num_frozen_pts)
+ {
+ diskann::cerr << "Out of range loc found as an edge : " << id << std::endl;
+ throw diskann::ANNException(std::string("Wrong loc") + std::to_string(id), -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ if (use_filter)
+ {
+ if (!detect_common_filters(id, search_invocation, filter_labels))
+ continue;
+ }
+
+ if (is_not_visited(id))
+ {
+ if (fast_iterate)
+ {
+ inserted_into_pool_bs[id] = 1;
+ }
+ else
+ {
+ inserted_into_pool_rs.insert(id);
+ }
+
+ float distance;
+ uint32_t ids[] = {id};
+ float distances[] = {std::numeric_limits<float>::max()};
+ _pq_data_store->get_distance(aligned_query, ids, 1, distances, scratch);
+ distance = distances[0];
+
+ Neighbor nn = Neighbor(id, distance);
+ best_L_nodes.insert(nn);
+ }
+ }
+
+ uint32_t hops = 0;
+ uint32_t cmps = 0;
+ std::cout << "best_L_nodes size: " << best_L_nodes.size() << ", expanded_nodes:" << expanded_nodes.size() << std::endl;
+ while (best_L_nodes.has_unexpanded_node())
+ {
+ auto nbr = best_L_nodes.closest_unexpanded();
+ auto n = nbr.id;
+
+ // Add node to expanded nodes to create pool for prune later
+ if (!search_invocation)
+ {
+ if (!use_filter)
+ {
+ expanded_nodes.emplace_back(nbr);
+ }
+ else
+ { // in filter based indexing, the same point might invoke
+ // multiple iterate_to_fixed_points, so need to be careful
+ // not to add the same item to pool multiple times.
+ if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == expanded_nodes.end())
+ {
+ expanded_nodes.emplace_back(nbr);
+ }
+ }
+ }
+
+ // Find which of the nodes in des have not been visited before
+ id_scratch.clear();
+ dist_scratch.clear();
+ if (_dynamic_index)
+ {
+ LockGuard guard(_locks[n]);
+ for (auto id : _graph_store->get_neighbours(n))
+ {
+ assert(id < _max_points + _num_frozen_pts);
+
+ if (use_filter)
+ {
+ // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS.
+ if (!detect_common_filters(id, search_invocation, filter_labels))
+ continue;
+ }
+
+ if (is_not_visited(id))
+ {
+ id_scratch.push_back(id);
+ }
+ }
+ }
+ else
+ {
+ _locks[n].lock();
+ auto nbrs = _graph_store->get_neighbours(n);
+ _locks[n].unlock();
+ for (auto id : nbrs)
+ {
+ assert(id < _max_points + _num_frozen_pts);
+
+ if (use_filter)
+ {
+ // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS.
+ if (!detect_common_filters(id, search_invocation, filter_labels))
+ continue;
+ }
+
+ if (is_not_visited(id))
+ {
+ id_scratch.push_back(id);
+ }
+ }
+ }
+
+ // Mark nodes visited
+ for (auto id : id_scratch)
+ {
+ if (fast_iterate)
+ {
+ inserted_into_pool_bs[id] = 1;
+ }
+ else
+ {
+ inserted_into_pool_rs.insert(id);
+ }
+ }
+
+ assert(dist_scratch.capacity() >= id_scratch.size());
+ compute_dists(id_scratch, dist_scratch);
+ cmps += (uint32_t)id_scratch.size();
+
+ // Insert <id, dist> pairs into the pool of candidates
+ for (size_t m = 0; m < id_scratch.size(); ++m)
+ {
+ best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m]));
+ }
+ }
+ std::cout << "best_L_nodes2 size: " << best_L_nodes.size() << ", expanded_nodes:" << expanded_nodes.size() << std::endl;
+ return std::make_pair(hops, cmps);
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t Lindex,
+ std::vector<uint32_t> &pruned_list,
+ InMemQueryScratch<T> *scratch, bool use_filter,
+ uint32_t filteredLindex)
+{
+ const std::vector<uint32_t> init_ids = get_init_ids();
+ const std::vector<LabelT> unused_filter_label;
+
+ if (!use_filter)
+ {
+ _data_store->get_vector(location, scratch->aligned_query());
+ std::cout << "init_ids size: " << init_ids.size() << std::endl;
+ iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);
+ std::cout << "993 size: " << scratch->pool().size() << std::endl;
+ }
+ else
+ {
+ std::shared_lock<std::shared_timed_mutex> tl(_tag_lock, std::defer_lock);
+ if (_dynamic_index)
+ tl.lock();
+ std::vector<uint32_t> filter_specific_start_nodes;
+ for (auto &x : _location_to_labels[location])
+ filter_specific_start_nodes.emplace_back(_label_to_start_id[x]);
+
+ if (_dynamic_index)
+ tl.unlock();
+
+ _data_store->get_vector(location, scratch->aligned_query());
+ iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true,
+ _location_to_labels[location], false);
+
+ // combine candidate pools obtained with filter and unfiltered criteria.
+ std::set<Neighbor> best_candidate_pool;
+ for (auto filtered_neighbor : scratch->pool())
+ {
+ best_candidate_pool.insert(filtered_neighbor);
+ }
+
+ // clear scratch for finding unfiltered candidates
+ scratch->clear();
+
+ _data_store->get_vector(location, scratch->aligned_query());
+ iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);
+
+ for (auto unfiltered_neighbour : scratch->pool())
+ {
+ // insert if this neighbour is not already in best_candidate_pool
+ if (best_candidate_pool.find(unfiltered_neighbour) == best_candidate_pool.end())
+ {
+ best_candidate_pool.insert(unfiltered_neighbour);
+ }
+ }
+
+ scratch->pool().clear();
+ std::copy(best_candidate_pool.begin(), best_candidate_pool.end(), std::back_inserter(scratch->pool()));
+ }
+
+ auto &pool = scratch->pool();
+ std::cout << "1037 size: " << pool.size() << std::endl;
+ for (uint32_t i = 0; i < pool.size(); i++)
+ {
+ if (pool[i].id == (uint32_t)location)
+ {
+ pool.erase(pool.begin() + i);
+ i--;
+ }
+ }
+ std::cout << "1047 size: " << pool.size() << std::endl;
+ if (pruned_list.size() > 0)
+ {
+ throw diskann::ANNException("ERROR: non-empty pruned_list passed", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ prune_neighbors(location, pool, pruned_list, scratch);
+
+ std::cout << "pool1 size: " << pool.size() << std::endl;
+ std::cout << "pruned_list2 size: " << pruned_list.size() << std::endl;
+
+ assert(!pruned_list.empty());
+ assert(_graph_store->get_total_points() == _max_points + _num_frozen_pts);
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<Neighbor> &pool, const float alpha,
+ const uint32_t degree, const uint32_t maxc, std::vector<uint32_t> &result,
+ InMemQueryScratch<T> *scratch,
+ const tsl::robin_set<uint32_t> *const delete_set_ptr)
+{
+ if (pool.size() == 0)
+ return;
+
+ // Truncate pool at maxc and initialize scratch spaces
+ assert(std::is_sorted(pool.begin(), pool.end()));
+ assert(result.size() == 0);
+ if (pool.size() > maxc)
+ pool.resize(maxc);
+ std::vector<float> &occlude_factor = scratch->occlude_factor();
+ // occlude_list can be called with the same scratch more than once by
+ // search_for_point_and_add_link through inter_insert.
+ occlude_factor.clear();
+ // Initialize occlude_factor to pool.size() many 0.0f values for correctness
+ occlude_factor.insert(occlude_factor.end(), pool.size(), 0.0f);
+
+ float cur_alpha = 1;
+ while (cur_alpha <= alpha && result.size() < degree)
+ {
+ // used for MIPS, where we store a value of eps in cur_alpha to
+ // denote pruned out entries which we can skip in later rounds.
+ float eps = cur_alpha + 0.01f;
+
+ for (auto iter = pool.begin(); result.size() < degree && iter != pool.end(); ++iter)
+ {
+ if (occlude_factor[iter - pool.begin()] > cur_alpha)
+ {
+ continue;
+ }
+ // Set the entry to float::max so that is not considered again
+ occlude_factor[iter - pool.begin()] = std::numeric_limits<float>::max();
+ // Add the entry to the result if its not been deleted, and doesn't
+ // add a self loop
+ if (delete_set_ptr == nullptr || delete_set_ptr->find(iter->id) == delete_set_ptr->end())
+ {
+ if (iter->id != location)
+ {
+ result.push_back(iter->id);
+ }
+ }
+
+ // Update occlude factor for points from iter+1 to pool.end()
+ for (auto iter2 = iter + 1; iter2 != pool.end(); iter2++)
+ {
+ auto t = iter2 - pool.begin();
+ if (occlude_factor[t] > alpha)
+ continue;
+
+ bool prune_allowed = true;
+ if (_filtered_index)
+ {
+ uint32_t a = iter->id;
+ uint32_t b = iter2->id;
+ if (_location_to_labels.size() < b || _location_to_labels.size() < a)
+ continue;
+ for (auto &x : _location_to_labels[b])
+ {
+ if (std::find(_location_to_labels[a].begin(), _location_to_labels[a].end(), x) ==
+ _location_to_labels[a].end())
+ {
+ prune_allowed = false;
+ }
+ if (!prune_allowed)
+ break;
+ }
+ }
+ if (!prune_allowed)
+ continue;
+
+ float djk = _data_store->get_distance(iter2->id, iter->id);
+ if (_dist_metric == diskann::Metric::L2 || _dist_metric == diskann::Metric::COSINE)
+ {
+ occlude_factor[t] = (djk == 0) ? std::numeric_limits<float>::max()
+ : std::max(occlude_factor[t], iter2->distance / djk);
+ }
+ else if (_dist_metric == diskann::Metric::INNER_PRODUCT)
+ {
+ // Improvization for flipping max and min dist for MIPS
+ float x = -iter2->distance;
+ float y = -djk;
+ if (y > cur_alpha * x)
+ {
+ occlude_factor[t] = std::max(occlude_factor[t], eps);
+ }
+ }
+ }
+ }
+ cur_alpha *= 1.2f;
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::prune_neighbors(const uint32_t location, std::vector<Neighbor> &pool,
+ std::vector<uint32_t> &pruned_list, InMemQueryScratch<T> *scratch)
+{
+ prune_neighbors(location, pool, _indexingRange, _indexingMaxC, _indexingAlpha, pruned_list, scratch);
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::prune_neighbors(const uint32_t location, std::vector<Neighbor> &pool, const uint32_t range,
+ const uint32_t max_candidate_size, const float alpha,
+ std::vector<uint32_t> &pruned_list, InMemQueryScratch<T> *scratch)
+{
+ if (pool.size() == 0)
+ {
+ // if the pool is empty, behave like a noop
+ pruned_list.clear();
+ return;
+ }
+
+ // If using _pq_build, over-write the PQ distances with actual distances
+ // REFACTOR PQ: TODO: How to get rid of this!?
+ if (_pq_dist)
+ {
+ for (auto &ngh : pool)
+ ngh.distance = _data_store->get_distance(ngh.id, location);
+ }
+
+ // sort the pool based on distance to query and prune it with occlude_list
+ std::sort(pool.begin(), pool.end());
+ pruned_list.clear();
+ pruned_list.reserve(range);
+
+ occlude_list(location, pool, alpha, range, max_candidate_size, pruned_list, scratch);
+ assert(pruned_list.size() <= range);
+ if (_saturate_graph && alpha > 1)
+ {
+ for (const auto &node : pool)
+ {
+ if (pruned_list.size() >= range)
+ break;
+ if ((std::find(pruned_list.begin(), pruned_list.end(), node.id) == pruned_list.end()) &&
+ node.id != location)
+ pruned_list.push_back(node.id);
+ }
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::inter_insert(uint32_t n, std::vector<uint32_t> &pruned_list, const uint32_t range,
+ InMemQueryScratch<T> *scratch)
+{
+ const auto &src_pool = pruned_list;
+
+ assert(!src_pool.empty());
+
+ for (auto des : src_pool)
+ {
+ // des.loc is the loc of the neighbors of n
+ assert(des < _max_points + _num_frozen_pts);
+ // des_pool contains the neighbors of the neighbors of n
+ std::vector<uint32_t> copy_of_neighbors;
+ bool prune_needed = false;
+ {
+ LockGuard guard(_locks[des]);
+ auto &des_pool = _graph_store->get_neighbours(des);
+ if (std::find(des_pool.begin(), des_pool.end(), n) == des_pool.end())
+ {
+ if (des_pool.size() < (uint64_t)(defaults::GRAPH_SLACK_FACTOR * range))
+ {
+ // des_pool.emplace_back(n);
+ _graph_store->add_neighbour(des, n);
+ prune_needed = false;
+ }
+ else
+ {
+ copy_of_neighbors.reserve(des_pool.size() + 1);
+ copy_of_neighbors = des_pool;
+ copy_of_neighbors.push_back(n);
+ prune_needed = true;
+ }
+ }
+ } // des lock is released by this point
+
+ if (prune_needed)
+ {
+ tsl::robin_set<uint32_t> dummy_visited(0);
+ std::vector<Neighbor> dummy_pool(0);
+
+ size_t reserveSize = (size_t)(std::ceil(1.05 * defaults::GRAPH_SLACK_FACTOR * range));
+ dummy_visited.reserve(reserveSize);
+ dummy_pool.reserve(reserveSize);
+
+ for (auto cur_nbr : copy_of_neighbors)
+ {
+ if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != des)
+ {
+ float dist = _data_store->get_distance(des, cur_nbr);
+ dummy_pool.emplace_back(Neighbor(cur_nbr, dist));
+ dummy_visited.insert(cur_nbr);
+ }
+ }
+ std::vector<uint32_t> new_out_neighbors;
+ prune_neighbors(des, dummy_pool, new_out_neighbors, scratch);
+ {
+ LockGuard guard(_locks[des]);
+
+ _graph_store->set_neighbours(des, new_out_neighbors);
+ }
+ }
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::inter_insert(uint32_t n, std::vector<uint32_t> &pruned_list, InMemQueryScratch<T> *scratch)
+{
+ inter_insert(n, pruned_list, _indexingRange, scratch);
+}
+
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::link()
+{
+ uint32_t num_threads = _indexingThreads;
+ if (num_threads != 0)
+ omp_set_num_threads(num_threads);
+
+ /* visit_order is a vector that is initialized to the entire graph */
+ std::vector<uint32_t> visit_order;
+ std::vector<diskann::Neighbor> pool, tmp;
+ tsl::robin_set<uint32_t> visited;
+ visit_order.reserve(_nd + _num_frozen_pts);
+ for (uint32_t i = 0; i < (uint32_t)_nd; i++)
+ {
+ visit_order.emplace_back(i);
+ }
+
+ // If there are any frozen points, add them all.
+ for (uint32_t frozen = (uint32_t)_max_points; frozen < _max_points + _num_frozen_pts; frozen++)
+ {
+ visit_order.emplace_back(frozen);
+ }
+
+ // if there are frozen points, the first such one is set to be the _start
+ if (_num_frozen_pts > 0)
+ _start = (uint32_t)_max_points;
+ else
+ _start = calculate_entry_point();
+
+ diskann::Timer link_timer;
+
+#pragma omp parallel for schedule(dynamic, 2048)
+ for (int64_t node_ctr = 0; node_ctr < (int64_t)(visit_order.size()); node_ctr++)
+ {
+ auto node = visit_order[node_ctr];
+
+ // Find and add appropriate graph edges
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+ std::vector<uint32_t> pruned_list;
+ if (_filtered_index)
+ {
+ search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize);
+ }
+ else
+ {
+ search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch);
+ }
+ assert(pruned_list.size() > 0);
+
+ {
+ LockGuard guard(_locks[node]);
+
+ _graph_store->set_neighbours(node, pruned_list);
+ assert(_graph_store->get_neighbours((location_t)node).size() <= _indexingRange);
+ }
+
+ inter_insert(node, pruned_list, scratch);
+
+ if (node_ctr % 100000 == 0)
+ {
+ diskann::cout << "\r" << (100.0 * node_ctr) / (visit_order.size()) << "% of index build completed."
+ << std::flush;
+ }
+ }
+
+ if (_nd > 0)
+ {
+ diskann::cout << "Starting final cleanup.." << std::flush;
+ }
+#pragma omp parallel for schedule(dynamic, 2048)
+ for (int64_t node_ctr = 0; node_ctr < (int64_t)(visit_order.size()); node_ctr++)
+ {
+ auto node = visit_order[node_ctr];
+ if (_graph_store->get_neighbours((location_t)node).size() > _indexingRange)
+ {
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+
+ tsl::robin_set<uint32_t> dummy_visited(0);
+ std::vector<Neighbor> dummy_pool(0);
+ std::vector<uint32_t> new_out_neighbors;
+
+ for (auto cur_nbr : _graph_store->get_neighbours((location_t)node))
+ {
+ if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node)
+ {
+ float dist = _data_store->get_distance(node, cur_nbr);
+ dummy_pool.emplace_back(Neighbor(cur_nbr, dist));
+ dummy_visited.insert(cur_nbr);
+ }
+ }
+ prune_neighbors(node, dummy_pool, new_out_neighbors, scratch);
+
+ _graph_store->clear_neighbours((location_t)node);
+ _graph_store->set_neighbours((location_t)node, new_out_neighbors);
+ }
+ }
+ if (_nd > 0)
+ {
+ diskann::cout << "done. Link time: " << ((double)link_timer.elapsed() / (double)1000000) << "s" << std::endl;
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::prune_all_neighbors(const uint32_t max_degree, const uint32_t max_occlusion_size,
+ const float alpha)
+{
+ const uint32_t range = max_degree;
+ const uint32_t maxc = max_occlusion_size;
+
+ _filtered_index = true;
+
+ diskann::Timer timer;
+#pragma omp parallel for
+ for (int64_t node = 0; node < (int64_t)(_max_points + _num_frozen_pts); node++)
+ {
+ if ((size_t)node < _nd || (size_t)node >= _max_points)
+ {
+ if (_graph_store->get_neighbours((location_t)node).size() > range)
+ {
+ tsl::robin_set<uint32_t> dummy_visited(0);
+ std::vector<Neighbor> dummy_pool(0);
+ std::vector<uint32_t> new_out_neighbors;
+
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+
+ for (auto cur_nbr : _graph_store->get_neighbours((location_t)node))
+ {
+ if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node)
+ {
+ float dist = _data_store->get_distance((location_t)node, (location_t)cur_nbr);
+ dummy_pool.emplace_back(Neighbor(cur_nbr, dist));
+ dummy_visited.insert(cur_nbr);
+ }
+ }
+
+ prune_neighbors((uint32_t)node, dummy_pool, range, maxc, alpha, new_out_neighbors, scratch);
+ _graph_store->clear_neighbours((location_t)node);
+ _graph_store->set_neighbours((location_t)node, new_out_neighbors);
+ }
+ }
+ }
+
+ diskann::cout << "Prune time : " << timer.elapsed() / 1000 << "ms" << std::endl;
+ size_t max = 0, min = 1 << 30, total = 0, cnt = 0;
+ for (size_t i = 0; i < _max_points + _num_frozen_pts; i++)
+ {
+ if (i < _nd || i >= _max_points)
+ {
+ const std::vector<uint32_t> &pool = _graph_store->get_neighbours((location_t)i);
+ max = (std::max)(max, pool.size());
+ min = (std::min)(min, pool.size());
+ total += pool.size();
+ if (pool.size() < 2)
+ cnt++;
+ }
+ }
+ if (min > max)
+ min = max;
+ if (_nd > 0)
+ {
+ diskann::cout << "Index built with degree: max:" << max
+ << " avg:" << (float)total / (float)(_nd + _num_frozen_pts) << " min:" << min
+ << " count(deg<2):" << cnt << std::endl;
+ }
+}
+
+// REFACTOR
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::set_start_points(const T *data, size_t data_count)
+{
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ if (_nd > 0)
+ throw ANNException("Can not set starting point for a non-empty index", -1, __FUNCSIG__, __FILE__, __LINE__);
+
+ if (data_count != _num_frozen_pts * _dim)
+ throw ANNException("Invalid number of points", -1, __FUNCSIG__, __FILE__, __LINE__);
+
+ // memcpy(_data + _aligned_dim * _max_points, data, _aligned_dim *
+ // sizeof(T) * _num_frozen_pts);
+ for (location_t i = 0; i < _num_frozen_pts; i++)
+ {
+ _data_store->set_vector((location_t)(i + _max_points), data + i * _dim);
+ }
+ _has_built = true;
+ diskann::cout << "Index start points set: #" << _num_frozen_pts << std::endl;
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::_set_start_points_at_random(DataType radius, uint32_t random_seed)
+{
+ try
+ {
+ T radius_to_use = std::any_cast<T>(radius);
+ this->set_start_points_at_random(radius_to_use, random_seed);
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException(
+ "Error: bad any cast while performing _set_start_points_at_random() " + std::string(e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw ANNException("Error: " + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::set_start_points_at_random(T radius, uint32_t random_seed)
+{
+ std::mt19937 gen{random_seed};
+ std::normal_distribution<> d{0.0, 1.0};
+
+ std::vector<T> points_data;
+ points_data.reserve(_dim * _num_frozen_pts);
+ std::vector<double> real_vec(_dim);
+
+ for (size_t frozen_point = 0; frozen_point < _num_frozen_pts; frozen_point++)
+ {
+ double norm_sq = 0.0;
+ for (size_t i = 0; i < _dim; ++i)
+ {
+ auto r = d(gen);
+ real_vec[i] = r;
+ norm_sq += r * r;
+ }
+
+ const double norm = std::sqrt(norm_sq);
+ for (auto iter : real_vec)
+ points_data.push_back(static_cast<T>(iter * radius / norm));
+ }
+
+ set_start_points(points_data.data(), points_data.size());
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::build_with_data_populated(const std::vector<TagT> &tags)
+{
+ diskann::cout << "Starting index build with " << _nd << " points... " << std::endl;
+
+ if (_nd < 1)
+ throw ANNException("Error: Trying to build an index with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__);
+
+ if (_enable_tags && tags.size() != _nd)
+ {
+ std::stringstream stream;
+ stream << "ERROR: Driver requests loading " << _nd << " points from file,"
+ << "but tags vector is of size " << tags.size() << "." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ if (_enable_tags)
+ {
+ for (size_t i = 0; i < tags.size(); ++i)
+ {
+ _tag_to_location[tags[i]] = (uint32_t)i;
+ _location_to_tag.set(static_cast<uint32_t>(i), tags[i]);
+ }
+ }
+
+ uint32_t index_R = _indexingRange;
+ uint32_t num_threads_index = _indexingThreads;
+ uint32_t index_L = _indexingQueueSize;
+ uint32_t maxc = _indexingMaxC;
+
+ if (_query_scratch.size() == 0)
+ {
+ initialize_query_scratch(5 + num_threads_index, index_L, index_L, index_R, maxc,
+ _data_store->get_aligned_dim());
+ }
+
+ generate_frozen_point();
+ link();
+
+ size_t max = 0, min = SIZE_MAX, total = 0, cnt = 0;
+ for (size_t i = 0; i < _nd; i++)
+ {
+ auto &pool = _graph_store->get_neighbours((location_t)i);
+ max = std::max(max, pool.size());
+ min = std::min(min, pool.size());
+ total += pool.size();
+ if (pool.size() < 2)
+ cnt++;
+ }
+ diskann::cout << "Index built with degree: max:" << max << " avg:" << (float)total / (float)(_nd + _num_frozen_pts)
+ << " min:" << min << " count(deg<2):" << cnt << std::endl;
+
+ _has_built = true;
+}
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::_build(const DataType &data, const size_t num_points_to_load, TagVector &tags)
+{
+ try
+ {
+ this->build(std::any_cast<const T *>(data), num_points_to_load, tags.get<const std::vector<TagT>>());
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException("Error: bad any cast in while building index. " + std::string(e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw ANNException("Error" + std::string(e.what()), -1);
+ }
+}
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::build(const T *data, const size_t num_points_to_load, const std::vector<TagT> &tags)
+{
+ if (num_points_to_load == 0)
+ {
+ throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ if (_pq_dist)
+ {
+ throw ANNException("ERROR: DO not use this build interface with PQ distance", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+
+ {
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ _nd = num_points_to_load;
+
+ _data_store->populate_data(data, (location_t)num_points_to_load);
+ }
+
+ build_with_data_populated(tags);
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::build(const char *filename, const size_t num_points_to_load, const std::vector<TagT> &tags)
+{
+ // idealy this should call build_filtered_index based on params passed
+
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+
+ // error checks
+ if (num_points_to_load == 0)
+ throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__);
+
+ if (!file_exists(filename))
+ {
+ std::stringstream stream;
+ stream << "ERROR: Data file " << filename << " does not exist." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ size_t file_num_points, file_dim;
+ if (filename == nullptr)
+ {
+ throw diskann::ANNException("Can not build with an empty file", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ diskann::get_bin_metadata(filename, file_num_points, file_dim);
+ if (file_num_points > _max_points)
+ {
+ std::stringstream stream;
+ stream << "ERROR: Driver requests loading " << num_points_to_load << " points and file has " << file_num_points
+ << " points, but "
+ << "index can support only " << _max_points << " points as specified in constructor." << std::endl;
+
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (num_points_to_load > file_num_points)
+ {
+ std::stringstream stream;
+ stream << "ERROR: Driver requests loading " << num_points_to_load << " points and file has only "
+ << file_num_points << " points." << std::endl;
+
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (file_dim != _dim)
+ {
+ std::stringstream stream;
+ stream << "ERROR: Driver requests loading " << _dim << " dimension,"
+ << "but file has " << file_dim << " dimension." << std::endl;
+ diskann::cerr << stream.str() << std::endl;
+
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ // REFACTOR PQ TODO: We can remove this if and add a check in the InMemDataStore
+ // to not populate_data if it has been called once.
+ if (_pq_dist)
+ {
+#ifdef EXEC_ENV_OLS
+ std::stringstream ss;
+ ss << "PQ Build is not supported in DLVS environment (i.e. if EXEC_ENV_OLS is defined)" << std::endl;
+ diskann::cerr << ss.str() << std::endl;
+ throw ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+#else
+ // REFACTOR TODO: Both in the previous code and in the current PQDataStore,
+ // we are writing the PQ files in the same path as the input file. Now we
+ // may not have write permissions to that folder, but we will always have
+ // write permissions to the output folder. So we should write the PQ files
+ // there. The problem is that the Index class gets the output folder prefix
+ // only at the time of save(), by which time we are too late. So leaving it
+ // as-is for now.
+ _pq_data_store->populate_data(filename, 0U);
+#endif
+ }
+
+ _data_store->populate_data(filename, 0U);
+ diskann::cout << "Using only first " << num_points_to_load << " from file.. " << std::endl;
+
+ {
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ _nd = num_points_to_load;
+ }
+ build_with_data_populated(tags);
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::build(const char *filename, const size_t num_points_to_load, const char *tag_filename)
+{
+ std::vector<TagT> tags;
+
+ if (_enable_tags)
+ {
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ if (tag_filename == nullptr)
+ {
+ throw ANNException("Tag filename is null, while _enable_tags is set", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ else
+ {
+ if (file_exists(tag_filename))
+ {
+ diskann::cout << "Loading tags from " << tag_filename << " for vamana index build" << std::endl;
+ TagT *tag_data = nullptr;
+ size_t npts, ndim;
+ diskann::load_bin(tag_filename, tag_data, npts, ndim);
+ if (npts < num_points_to_load)
+ {
+ std::stringstream sstream;
+ sstream << "Loaded " << npts << " tags, insufficient to populate tags for " << num_points_to_load
+ << " points to load";
+ throw diskann::ANNException(sstream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ for (size_t i = 0; i < num_points_to_load; i++)
+ {
+ tags.push_back(tag_data[i]);
+ }
+ delete[] tag_data;
+ }
+ else
+ {
+ throw diskann::ANNException(std::string("Tag file") + tag_filename + " does not exist", -1, __FUNCSIG__,
+ __FILE__, __LINE__);
+ }
+ }
+ }
+ build(filename, num_points_to_load, tags);
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::build(const std::string &data_file, const size_t num_points_to_load,
+ IndexFilterParams &filter_params)
+{
+ size_t points_to_load = num_points_to_load == 0 ? _max_points : num_points_to_load;
+
+ auto s = std::chrono::high_resolution_clock::now();
+ if (filter_params.label_file == "")
+ {
+ this->build(data_file.c_str(), points_to_load);
+ }
+ else
+ {
+ // TODO: this should ideally happen in save()
+ std::string labels_file_to_use = filter_params.save_path_prefix + "_label_formatted.txt";
+ std::string mem_labels_int_map_file = filter_params.save_path_prefix + "_labels_map.txt";
+ convert_labels_string_to_int(filter_params.label_file, labels_file_to_use, mem_labels_int_map_file,
+ filter_params.universal_label);
+ if (filter_params.universal_label != "")
+ {
+ LabelT unv_label_as_num = 0;
+ this->set_universal_label(unv_label_as_num);
+ }
+ this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load);
+ }
+ std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
+ std::cout << "Indexing time: " << diff.count() << "\n";
+}
+
+template <typename T, typename TagT, typename LabelT>
+std::unordered_map<std::string, LabelT> Index<T, TagT, LabelT>::load_label_map(const std::string &labels_map_file)
+{
+ std::unordered_map<std::string, LabelT> string_to_int_mp;
+ std::ifstream map_reader(labels_map_file);
+ std::string line, token;
+ LabelT token_as_num;
+ std::string label_str;
+ while (std::getline(map_reader, line))
+ {
+ std::istringstream iss(line);
+ getline(iss, token, '\t');
+ label_str = token;
+ getline(iss, token, '\t');
+ token_as_num = (LabelT)std::stoul(token);
+ string_to_int_mp[label_str] = token_as_num;
+ }
+ return string_to_int_mp;
+}
+
+template <typename T, typename TagT, typename LabelT>
+LabelT Index<T, TagT, LabelT>::get_converted_label(const std::string &raw_label)
+{
+ if (_label_map.find(raw_label) != _label_map.end())
+ {
+ return _label_map[raw_label];
+ }
+ if (_use_universal_label)
+ {
+ return _universal_label;
+ }
+ std::stringstream stream;
+ stream << "Unable to find label in the Label Map";
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points)
+{
+ // Format of Label txt file: filters with comma separators
+
+ std::ifstream infile(label_file);
+ if (infile.fail())
+ {
+ throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
+ }
+
+ std::string line, token;
+ uint32_t line_cnt = 0;
+
+ while (std::getline(infile, line))
+ {
+ line_cnt++;
+ }
+ _location_to_labels.resize(line_cnt, std::vector<LabelT>());
+
+ infile.clear();
+ infile.seekg(0, std::ios::beg);
+ line_cnt = 0;
+
+ while (std::getline(infile, line))
+ {
+ std::istringstream iss(line);
+ std::vector<LabelT> lbls(0);
+ getline(iss, token, '\t');
+ std::istringstream new_iss(token);
+ while (getline(new_iss, token, ','))
+ {
+ token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
+ token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
+ LabelT token_as_num = (LabelT)std::stoul(token);
+ lbls.push_back(token_as_num);
+ _labels.insert(token_as_num);
+ }
+
+ std::sort(lbls.begin(), lbls.end());
+ _location_to_labels[line_cnt] = lbls;
+ line_cnt++;
+ }
+ num_points = (size_t)line_cnt;
+ diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl;
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::_set_universal_label(const LabelType universal_label)
+{
+ this->set_universal_label(std::any_cast<const LabelT>(universal_label));
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::set_universal_label(const LabelT &label)
+{
+ _use_universal_label = true;
+ _universal_label = label;
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::build_filtered_index(const char *filename, const std::string &label_file,
+ const size_t num_points_to_load, const std::vector<TagT> &tags)
+{
+ _filtered_index = true;
+ _label_to_start_id.clear();
+ size_t num_points_labels = 0;
+
+ parse_label_file(label_file,
+ num_points_labels); // determines medoid for each label and identifies
+ // the points to label mapping
+
+ std::unordered_map<LabelT, std::vector<uint32_t>> label_to_points;
+
+ for (uint32_t point_id = 0; point_id < num_points_to_load; point_id++)
+ {
+ for (auto label : _location_to_labels[point_id])
+ {
+ if (label != _universal_label)
+ {
+ label_to_points[label].emplace_back(point_id);
+ }
+ else
+ {
+ for (typename tsl::robin_set<LabelT>::size_type lbl = 0; lbl < _labels.size(); lbl++)
+ {
+ auto itr = _labels.begin();
+ std::advance(itr, lbl);
+ auto &x = *itr;
+ label_to_points[x].emplace_back(point_id);
+ }
+ }
+ }
+ }
+
+ uint32_t num_cands = 25;
+ for (auto itr = _labels.begin(); itr != _labels.end(); itr++)
+ {
+ uint32_t best_medoid_count = std::numeric_limits<uint32_t>::max();
+ auto &curr_label = *itr;
+ uint32_t best_medoid;
+ auto labeled_points = label_to_points[curr_label];
+ for (uint32_t cnd = 0; cnd < num_cands; cnd++)
+ {
+ uint32_t cur_cnd = labeled_points[rand() % labeled_points.size()];
+ uint32_t cur_cnt = std::numeric_limits<uint32_t>::max();
+ if (_medoid_counts.find(cur_cnd) == _medoid_counts.end())
+ {
+ _medoid_counts[cur_cnd] = 0;
+ cur_cnt = 0;
+ }
+ else
+ {
+ cur_cnt = _medoid_counts[cur_cnd];
+ }
+ if (cur_cnt < best_medoid_count)
+ {
+ best_medoid_count = cur_cnt;
+ best_medoid = cur_cnd;
+ }
+ }
+ _label_to_start_id[curr_label] = best_medoid;
+ _medoid_counts[best_medoid]++;
+ }
+
+ this->build(filename, num_points_to_load, tags);
+}
+
+template <typename T, typename TagT, typename LabelT>
+std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search(const DataType &query, const size_t K, const uint32_t L,
+ std::any &indices, float *distances)
+{
+ try
+ {
+ auto typed_query = std::any_cast<const T *>(query);
+ if (typeid(uint32_t *) == indices.type())
+ {
+ auto u32_ptr = std::any_cast<uint32_t *>(indices);
+ return this->search(typed_query, K, L, u32_ptr, distances);
+ }
+ else if (typeid(uint64_t *) == indices.type())
+ {
+ auto u64_ptr = std::any_cast<uint64_t *>(indices);
+ return this->search(typed_query, K, L, u64_ptr, distances);
+ }
+ else
+ {
+ throw ANNException("Error: indices type can only be uint64_t or uint32_t.", -1);
+ }
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException("Error: bad any cast while searching. " + std::string(e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw ANNException("Error: " + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+template <typename IdType>
+std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, const size_t K, const uint32_t L,
+ IdType *indices, float *distances)
+{
+ if (K > (uint64_t)L)
+ {
+ throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+
+ if (L > scratch->get_L())
+ {
+ diskann::cout << "Attempting to expand query scratch_space. Was created "
+ << "with Lsize: " << scratch->get_L() << " but search L is: " << L << std::endl;
+ scratch->resize_for_new_L(L);
+ diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() << std::endl;
+ }
+
+ const std::vector<LabelT> unused_filter_label;
+ const std::vector<uint32_t> init_ids = get_init_ids();
+
+ std::shared_lock<std::shared_timed_mutex> lock(_update_lock);
+
+ _data_store->preprocess_query(query, scratch);
+
+ auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true);
+
+ NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes();
+
+ size_t pos = 0;
+ for (size_t i = 0; i < best_L_nodes.size(); ++i)
+ {
+ if (best_L_nodes[i].id < _max_points)
+ {
+ // safe because Index uses uint32_t ids internally
+ // and IDType will be uint32_t or uint64_t
+ indices[pos] = (IdType)best_L_nodes[i].id;
+ if (distances != nullptr)
+ {
+#ifdef EXEC_ENV_OLS
+ // DLVS expects negative distances
+ distances[pos] = best_L_nodes[i].distance;
+#else
+ distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance
+ : best_L_nodes[i].distance;
+#endif
+ }
+ pos++;
+ }
+ if (pos == K)
+ break;
+ }
+ if (pos < K)
+ {
+ diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl;
+ }
+
+ return retval;
+}
+
+template <typename T, typename TagT, typename LabelT>
+std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const DataType &query,
+ const std::string &raw_label, const size_t K,
+ const uint32_t L, std::any &indices,
+ float *distances)
+{
+ auto converted_label = this->get_converted_label(raw_label);
+ if (typeid(uint64_t *) == indices.type())
+ {
+ auto ptr = std::any_cast<uint64_t *>(indices);
+ return this->search_with_filters(std::any_cast<T *>(query), converted_label, K, L, ptr, distances);
+ }
+ else if (typeid(uint32_t *) == indices.type())
+ {
+ auto ptr = std::any_cast<uint32_t *>(indices);
+ return this->search_with_filters(std::any_cast<T *>(query), converted_label, K, L, ptr, distances);
+ }
+ else
+ {
+ throw ANNException("Error: Id type can only be uint64_t or uint32_t.", -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+template <typename IdType>
+std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const T *query, const LabelT &filter_label,
+ const size_t K, const uint32_t L,
+ IdType *indices, float *distances)
+{
+ if (K > (uint64_t)L)
+ {
+ throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+
+ if (L > scratch->get_L())
+ {
+ diskann::cout << "Attempting to expand query scratch_space. Was created "
+ << "with Lsize: " << scratch->get_L() << " but search L is: " << L << std::endl;
+ scratch->resize_for_new_L(L);
+ diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() << std::endl;
+ }
+
+ std::vector<LabelT> filter_vec;
+ std::vector<uint32_t> init_ids = get_init_ids();
+
+ std::shared_lock<std::shared_timed_mutex> lock(_update_lock);
+ std::shared_lock<std::shared_timed_mutex> tl(_tag_lock, std::defer_lock);
+ if (_dynamic_index)
+ tl.lock();
+
+ if (_label_to_start_id.find(filter_label) != _label_to_start_id.end())
+ {
+ init_ids.emplace_back(_label_to_start_id[filter_label]);
+ }
+ else
+ {
+ diskann::cout << "No filtered medoid found. exitting "
+ << std::endl; // RKNOTE: If universal label found start there
+ throw diskann::ANNException("No filtered medoid found. exitting ", -1);
+ }
+ if (_dynamic_index)
+ tl.unlock();
+
+ filter_vec.emplace_back(filter_label);
+
+ _data_store->preprocess_query(query, scratch);
+ auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true);
+
+ auto best_L_nodes = scratch->best_l_nodes();
+
+ size_t pos = 0;
+ for (size_t i = 0; i < best_L_nodes.size(); ++i)
+ {
+ if (best_L_nodes[i].id < _max_points)
+ {
+ indices[pos] = (IdType)best_L_nodes[i].id;
+
+ if (distances != nullptr)
+ {
+#ifdef EXEC_ENV_OLS
+ // DLVS expects negative distances
+ distances[pos] = best_L_nodes[i].distance;
+#else
+ distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance
+ : best_L_nodes[i].distance;
+#endif
+ }
+ pos++;
+ }
+ if (pos == K)
+ break;
+ }
+ if (pos < K)
+ {
+ diskann::cerr << "Found fewer than K elements for query" << std::endl;
+ }
+
+ return retval;
+}
+
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::_search_with_tags(const DataType &query, const uint64_t K, const uint32_t L,
+ const TagType &tags, float *distances, DataVector &res_vectors,
+ bool use_filters, const std::string filter_label)
+{
+ try
+ {
+ return this->search_with_tags(std::any_cast<const T *>(query), K, L, std::any_cast<TagT *>(tags), distances,
+ res_vectors.get<std::vector<T *>>(), use_filters, filter_label);
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException("Error: bad any cast while performing _search_with_tags() " + std::string(e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw ANNException("Error: " + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
+ float *distances, std::vector<T *> &res_vectors, bool use_filters,
+ const std::string filter_label)
+{
+ if (K > (uint64_t)L)
+ {
+ throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+
+ if (L > scratch->get_L())
+ {
+ diskann::cout << "Attempting to expand query scratch_space. Was created "
+ << "with Lsize: " << scratch->get_L() << " but search L is: " << L << std::endl;
+ scratch->resize_for_new_L(L);
+ diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() << std::endl;
+ }
+
+ std::shared_lock<std::shared_timed_mutex> ul(_update_lock);
+
+ const std::vector<uint32_t> init_ids = get_init_ids();
+
+ //_distance->preprocess_query(query, _data_store->get_dims(),
+ // scratch->aligned_query());
+ _data_store->preprocess_query(query, scratch);
+ if (!use_filters)
+ {
+ const std::vector<LabelT> unused_filter_label;
+ iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true);
+ }
+ else
+ {
+ std::vector<LabelT> filter_vec;
+ auto converted_label = this->get_converted_label(filter_label);
+ filter_vec.push_back(converted_label);
+ iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true);
+ }
+
+ NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes();
+ assert(best_L_nodes.size() <= L);
+
+ std::shared_lock<std::shared_timed_mutex> tl(_tag_lock);
+
+ size_t pos = 0;
+ for (size_t i = 0; i < best_L_nodes.size(); ++i)
+ {
+ auto node = best_L_nodes[i];
+
+ TagT tag;
+ if (_location_to_tag.try_get(node.id, tag))
+ {
+ tags[pos] = tag;
+
+ if (res_vectors.size() > 0)
+ {
+ _data_store->get_vector(node.id, res_vectors[pos]);
+ }
+
+ if (distances != nullptr)
+ {
+#ifdef EXEC_ENV_OLS
+ distances[pos] = node.distance; // DLVS expects negative distances
+#else
+ distances[pos] = _dist_metric == INNER_PRODUCT ? -1 * node.distance : node.distance;
+#endif
+ }
+ pos++;
+ // If res_vectors.size() < k, clip at the value.
+ if (pos == K || pos == res_vectors.size())
+ break;
+ }
+ }
+
+ return pos;
+}
+
+template <typename T, typename TagT, typename LabelT> size_t Index<T, TagT, LabelT>::get_num_points()
+{
+ std::shared_lock<std::shared_timed_mutex> tl(_tag_lock);
+ return _nd;
+}
+
+template <typename T, typename TagT, typename LabelT> size_t Index<T, TagT, LabelT>::get_max_points()
+{
+ std::shared_lock<std::shared_timed_mutex> tl(_tag_lock);
+ return _max_points;
+}
+
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::generate_frozen_point()
+{
+ if (_num_frozen_pts == 0)
+ return;
+
+ if (_num_frozen_pts > 1)
+ {
+ throw ANNException("More than one frozen point not supported in generate_frozen_point", -1, __FUNCSIG__,
+ __FILE__, __LINE__);
+ }
+
+ if (_nd == 0)
+ {
+ throw ANNException("ERROR: Can not pick a frozen point since nd=0", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ size_t res = calculate_entry_point();
+
+ // REFACTOR PQ: Not sure if we should do this for both stores.
+ if (_pq_dist)
+ {
+ // copy the PQ data corresponding to the point returned by
+ // calculate_entry_point
+ // memcpy(_pq_data + _max_points * _num_pq_chunks,
+ // _pq_data + res * _num_pq_chunks,
+ // _num_pq_chunks * DIV_ROUND_UP(NUM_PQ_BITS, 8));
+ _pq_data_store->copy_vectors((location_t)res, (location_t)_max_points, 1);
+ }
+ else
+ {
+ _data_store->copy_vectors((location_t)res, (location_t)_max_points, 1);
+ }
+ _frozen_pts_used++;
+}
+
+template <typename T, typename TagT, typename LabelT> int Index<T, TagT, LabelT>::enable_delete()
+{
+ assert(_enable_tags);
+
+ if (!_enable_tags)
+ {
+ diskann::cerr << "Tags must be instantiated for deletions" << std::endl;
+ return -2;
+ }
+
+ if (this->_deletes_enabled)
+ {
+ return 0;
+ }
+
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+
+ if (_data_compacted)
+ {
+ for (uint32_t slot = (uint32_t)_nd; slot < _max_points; ++slot)
+ {
+ _empty_slots.insert(slot);
+ }
+ }
+ this->_deletes_enabled = true;
+ return 0;
+}
+
+template <typename T, typename TagT, typename LabelT>
+inline void Index<T, TagT, LabelT>::process_delete(const tsl::robin_set<uint32_t> &old_delete_set, size_t loc,
+ const uint32_t range, const uint32_t maxc, const float alpha,
+ InMemQueryScratch<T> *scratch)
+{
+ tsl::robin_set<uint32_t> &expanded_nodes_set = scratch->expanded_nodes_set();
+ std::vector<Neighbor> &expanded_nghrs_vec = scratch->expanded_nodes_vec();
+
+ // If this condition were not true, deadlock could result
+ assert(old_delete_set.find((uint32_t)loc) == old_delete_set.end());
+
+ std::vector<uint32_t> adj_list;
+ {
+ // Acquire and release lock[loc] before acquiring locks for neighbors
+ std::unique_lock<non_recursive_mutex> adj_list_lock;
+ if (_conc_consolidate)
+ adj_list_lock = std::unique_lock<non_recursive_mutex>(_locks[loc]);
+ adj_list = _graph_store->get_neighbours((location_t)loc);
+ }
+
+ bool modify = false;
+ for (auto ngh : adj_list)
+ {
+ if (old_delete_set.find(ngh) == old_delete_set.end())
+ {
+ expanded_nodes_set.insert(ngh);
+ }
+ else
+ {
+ modify = true;
+
+ std::unique_lock<non_recursive_mutex> ngh_lock;
+ if (_conc_consolidate)
+ ngh_lock = std::unique_lock<non_recursive_mutex>(_locks[ngh]);
+ for (auto j : _graph_store->get_neighbours((location_t)ngh))
+ if (j != loc && old_delete_set.find(j) == old_delete_set.end())
+ expanded_nodes_set.insert(j);
+ }
+ }
+
+ if (modify)
+ {
+ if (expanded_nodes_set.size() <= range)
+ {
+ std::unique_lock<non_recursive_mutex> adj_list_lock(_locks[loc]);
+ _graph_store->clear_neighbours((location_t)loc);
+ for (auto &ngh : expanded_nodes_set)
+ _graph_store->add_neighbour((location_t)loc, ngh);
+ }
+ else
+ {
+ // Create a pool of Neighbor candidates from the expanded_nodes_set
+ expanded_nghrs_vec.reserve(expanded_nodes_set.size());
+ for (auto &ngh : expanded_nodes_set)
+ {
+ expanded_nghrs_vec.emplace_back(ngh, _data_store->get_distance((location_t)loc, (location_t)ngh));
+ }
+ std::sort(expanded_nghrs_vec.begin(), expanded_nghrs_vec.end());
+ std::vector<uint32_t> &occlude_list_output = scratch->occlude_list_output();
+ occlude_list((uint32_t)loc, expanded_nghrs_vec, alpha, range, maxc, occlude_list_output, scratch,
+ &old_delete_set);
+ std::unique_lock<non_recursive_mutex> adj_list_lock(_locks[loc]);
+ _graph_store->set_neighbours((location_t)loc, occlude_list_output);
+ }
+ }
+}
+
+// Returns number of live points left after consolidation
+template <typename T, typename TagT, typename LabelT>
+consolidation_report Index<T, TagT, LabelT>::consolidate_deletes(const IndexWriteParameters ¶ms)
+{
+ if (!_enable_tags)
+ throw diskann::ANNException("Point tag array not instantiated", -1, __FUNCSIG__, __FILE__, __LINE__);
+
+ {
+ std::shared_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::shared_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::shared_lock<std::shared_timed_mutex> dl(_delete_lock);
+ if (_empty_slots.size() + _nd != _max_points)
+ {
+ std::string err = "#empty slots + nd != max points";
+ diskann::cerr << err << std::endl;
+ throw ANNException(err, -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (_location_to_tag.size() + _delete_set->size() != _nd)
+ {
+ diskann::cerr << "Error: _location_to_tag.size (" << _location_to_tag.size() << ") + _delete_set->size ("
+ << _delete_set->size() << ") != _nd(" << _nd << ") ";
+ return consolidation_report(diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR, 0, 0, 0,
+ 0, 0, 0, 0);
+ }
+
+ if (_location_to_tag.size() != _tag_to_location.size())
+ {
+ throw diskann::ANNException("_location_to_tag and _tag_to_location not of same size", -1, __FUNCSIG__,
+ __FILE__, __LINE__);
+ }
+ }
+
+ std::unique_lock<std::shared_timed_mutex> update_lock(_update_lock, std::defer_lock);
+ if (!_conc_consolidate)
+ update_lock.lock();
+
+ std::unique_lock<std::shared_timed_mutex> cl(_consolidate_lock, std::defer_lock);
+ if (!cl.try_lock())
+ {
+ diskann::cerr << "Consildate delete function failed to acquire consolidate lock" << std::endl;
+ return consolidation_report(diskann::consolidation_report::status_code::LOCK_FAIL, 0, 0, 0, 0, 0, 0, 0);
+ }
+
+ diskann::cout << "Starting consolidate_deletes... ";
+
+ std::unique_ptr<tsl::robin_set<uint32_t>> old_delete_set(new tsl::robin_set<uint32_t>);
+ {
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+ std::swap(_delete_set, old_delete_set);
+ }
+
+ if (old_delete_set->find(_start) != old_delete_set->end())
+ {
+ throw diskann::ANNException("ERROR: start node has been deleted", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ const uint32_t range = params.max_degree;
+ const uint32_t maxc = params.max_occlusion_size;
+ const float alpha = params.alpha;
+ const uint32_t num_threads = params.num_threads == 0 ? omp_get_num_procs() : params.num_threads;
+
+ uint32_t num_calls_to_process_delete = 0;
+ diskann::Timer timer;
+#pragma omp parallel for num_threads(num_threads) schedule(dynamic, 8192) reduction(+ : num_calls_to_process_delete)
+ for (int64_t loc = 0; loc < (int64_t)_max_points; loc++)
+ {
+ if (old_delete_set->find((uint32_t)loc) == old_delete_set->end() && !_empty_slots.is_in_set((uint32_t)loc))
+ {
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+ process_delete(*old_delete_set, loc, range, maxc, alpha, scratch);
+ num_calls_to_process_delete += 1;
+ }
+ }
+ for (int64_t loc = _max_points; loc < (int64_t)(_max_points + _num_frozen_pts); loc++)
+ {
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+ process_delete(*old_delete_set, loc, range, maxc, alpha, scratch);
+ num_calls_to_process_delete += 1;
+ }
+
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ size_t ret_nd = release_locations(*old_delete_set);
+ size_t max_points = _max_points;
+ size_t empty_slots_size = _empty_slots.size();
+
+ std::shared_lock<std::shared_timed_mutex> dl(_delete_lock);
+ size_t delete_set_size = _delete_set->size();
+ size_t old_delete_set_size = old_delete_set->size();
+
+ if (!_conc_consolidate)
+ {
+ update_lock.unlock();
+ }
+
+ double duration = timer.elapsed() / 1000000.0;
+ diskann::cout << " done in " << duration << " seconds." << std::endl;
+ return consolidation_report(diskann::consolidation_report::status_code::SUCCESS, ret_nd, max_points,
+ empty_slots_size, old_delete_set_size, delete_set_size, num_calls_to_process_delete,
+ duration);
+}
+
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::compact_frozen_point()
+{
+ if (_nd < _max_points && _num_frozen_pts > 0)
+ {
+ reposition_points((uint32_t)_max_points, (uint32_t)_nd, (uint32_t)_num_frozen_pts);
+ _start = (uint32_t)_nd;
+
+ if (_filtered_index && _dynamic_index)
+ {
+ // update medoid id's as frozen points are treated as medoid
+ for (auto &[label, medoid_id] : _label_to_start_id)
+ {
+ /* if (label == _universal_label)
+ continue;*/
+ _label_to_start_id[label] = (uint32_t)_nd + (medoid_id - (uint32_t)_max_points);
+ }
+ }
+ }
+}
+
+// Should be called after acquiring _update_lock
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::compact_data()
+{
+ if (!_dynamic_index)
+ throw ANNException("Can not compact a non-dynamic index", -1, __FUNCSIG__, __FILE__, __LINE__);
+
+ if (_data_compacted)
+ {
+ diskann::cerr << "Warning! Calling compact_data() when _data_compacted is true!" << std::endl;
+ return;
+ }
+
+ if (_delete_set->size() > 0)
+ {
+ throw ANNException("Can not compact data when index has non-empty _delete_set of "
+ "size: " +
+ std::to_string(_delete_set->size()),
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ diskann::Timer timer;
+
+ std::vector<uint32_t> new_location = std::vector<uint32_t>(_max_points + _num_frozen_pts, UINT32_MAX);
+
+ uint32_t new_counter = 0;
+ std::set<uint32_t> empty_locations;
+ for (uint32_t old_location = 0; old_location < _max_points; old_location++)
+ {
+ if (_location_to_tag.contains(old_location))
+ {
+ new_location[old_location] = new_counter;
+ new_counter++;
+ }
+ else
+ {
+ empty_locations.insert(old_location);
+ }
+ }
+ for (uint32_t old_location = (uint32_t)_max_points; old_location < _max_points + _num_frozen_pts; old_location++)
+ {
+ new_location[old_location] = old_location;
+ }
+
+ // If start node is removed, throw an exception
+ if (_start < _max_points && !_location_to_tag.contains(_start))
+ {
+ throw diskann::ANNException("ERROR: Start node deleted.", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ size_t num_dangling = 0;
+ for (uint32_t old = 0; old < _max_points + _num_frozen_pts; ++old)
+ {
+ // compact _final_graph
+ std::vector<uint32_t> new_adj_list;
+
+ if ((new_location[old] < _max_points) // If point continues to exist
+ || (old >= _max_points && old < _max_points + _num_frozen_pts))
+ {
+ new_adj_list.reserve(_graph_store->get_neighbours((location_t)old).size());
+ for (auto ngh_iter : _graph_store->get_neighbours((location_t)old))
+ {
+ if (empty_locations.find(ngh_iter) != empty_locations.end())
+ {
+ ++num_dangling;
+ diskann::cerr << "Error in compact_data(). _final_graph[" << old << "] has neighbor " << ngh_iter
+ << " which is a location not associated with any tag." << std::endl;
+ }
+ else
+ {
+ new_adj_list.push_back(new_location[ngh_iter]);
+ }
+ }
+ //_graph_store->get_neighbours((location_t)old).swap(new_adj_list);
+ _graph_store->set_neighbours((location_t)old, new_adj_list);
+
+ // Move the data and adj list to the correct position
+ if (new_location[old] != old)
+ {
+ assert(new_location[old] < old);
+ _graph_store->swap_neighbours(new_location[old], (location_t)old);
+
+ if (_filtered_index)
+ {
+ _location_to_labels[new_location[old]].swap(_location_to_labels[old]);
+ }
+
+ _data_store->copy_vectors(old, new_location[old], 1);
+ }
+ }
+ else
+ {
+ _graph_store->clear_neighbours((location_t)old);
+ }
+ }
+ diskann::cerr << "#dangling references after data compaction: " << num_dangling << std::endl;
+
+ _tag_to_location.clear();
+ for (auto pos = _location_to_tag.find_first(); pos.is_valid(); pos = _location_to_tag.find_next(pos))
+ {
+ const auto tag = _location_to_tag.get(pos);
+ _tag_to_location[tag] = new_location[pos._key];
+ }
+ _location_to_tag.clear();
+ for (const auto &iter : _tag_to_location)
+ {
+ _location_to_tag.set(iter.second, iter.first);
+ }
+ // remove all cleared up old
+ for (size_t old = _nd; old < _max_points; ++old)
+ {
+ _graph_store->clear_neighbours((location_t)old);
+ }
+ if (_filtered_index)
+ {
+ for (size_t old = _nd; old < _max_points; old++)
+ {
+ _location_to_labels[old].clear();
+ }
+ }
+
+ _empty_slots.clear();
+ // mark all slots after _nd as empty
+ for (auto i = _nd; i < _max_points; i++)
+ {
+ _empty_slots.insert((uint32_t)i);
+ }
+ _data_compacted = true;
+ diskann::cout << "Time taken for compact_data: " << timer.elapsed() / 1000000. << "s." << std::endl;
+}
+
+//
+// Caller must hold unique _tag_lock and _delete_lock before calling this
+//
+template <typename T, typename TagT, typename LabelT> int Index<T, TagT, LabelT>::reserve_location()
+{
+ if (_nd >= _max_points)
+ {
+ return -1;
+ }
+ uint32_t location;
+ if (_data_compacted && _empty_slots.is_empty())
+ {
+ // This code path is encountered when enable_delete hasn't been
+ // called yet, so no points have been deleted and _empty_slots
+ // hasn't been filled in. In that case, just keep assigning
+ // consecutive locations.
+ location = (uint32_t)_nd;
+ }
+ else
+ {
+ assert(_empty_slots.size() != 0);
+ assert(_empty_slots.size() + _nd == _max_points);
+
+ location = _empty_slots.pop_any();
+ _delete_set->erase(location);
+ }
+ ++_nd;
+ return location;
+}
+
+template <typename T, typename TagT, typename LabelT> size_t Index<T, TagT, LabelT>::release_location(int location)
+{
+ if (_empty_slots.is_in_set(location))
+ throw ANNException("Trying to release location, but location already in empty slots", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ _empty_slots.insert(location);
+
+ _nd--;
+ return _nd;
+}
+
+template <typename T, typename TagT, typename LabelT>
+size_t Index<T, TagT, LabelT>::release_locations(const tsl::robin_set<uint32_t> &locations)
+{
+ for (auto location : locations)
+ {
+ if (_empty_slots.is_in_set(location))
+ throw ANNException("Trying to release location, but location "
+ "already in empty slots",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ _empty_slots.insert(location);
+
+ _nd--;
+ }
+
+ if (_empty_slots.size() + _nd != _max_points)
+ throw ANNException("#empty slots + nd != max points", -1, __FUNCSIG__, __FILE__, __LINE__);
+
+ return _nd;
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::reposition_points(uint32_t old_location_start, uint32_t new_location_start,
+ uint32_t num_locations)
+{
+ if (num_locations == 0 || old_location_start == new_location_start)
+ {
+ return;
+ }
+
+ // Update pointers to the moved nodes. Note: the computation is correct even
+ // when new_location_start < old_location_start given the C++ uint32_t
+ // integer arithmetic rules.
+ const uint32_t location_delta = new_location_start - old_location_start;
+
+ std::vector<location_t> updated_neighbours_location;
+ for (uint32_t i = 0; i < _max_points + _num_frozen_pts; i++)
+ {
+ auto &i_neighbours = _graph_store->get_neighbours((location_t)i);
+ std::vector<location_t> i_neighbours_copy(i_neighbours.begin(), i_neighbours.end());
+ for (auto &loc : i_neighbours_copy)
+ {
+ if (loc >= old_location_start && loc < old_location_start + num_locations)
+ loc += location_delta;
+ }
+ _graph_store->set_neighbours(i, i_neighbours_copy);
+ }
+
+ // The [start, end) interval which will contain obsolete points to be
+ // cleared.
+ uint32_t mem_clear_loc_start = old_location_start;
+ uint32_t mem_clear_loc_end_limit = old_location_start + num_locations;
+
+ // Move the adjacency lists. Make sure that overlapping ranges are handled
+ // correctly.
+ if (new_location_start < old_location_start)
+ {
+ // New location before the old location: copy the entries in order
+ // to avoid modifying locations that are yet to be copied.
+ for (uint32_t loc_offset = 0; loc_offset < num_locations; loc_offset++)
+ {
+ assert(_graph_store->get_neighbours(new_location_start + loc_offset).empty());
+ _graph_store->swap_neighbours(new_location_start + loc_offset, old_location_start + loc_offset);
+ if (_dynamic_index && _filtered_index)
+ {
+ _location_to_labels[new_location_start + loc_offset].swap(
+ _location_to_labels[old_location_start + loc_offset]);
+ }
+ }
+ // If ranges are overlapping, make sure not to clear the newly copied
+ // data.
+ if (mem_clear_loc_start < new_location_start + num_locations)
+ {
+ // Clear only after the end of the new range.
+ mem_clear_loc_start = new_location_start + num_locations;
+ }
+ }
+ else
+ {
+ // Old location after the new location: copy from the end of the range
+ // to avoid modifying locations that are yet to be copied.
+ for (uint32_t loc_offset = num_locations; loc_offset > 0; loc_offset--)
+ {
+ assert(_graph_store->get_neighbours(new_location_start + loc_offset - 1u).empty());
+ _graph_store->swap_neighbours(new_location_start + loc_offset - 1u, old_location_start + loc_offset - 1u);
+ if (_dynamic_index && _filtered_index)
+ {
+ _location_to_labels[new_location_start + loc_offset - 1u].swap(
+ _location_to_labels[old_location_start + loc_offset - 1u]);
+ }
+ }
+
+ // If ranges are overlapping, make sure not to clear the newly copied
+ // data.
+ if (mem_clear_loc_end_limit > new_location_start)
+ {
+ // Clear only up to the beginning of the new range.
+ mem_clear_loc_end_limit = new_location_start;
+ }
+ }
+ _data_store->move_vectors(old_location_start, new_location_start, num_locations);
+}
+
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::reposition_frozen_point_to_end()
+{
+ if (_num_frozen_pts == 0)
+ return;
+
+ if (_nd == _max_points)
+ {
+ diskann::cout << "Not repositioning frozen point as it is already at the end." << std::endl;
+ return;
+ }
+
+ reposition_points((uint32_t)_nd, (uint32_t)_max_points, (uint32_t)_num_frozen_pts);
+ _start = (uint32_t)_max_points;
+
+ // update medoid id's as frozen points are treated as medoid
+ if (_filtered_index && _dynamic_index)
+ {
+ for (auto &[label, medoid_id] : _label_to_start_id)
+ {
+ /*if (label == _universal_label)
+ continue;*/
+ _label_to_start_id[label] = (uint32_t)_max_points + (medoid_id - (uint32_t)_nd);
+ }
+ }
+}
+
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::resize(size_t new_max_points)
+{
+ const size_t new_internal_points = new_max_points + _num_frozen_pts;
+ auto start = std::chrono::high_resolution_clock::now();
+ assert(_empty_slots.size() == 0); // should not resize if there are empty slots.
+
+ _data_store->resize((location_t)new_internal_points);
+ _graph_store->resize_graph(new_internal_points);
+ _locks = std::vector<non_recursive_mutex>(new_internal_points);
+
+ if (_num_frozen_pts != 0)
+ {
+ reposition_points((uint32_t)_max_points, (uint32_t)new_max_points, (uint32_t)_num_frozen_pts);
+ _start = (uint32_t)new_max_points;
+ }
+
+ _max_points = new_max_points;
+ _empty_slots.reserve(_max_points);
+ for (auto i = _nd; i < _max_points; i++)
+ {
+ _empty_slots.insert((uint32_t)i);
+ }
+
+ auto stop = std::chrono::high_resolution_clock::now();
+ diskann::cout << "Resizing took: " << std::chrono::duration<double>(stop - start).count() << "s" << std::endl;
+}
+
+template <typename T, typename TagT, typename LabelT>
+int Index<T, TagT, LabelT>::_insert_point(const DataType &point, const TagType tag)
+{
+ try
+ {
+ return this->insert_point(std::any_cast<const T *>(point), std::any_cast<const TagT>(tag));
+ }
+ catch (const std::bad_any_cast &anycast_e)
+ {
+ throw new ANNException("Error:Trying to insert invalid data type" + std::string(anycast_e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw new ANNException("Error:" + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+int Index<T, TagT, LabelT>::_insert_point(const DataType &point, const TagType tag, Labelvector &labels)
+{
+ try
+ {
+ return this->insert_point(std::any_cast<const T *>(point), std::any_cast<const TagT>(tag),
+ labels.get<const std::vector<LabelT>>());
+ }
+ catch (const std::bad_any_cast &anycast_e)
+ {
+ throw new ANNException("Error:Trying to insert invalid data type" + std::string(anycast_e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw new ANNException("Error:" + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag)
+{
+ std::vector<LabelT> no_labels{0};
+ return insert_point(point, tag, no_labels);
+}
+
+template <typename T, typename TagT, typename LabelT>
+int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const std::vector<LabelT> &labels)
+{
+
+ assert(_has_built);
+ if (tag == 0)
+ {
+ throw diskann::ANNException("Do not insert point with tag 0. That is "
+ "reserved for points hidden "
+ "from the user.",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ std::shared_lock<std::shared_timed_mutex> shared_ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+
+ auto location = reserve_location();
+ if (_filtered_index)
+ {
+ if (labels.empty())
+ {
+ release_location(location);
+ std::cerr << "Error: Can't insert point with tag " + get_tag_string(tag) +
+ " . there are no labels for the point."
+ << std::endl;
+ return -1;
+ }
+
+ _location_to_labels[location] = labels;
+
+ for (LabelT label : labels)
+ {
+ if (_labels.find(label) == _labels.end())
+ {
+ if (_frozen_pts_used >= _num_frozen_pts)
+ {
+ throw ANNException(
+ "Error: For dynamic filtered index, the number of frozen points should be atleast equal "
+ "to number of unique labels.",
+ -1);
+ }
+
+ auto fz_location = (int)(_max_points) + _frozen_pts_used; // as first _fz_point
+ _labels.insert(label);
+ _label_to_start_id[label] = (uint32_t)fz_location;
+ _location_to_labels[fz_location] = {label};
+ _data_store->set_vector((location_t)fz_location, point);
+ _frozen_pts_used++;
+ }
+ }
+ }
+
+ if (location == -1)
+ {
+#if EXPAND_IF_FULL
+ dl.unlock();
+ tl.unlock();
+ shared_ul.unlock();
+
+ {
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+ tl.lock();
+ dl.lock();
+
+ if (_nd >= _max_points)
+ {
+ auto new_max_points = (size_t)(_max_points * INDEX_GROWTH_FACTOR);
+ resize(new_max_points);
+ }
+
+ dl.unlock();
+ tl.unlock();
+ ul.unlock();
+ }
+
+ shared_ul.lock();
+ tl.lock();
+ dl.lock();
+
+ location = reserve_location();
+ if (location == -1)
+ {
+ throw diskann::ANNException("Cannot reserve location even after "
+ "expanding graph. Terminating.",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+#else
+ return -1;
+#endif
+ } // cant insert as active pts >= max_pts
+ dl.unlock();
+
+ // Insert tag and mapping to location
+ if (_enable_tags)
+ {
+ // if tags are enabled and tag is already inserted. so we can't reuse that tag.
+ if (_tag_to_location.find(tag) != _tag_to_location.end())
+ {
+ release_location(location);
+ return -1;
+ }
+
+ _tag_to_location[tag] = location;
+ _location_to_tag.set(location, tag);
+ }
+ tl.unlock();
+
+ _data_store->set_vector(location, point); // update datastore
+
+ // Find and add appropriate graph edges
+ ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
+ auto scratch = manager.scratch_space();
+ std::vector<uint32_t> pruned_list; // it is the set best candidates to connect to this point
+ if (_filtered_index)
+ {
+ // when filtered the best_candidates will share the same label ( label_present > distance)
+ search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize);
+ }
+ else
+ {
+ search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch);
+ }
+ assert(pruned_list.size() > 0); // should find atleast one neighbour (i.e frozen point acting as medoid)
+
+ {
+ std::shared_lock<std::shared_timed_mutex> tlock(_tag_lock, std::defer_lock);
+ if (_conc_consolidate)
+ tlock.lock();
+
+ LockGuard guard(_locks[location]);
+ _graph_store->clear_neighbours(location);
+
+ std::vector<uint32_t> neighbor_links;
+ for (auto link : pruned_list)
+ {
+ if (_conc_consolidate)
+ if (!_location_to_tag.contains(link))
+ continue;
+ neighbor_links.emplace_back(link);
+ }
+ _graph_store->set_neighbours(location, neighbor_links);
+ assert(_graph_store->get_neighbours(location).size() <= _indexingRange);
+
+ if (_conc_consolidate)
+ tlock.unlock();
+ }
+
+ inter_insert(location, pruned_list, scratch);
+
+ return 0;
+}
+
+template <typename T, typename TagT, typename LabelT> int Index<T, TagT, LabelT>::_lazy_delete(const TagType &tag)
+{
+ try
+ {
+ return lazy_delete(std::any_cast<const TagT>(tag));
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException(std::string("Error: ") + e.what(), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::_lazy_delete(TagVector &tags, TagVector &failed_tags)
+{
+ try
+ {
+ this->lazy_delete(tags.get<const std::vector<TagT>>(), failed_tags.get<std::vector<TagT>>());
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException("Error: bad any cast while performing _lazy_delete() " + std::string(e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw ANNException("Error: " + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT> int Index<T, TagT, LabelT>::lazy_delete(const TagT &tag)
+{
+ std::shared_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+ _data_compacted = false;
+
+ if (_tag_to_location.find(tag) == _tag_to_location.end())
+ {
+ diskann::cerr << "Delete tag not found " << get_tag_string(tag) << std::endl;
+ return -1;
+ }
+ assert(_tag_to_location[tag] < _max_points);
+
+ const auto location = _tag_to_location[tag];
+ _delete_set->insert(location);
+ _location_to_tag.erase(location);
+ _tag_to_location.erase(tag);
+ return 0;
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::lazy_delete(const std::vector<TagT> &tags, std::vector<TagT> &failed_tags)
+{
+ if (failed_tags.size() > 0)
+ {
+ throw ANNException("failed_tags should be passed as an empty list", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ std::shared_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::unique_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::unique_lock<std::shared_timed_mutex> dl(_delete_lock);
+ _data_compacted = false;
+
+ for (auto tag : tags)
+ {
+ if (_tag_to_location.find(tag) == _tag_to_location.end())
+ {
+ failed_tags.push_back(tag);
+ }
+ else
+ {
+ const auto location = _tag_to_location[tag];
+ _delete_set->insert(location);
+ _location_to_tag.erase(location);
+ _tag_to_location.erase(tag);
+ }
+ }
+}
+
+template <typename T, typename TagT, typename LabelT> bool Index<T, TagT, LabelT>::is_index_saved()
+{
+ return _is_saved;
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::_get_active_tags(TagRobinSet &active_tags)
+{
+ try
+ {
+ this->get_active_tags(active_tags.get<tsl::robin_set<TagT>>());
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException("Error: bad_any cast while performing _get_active_tags() " + std::string(e.what()), -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw ANNException("Error :" + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::get_active_tags(tsl::robin_set<TagT> &active_tags)
+{
+ active_tags.clear();
+ std::shared_lock<std::shared_timed_mutex> tl(_tag_lock);
+ for (auto iter : _tag_to_location)
+ {
+ active_tags.insert(iter.first);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::print_status()
+{
+ std::shared_lock<std::shared_timed_mutex> ul(_update_lock);
+ std::shared_lock<std::shared_timed_mutex> cl(_consolidate_lock);
+ std::shared_lock<std::shared_timed_mutex> tl(_tag_lock);
+ std::shared_lock<std::shared_timed_mutex> dl(_delete_lock);
+
+ diskann::cout << "------------------- Index object: " << (uint64_t)this << " -------------------" << std::endl;
+ diskann::cout << "Number of points: " << _nd << std::endl;
+ diskann::cout << "Graph size: " << _graph_store->get_total_points() << std::endl;
+ diskann::cout << "Location to tag size: " << _location_to_tag.size() << std::endl;
+ diskann::cout << "Tag to location size: " << _tag_to_location.size() << std::endl;
+ diskann::cout << "Number of empty slots: " << _empty_slots.size() << std::endl;
+ diskann::cout << std::boolalpha << "Data compacted: " << this->_data_compacted << std::endl;
+ diskann::cout << "---------------------------------------------------------"
+ "------------"
+ << std::endl;
+}
+
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::count_nodes_at_bfs_levels()
+{
+ std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
+
+ boost::dynamic_bitset<> visited(_max_points + _num_frozen_pts);
+
+ size_t MAX_BFS_LEVELS = 32;
+ auto bfs_sets = new tsl::robin_set<uint32_t>[MAX_BFS_LEVELS];
+
+ bfs_sets[0].insert(_start);
+ visited.set(_start);
+
+ for (uint32_t i = (uint32_t)_max_points; i < _max_points + _num_frozen_pts; ++i)
+ {
+ if (i != _start)
+ {
+ bfs_sets[0].insert(i);
+ visited.set(i);
+ }
+ }
+
+ for (size_t l = 0; l < MAX_BFS_LEVELS - 1; ++l)
+ {
+ diskann::cout << "Number of nodes at BFS level " << l << " is " << bfs_sets[l].size() << std::endl;
+ if (bfs_sets[l].size() == 0)
+ break;
+ for (auto node : bfs_sets[l])
+ {
+ for (auto nghbr : _graph_store->get_neighbours((location_t)node))
+ {
+ if (!visited.test(nghbr))
+ {
+ visited.set(nghbr);
+ bfs_sets[l + 1].insert(nghbr);
+ }
+ }
+ }
+ }
+
+ delete[] bfs_sets;
+}
+
+// REFACTOR: This should be an OptimizedDataStore class
+template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::optimize_index_layout()
+{ // use after build or load
+ if (_dynamic_index)
+ {
+ throw diskann::ANNException("Optimize_index_layout not implemented for dyanmic indices", -1, __FUNCSIG__,
+ __FILE__, __LINE__);
+ }
+
+ float *cur_vec = new float[_data_store->get_aligned_dim()];
+ std::memset(cur_vec, 0, _data_store->get_aligned_dim() * sizeof(float));
+ _data_len = (_data_store->get_aligned_dim() + 1) * sizeof(float);
+ _neighbor_len = (_graph_store->get_max_observed_degree() + 1) * sizeof(uint32_t);
+ _node_size = _data_len + _neighbor_len;
+ _opt_graph = new char[_node_size * _nd];
+ auto dist_fast = (DistanceFastL2<T> *)(_data_store->get_dist_fn());
+ for (uint32_t i = 0; i < _nd; i++)
+ {
+ char *cur_node_offset = _opt_graph + i * _node_size;
+ _data_store->get_vector(i, (T *)cur_vec);
+ float cur_norm = dist_fast->norm((T *)cur_vec, (uint32_t)_data_store->get_aligned_dim());
+ std::memcpy(cur_node_offset, &cur_norm, sizeof(float));
+ std::memcpy(cur_node_offset + sizeof(float), cur_vec, _data_len - sizeof(float));
+
+ cur_node_offset += _data_len;
+ uint32_t k = (uint32_t)_graph_store->get_neighbours(i).size();
+ std::memcpy(cur_node_offset, &k, sizeof(uint32_t));
+ std::memcpy(cur_node_offset + sizeof(uint32_t), _graph_store->get_neighbours(i).data(), k * sizeof(uint32_t));
+ // std::vector<uint32_t>().swap(_graph_store->get_neighbours(i));
+ _graph_store->clear_neighbours(i);
+ }
+ _graph_store->clear_graph();
+ _graph_store->resize_graph(0);
+ delete[] cur_vec;
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::_search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices)
+{
+ try
+ {
+ return this->search_with_optimized_layout(std::any_cast<const T *>(query), K, L, indices);
+ }
+ catch (const std::bad_any_cast &e)
+ {
+ throw ANNException("Error: bad any cast while performing "
+ "_search_with_optimized_layout() " +
+ std::string(e.what()),
+ -1);
+ }
+ catch (const std::exception &e)
+ {
+ throw ANNException("Error: " + std::string(e.what()), -1);
+ }
+}
+
+template <typename T, typename TagT, typename LabelT>
+void Index<T, TagT, LabelT>::search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices)
+{
+ DistanceFastL2<T> *dist_fast = (DistanceFastL2<T> *)(_data_store->get_dist_fn());
+
+ NeighborPriorityQueue retset(L);
+ std::vector<uint32_t> init_ids(L);
+
+ boost::dynamic_bitset<> flags{_nd, 0};
+ uint32_t tmp_l = 0;
+ uint32_t *neighbors = (uint32_t *)(_opt_graph + _node_size * _start + _data_len);
+ uint32_t MaxM_ep = *neighbors;
+ neighbors++;
+
+ for (; tmp_l < L && tmp_l < MaxM_ep; tmp_l++)
+ {
+ init_ids[tmp_l] = neighbors[tmp_l];
+ flags[init_ids[tmp_l]] = true;
+ }
+
+ while (tmp_l < L)
+ {
+ uint32_t id = rand() % _nd;
+ if (flags[id])
+ continue;
+ flags[id] = true;
+ init_ids[tmp_l] = id;
+ tmp_l++;
+ }
+
+ for (uint32_t i = 0; i < init_ids.size(); i++)
+ {
+ uint32_t id = init_ids[i];
+ if (id >= _nd)
+ continue;
+ _mm_prefetch(_opt_graph + _node_size * id, _MM_HINT_T0);
+ }
+ L = 0;
+ for (uint32_t i = 0; i < init_ids.size(); i++)
+ {
+ uint32_t id = init_ids[i];
+ if (id >= _nd)
+ continue;
+ T *x = (T *)(_opt_graph + _node_size * id);
+ float norm_x = *x;
+ x++;
+ float dist = dist_fast->compare(x, query, norm_x, (uint32_t)_data_store->get_aligned_dim());
+ retset.insert(Neighbor(id, dist));
+ flags[id] = true;
+ L++;
+ }
+
+ while (retset.has_unexpanded_node())
+ {
+ auto nbr = retset.closest_unexpanded();
+ auto n = nbr.id;
+ _mm_prefetch(_opt_graph + _node_size * n + _data_len, _MM_HINT_T0);
+ neighbors = (uint32_t *)(_opt_graph + _node_size * n + _data_len);
+ uint32_t MaxM = *neighbors;
+ neighbors++;
+ for (uint32_t m = 0; m < MaxM; ++m)
+ _mm_prefetch(_opt_graph + _node_size * neighbors[m], _MM_HINT_T0);
+ for (uint32_t m = 0; m < MaxM; ++m)
+ {
+ uint32_t id = neighbors[m];
+ if (flags[id])
+ continue;
+ flags[id] = 1;
+ T *data = (T *)(_opt_graph + _node_size * id);
+ float norm = *data;
+ data++;
+ float dist = dist_fast->compare(query, data, norm, (uint32_t)_data_store->get_aligned_dim());
+ Neighbor nn(id, dist);
+ retset.insert(nn);
+ }
+ }
+
+ for (size_t i = 0; i < K; i++)
+ {
+ indices[i] = retset[i].id;
+ }
+}
+
+/* Internals of the library */
+template <typename T, typename TagT, typename LabelT> const float Index<T, TagT, LabelT>::INDEX_GROWTH_FACTOR = 1.5f;
+
+// EXPORTS
+template DISKANN_DLLEXPORT class Index<float, int32_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, int32_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, int32_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<float, uint32_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, uint32_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, uint32_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<float, int64_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, int64_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, int64_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<float, uint64_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, uint64_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, uint64_t, uint32_t>;
+template DISKANN_DLLEXPORT class Index<float, tag_uint128, uint32_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, tag_uint128, uint32_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, tag_uint128, uint32_t>;
+// Label with short int 2 byte
+template DISKANN_DLLEXPORT class Index<float, int32_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, int32_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, int32_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<float, uint32_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, uint32_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, uint32_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<float, int64_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, int64_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, int64_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<float, uint64_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, uint64_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, uint64_t, uint16_t>;
+template DISKANN_DLLEXPORT class Index<float, tag_uint128, uint16_t>;
+template DISKANN_DLLEXPORT class Index<int8_t, tag_uint128, uint16_t>;
+template DISKANN_DLLEXPORT class Index<uint8_t, tag_uint128, uint16_t>;
+
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search<uint64_t>(
+ const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search<uint32_t>(
+ const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search<uint64_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search<uint32_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search<uint64_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search<uint32_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+// TagT==uint32_t
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search<uint64_t>(
+ const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search<uint32_t>(
+ const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search<uint64_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search<uint32_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search<uint64_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search<uint32_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
+ uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
+ uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search_with_filters<
+ uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search_with_filters<
+ uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search_with_filters<
+ uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search_with_filters<
+ uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+// TagT==uint32_t
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search_with_filters<
+ uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search_with_filters<
+ uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search_with_filters<
+ uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search_with_filters<
+ uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
+ uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
+ uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search<uint64_t>(
+ const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search<uint32_t>(
+ const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search<uint64_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search<uint32_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search<uint64_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search<uint32_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+// TagT==uint32_t
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search<uint64_t>(
+ const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search<uint32_t>(
+ const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search<uint64_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search<uint32_t>(
+ const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search<uint64_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search<uint32_t>(
+ const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
+
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
+ uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
+ uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search_with_filters<
+ uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search_with_filters<
+ uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search_with_filters<
+ uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search_with_filters<
+ uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+// TagT==uint32_t
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search_with_filters<
+ uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search_with_filters<
+ uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search_with_filters<
+ uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search_with_filters<
+ uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
+ uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
+ float *distances);
+template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
+ uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
+ float *distances);
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/index_factory.cpp b/be/src/extern/diskann/src/index_factory.cpp
new file mode 100644
index 0000000..35790f8
--- /dev/null
+++ b/be/src/extern/diskann/src/index_factory.cpp
@@ -0,0 +1,213 @@
+#include "index_factory.h"
+#include "pq_l2_distance.h"
+
+namespace diskann
+{
+
+IndexFactory::IndexFactory(const IndexConfig &config) : _config(std::make_unique<IndexConfig>(config))
+{
+ check_config();
+}
+
+std::unique_ptr<AbstractIndex> IndexFactory::create_instance()
+{
+ return create_instance(_config->data_type, _config->tag_type, _config->label_type);
+}
+
+void IndexFactory::check_config()
+{
+ if (_config->dynamic_index && !_config->enable_tags)
+ {
+ throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (_config->pq_dist_build)
+ {
+ if (_config->dynamic_index)
+ throw ANNException("ERROR: Dynamic Indexing not supported with PQ distance based "
+ "index construction",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ if (_config->metric == diskann::Metric::INNER_PRODUCT)
+ throw ANNException("ERROR: Inner product metrics not yet supported "
+ "with PQ distance "
+ "base index",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ if (_config->data_type != "float" && _config->data_type != "uint8" && _config->data_type != "int8")
+ {
+ throw ANNException("ERROR: invalid data type : + " + _config->data_type +
+ " is not supported. please select from [float, int8, uint8]",
+ -1);
+ }
+
+ if (_config->tag_type != "int32" && _config->tag_type != "uint32" && _config->tag_type != "int64" &&
+ _config->tag_type != "uint64")
+ {
+ throw ANNException("ERROR: invalid data type : + " + _config->tag_type +
+ " is not supported. please select from [int32, uint32, int64, uint64]",
+ -1);
+ }
+}
+
+template <typename T> Distance<T> *IndexFactory::construct_inmem_distance_fn(Metric metric)
+{
+ if (metric == diskann::Metric::COSINE && std::is_same<T, float>::value)
+ {
+ return (Distance<T> *)new AVXNormalizedCosineDistanceFloat();
+ }
+ else
+ {
+ return (Distance<T> *)get_distance_function<T>(metric);
+ }
+}
+
+template <typename T>
+std::shared_ptr<AbstractDataStore<T>> IndexFactory::construct_datastore(DataStoreStrategy strategy,
+ size_t total_internal_points, size_t dimension,
+ Metric metric)
+{
+ std::unique_ptr<Distance<T>> distance;
+ switch (strategy)
+ {
+ case DataStoreStrategy::MEMORY:
+ distance.reset(construct_inmem_distance_fn<T>(metric));
+ return std::make_shared<diskann::InMemDataStore<T>>((location_t)total_internal_points, dimension,
+ std::move(distance));
+ default:
+ break;
+ }
+ return nullptr;
+}
+
+std::unique_ptr<AbstractGraphStore> IndexFactory::construct_graphstore(const GraphStoreStrategy strategy,
+ const size_t size,
+ const size_t reserve_graph_degree)
+{
+ switch (strategy)
+ {
+ case GraphStoreStrategy::MEMORY:
+ return std::make_unique<InMemGraphStore>(size, reserve_graph_degree);
+ default:
+ throw ANNException("Error : Current GraphStoreStratagy is not supported.", -1);
+ }
+}
+
+template <typename T>
+std::shared_ptr<PQDataStore<T>> IndexFactory::construct_pq_datastore(DataStoreStrategy strategy, size_t num_points,
+ size_t dimension, Metric m, size_t num_pq_chunks,
+ bool use_opq)
+{
+ std::unique_ptr<Distance<T>> distance_fn;
+ std::unique_ptr<QuantizedDistance<T>> quantized_distance_fn;
+
+ quantized_distance_fn = std::move(std::make_unique<PQL2Distance<T>>((uint32_t)num_pq_chunks, use_opq));
+ switch (strategy)
+ {
+ case DataStoreStrategy::MEMORY:
+ distance_fn.reset(construct_inmem_distance_fn<T>(m));
+ return std::make_shared<diskann::PQDataStore<T>>(dimension, (location_t)(num_points), num_pq_chunks,
+ std::move(distance_fn), std::move(quantized_distance_fn));
+ default:
+ // REFACTOR TODO: We do support diskPQ - so we may need to add a new class for SSDPQDataStore!
+ break;
+ }
+ return nullptr;
+}
+
+template <typename data_type, typename tag_type, typename label_type>
+std::unique_ptr<AbstractIndex> IndexFactory::create_instance()
+{
+ size_t num_points = _config->max_points + _config->num_frozen_pts;
+ size_t dim = _config->dimension;
+ // auto graph_store = construct_graphstore(_config->graph_strategy, num_points);
+ auto data_store = construct_datastore<data_type>(_config->data_strategy, num_points, dim, _config->metric);
+ std::shared_ptr<AbstractDataStore<data_type>> pq_data_store = nullptr;
+
+ if (_config->data_strategy == DataStoreStrategy::MEMORY && _config->pq_dist_build)
+ {
+ pq_data_store =
+ construct_pq_datastore<data_type>(_config->data_strategy, num_points + _config->num_frozen_pts, dim,
+ _config->metric, _config->num_pq_chunks, _config->use_opq);
+ }
+ else
+ {
+ pq_data_store = data_store;
+ }
+ size_t max_reserve_degree =
+ (size_t)(defaults::GRAPH_SLACK_FACTOR * 1.05 *
+ (_config->index_write_params == nullptr ? 0 : _config->index_write_params->max_degree));
+ std::unique_ptr<AbstractGraphStore> graph_store =
+ construct_graphstore(_config->graph_strategy, num_points + _config->num_frozen_pts, max_reserve_degree);
+
+ // REFACTOR TODO: Must construct in-memory PQDatastore if strategy == ONDISK and must construct
+ // in-mem and on-disk PQDataStore if strategy == ONDISK and diskPQ is required.
+ return std::make_unique<diskann::Index<data_type, tag_type, label_type>>(*_config, data_store,
+ std::move(graph_store), pq_data_store);
+}
+
+std::unique_ptr<AbstractIndex> IndexFactory::create_instance(const std::string &data_type, const std::string &tag_type,
+ const std::string &label_type)
+{
+ if (data_type == std::string("float"))
+ {
+ return create_instance<float>(tag_type, label_type);
+ }
+ else if (data_type == std::string("uint8"))
+ {
+ return create_instance<uint8_t>(tag_type, label_type);
+ }
+ else if (data_type == std::string("int8"))
+ {
+ return create_instance<int8_t>(tag_type, label_type);
+ }
+ else
+ throw ANNException("Error: unsupported data_type please choose from [float/int8/uint8]", -1);
+}
+
+template <typename data_type>
+std::unique_ptr<AbstractIndex> IndexFactory::create_instance(const std::string &tag_type, const std::string &label_type)
+{
+ if (tag_type == std::string("int32"))
+ {
+ return create_instance<data_type, int32_t>(label_type);
+ }
+ else if (tag_type == std::string("uint32"))
+ {
+ return create_instance<data_type, uint32_t>(label_type);
+ }
+ else if (tag_type == std::string("int64"))
+ {
+ return create_instance<data_type, int64_t>(label_type);
+ }
+ else if (tag_type == std::string("uint64"))
+ {
+ return create_instance<data_type, uint64_t>(label_type);
+ }
+ else
+ throw ANNException("Error: unsupported tag_type please choose from [int32/uint32/int64/uint64]", -1);
+}
+
+template <typename data_type, typename tag_type>
+std::unique_ptr<AbstractIndex> IndexFactory::create_instance(const std::string &label_type)
+{
+ if (label_type == std::string("uint16") || label_type == std::string("ushort"))
+ {
+ return create_instance<data_type, tag_type, uint16_t>();
+ }
+ else if (label_type == std::string("uint32") || label_type == std::string("uint"))
+ {
+ return create_instance<data_type, tag_type, uint32_t>();
+ }
+ else
+ throw ANNException("Error: unsupported label_type please choose from [uint/ushort]", -1);
+}
+
+// template DISKANN_DLLEXPORT std::shared_ptr<AbstractDataStore<uint8_t>> IndexFactory::construct_datastore(
+// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m);
+// template DISKANN_DLLEXPORT std::shared_ptr<AbstractDataStore<int8_t>> IndexFactory::construct_datastore(
+// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m);
+// template DISKANN_DLLEXPORT std::shared_ptr<AbstractDataStore<float>> IndexFactory::construct_datastore(
+// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m);
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/linux_aligned_file_reader.cpp b/be/src/extern/diskann/src/linux_aligned_file_reader.cpp
new file mode 100644
index 0000000..31bf5f8
--- /dev/null
+++ b/be/src/extern/diskann/src/linux_aligned_file_reader.cpp
@@ -0,0 +1,228 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include "linux_aligned_file_reader.h"
+
+#include <cassert>
+#include <cstdio>
+#include <iostream>
+#include "tsl/robin_map.h"
+#include "utils.h"
+#define MAX_EVENTS 1024
+
+namespace
+{
+typedef struct io_event io_event_t;
+typedef struct iocb iocb_t;
+
+void execute_io(io_context_t ctx, int fd, std::vector<AlignedRead> &read_reqs, uint64_t n_retries = 0)
+{
+#ifdef DEBUG
+ for (auto &req : read_reqs)
+ {
+ assert(IS_ALIGNED(req.len, 512));
+ // std::cout << "request:"<<req.offset<<":"<<req.len << std::endl;
+ assert(IS_ALIGNED(req.offset, 512));
+ assert(IS_ALIGNED(req.buf, 512));
+ // assert(malloc_usable_size(req.buf) >= req.len);
+ }
+#endif
+
+ // break-up requests into chunks of size MAX_EVENTS each
+ uint64_t n_iters = ROUND_UP(read_reqs.size(), MAX_EVENTS) / MAX_EVENTS;
+ for (uint64_t iter = 0; iter < n_iters; iter++)
+ {
+ uint64_t n_ops = std::min((uint64_t)read_reqs.size() - (iter * MAX_EVENTS), (uint64_t)MAX_EVENTS);
+ std::vector<iocb_t *> cbs(n_ops, nullptr);
+ std::vector<io_event_t> evts(n_ops);
+ std::vector<struct iocb> cb(n_ops);
+ for (uint64_t j = 0; j < n_ops; j++)
+ {
+ io_prep_pread(cb.data() + j, fd, read_reqs[j + iter * MAX_EVENTS].buf, read_reqs[j + iter * MAX_EVENTS].len,
+ read_reqs[j + iter * MAX_EVENTS].offset);
+ }
+
+ // initialize `cbs` using `cb` array
+ //
+
+ for (uint64_t i = 0; i < n_ops; i++)
+ {
+ cbs[i] = cb.data() + i;
+ }
+
+ uint64_t n_tries = 0;
+ while (n_tries <= n_retries)
+ {
+ // issue reads
+ int64_t ret = io_submit(ctx, (int64_t)n_ops, cbs.data());
+ // if requests didn't get accepted
+ if (ret != (int64_t)n_ops)
+ {
+ std::cerr << "io_submit() failed; returned " << ret << ", expected=" << n_ops << ", ernno=" << errno
+ << "=" << ::strerror(-ret) << ", try #" << n_tries + 1;
+ std::cout << "ctx: " << ctx << "\n";
+ exit(-1);
+ }
+ else
+ {
+ // wait on io_getevents
+ ret = io_getevents(ctx, (int64_t)n_ops, (int64_t)n_ops, evts.data(), nullptr);
+ // if requests didn't complete
+ if (ret != (int64_t)n_ops)
+ {
+ std::cerr << "io_getevents() failed; returned " << ret << ", expected=" << n_ops
+ << ", ernno=" << errno << "=" << ::strerror(-ret) << ", try #" << n_tries + 1;
+ exit(-1);
+ }
+ else
+ {
+ break;
+ }
+ }
+ }
+ // disabled since req.buf could be an offset into another buf
+ /*
+ for (auto &req : read_reqs) {
+ // corruption check
+ assert(malloc_usable_size(req.buf) >= req.len);
+ }
+ */
+ }
+}
+} // namespace
+
+LinuxAlignedFileReader::LinuxAlignedFileReader()
+{
+ this->file_desc = -1;
+}
+
+LinuxAlignedFileReader::~LinuxAlignedFileReader()
+{
+ int64_t ret;
+ // check to make sure file_desc is closed
+ ret = ::fcntl(this->file_desc, F_GETFD);
+ if (ret == -1)
+ {
+ if (errno != EBADF)
+ {
+ std::cerr << "close() not called" << std::endl;
+ // close file desc
+ ret = ::close(this->file_desc);
+ // error checks
+ if (ret == -1)
+ {
+ std::cerr << "close() failed; returned " << ret << ", errno=" << errno << ":" << ::strerror(errno)
+ << std::endl;
+ }
+ }
+ }
+}
+
+io_context_t &LinuxAlignedFileReader::get_ctx()
+{
+ std::unique_lock<std::mutex> lk(ctx_mut);
+ // perform checks only in DEBUG mode
+ if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end())
+ {
+ std::cerr << "bad thread access; returning -1 as io_context_t" << std::endl;
+ return this->bad_ctx;
+ }
+ else
+ {
+ return ctx_map[std::this_thread::get_id()];
+ }
+}
+
+void LinuxAlignedFileReader::register_thread()
+{
+ auto my_id = std::this_thread::get_id();
+ std::unique_lock<std::mutex> lk(ctx_mut);
+ if (ctx_map.find(my_id) != ctx_map.end())
+ {
+ std::cerr << "multiple calls to register_thread from the same thread" << std::endl;
+ return;
+ }
+ io_context_t ctx = 0;
+ int ret = io_setup(MAX_EVENTS, &ctx);
+ if (ret != 0)
+ {
+ lk.unlock();
+ if (ret == -EAGAIN)
+ {
+ std::cerr << "io_setup() failed with EAGAIN: Consider increasing /proc/sys/fs/aio-max-nr" << std::endl;
+ }
+ else
+ {
+ std::cerr << "io_setup() failed; returned " << ret << ": " << ::strerror(-ret) << std::endl;
+ }
+ }
+ else
+ {
+ diskann::cout << "allocating ctx: " << ctx << " to thread-id:" << my_id << std::endl;
+ ctx_map[my_id] = ctx;
+ }
+ lk.unlock();
+}
+
+void LinuxAlignedFileReader::deregister_thread()
+{
+ auto my_id = std::this_thread::get_id();
+ std::unique_lock<std::mutex> lk(ctx_mut);
+ assert(ctx_map.find(my_id) != ctx_map.end());
+
+ lk.unlock();
+ io_context_t ctx = this->get_ctx();
+ io_destroy(ctx);
+ // assert(ret == 0);
+ lk.lock();
+ ctx_map.erase(my_id);
+ std::cerr << "returned ctx from thread-id:" << my_id << std::endl;
+ lk.unlock();
+}
+
+void LinuxAlignedFileReader::deregister_all_threads()
+{
+ std::unique_lock<std::mutex> lk(ctx_mut);
+ for (auto x = ctx_map.begin(); x != ctx_map.end(); x++)
+ {
+ io_context_t ctx = x.value();
+ io_destroy(ctx);
+ // assert(ret == 0);
+ // lk.lock();
+ // ctx_map.erase(my_id);
+ // std::cerr << "returned ctx from thread-id:" << my_id << std::endl;
+ }
+ ctx_map.clear();
+ // lk.unlock();
+}
+
+void LinuxAlignedFileReader::open(const std::string &fname)
+{
+ int flags = O_DIRECT | O_RDONLY | O_LARGEFILE;
+ this->file_desc = ::open(fname.c_str(), flags);
+ // error checks
+ assert(this->file_desc != -1);
+ std::cerr << "Opened file : " << fname << std::endl;
+}
+
+void LinuxAlignedFileReader::close()
+{
+ // int64_t ret;
+
+ // check to make sure file_desc is closed
+ ::fcntl(this->file_desc, F_GETFD);
+ // assert(ret != -1);
+
+ ::close(this->file_desc);
+ // assert(ret != -1);
+}
+
+void LinuxAlignedFileReader::read(std::vector<AlignedRead> &read_reqs, io_context_t &ctx, bool async)
+{
+ if (async == true)
+ {
+ diskann::cout << "Async currently not supported in linux." << std::endl;
+ }
+ assert(this->file_desc != -1);
+ execute_io(ctx, this->file_desc, read_reqs);
+}
diff --git a/be/src/extern/diskann/src/logger.cpp b/be/src/extern/diskann/src/logger.cpp
new file mode 100644
index 0000000..052f548
--- /dev/null
+++ b/be/src/extern/diskann/src/logger.cpp
@@ -0,0 +1,97 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <cstring>
+#include <iostream>
+
+#include "logger_impl.h"
+#include "windows_customizations.h"
+
+namespace diskann
+{
+
+#ifdef ENABLE_CUSTOM_LOGGER
+DISKANN_DLLEXPORT ANNStreamBuf coutBuff(stdout);
+DISKANN_DLLEXPORT ANNStreamBuf cerrBuff(stderr);
+
+DISKANN_DLLEXPORT std::basic_ostream<char> cout(&coutBuff);
+DISKANN_DLLEXPORT std::basic_ostream<char> cerr(&cerrBuff);
+std::function<void(LogLevel, const char *)> g_logger;
+
+void SetCustomLogger(std::function<void(LogLevel, const char *)> logger)
+{
+ g_logger = logger;
+ diskann::cout << "Set Custom Logger" << std::endl;
+}
+
+ANNStreamBuf::ANNStreamBuf(FILE *fp)
+{
+ if (fp == nullptr)
+ {
+ throw diskann::ANNException("File pointer passed to ANNStreamBuf() cannot be null", -1);
+ }
+ if (fp != stdout && fp != stderr)
+ {
+ throw diskann::ANNException("The custom logger only supports stdout and stderr.", -1);
+ }
+ _fp = fp;
+ _logLevel = (_fp == stdout) ? LogLevel::LL_Info : LogLevel::LL_Error;
+ _buf = new char[BUFFER_SIZE + 1]; // See comment in the header
+
+ std::memset(_buf, 0, (BUFFER_SIZE) * sizeof(char));
+ setp(_buf, _buf + BUFFER_SIZE - 1);
+}
+
+ANNStreamBuf::~ANNStreamBuf()
+{
+ sync();
+ _fp = nullptr; // we'll not close because we can't.
+ delete[] _buf;
+}
+
+int ANNStreamBuf::overflow(int c)
+{
+ std::lock_guard<std::mutex> lock(_mutex);
+ if (c != EOF)
+ {
+ *pptr() = (char)c;
+ pbump(1);
+ }
+ flush();
+ return c;
+}
+
+int ANNStreamBuf::sync()
+{
+ std::lock_guard<std::mutex> lock(_mutex);
+ flush();
+ return 0;
+}
+
+int ANNStreamBuf::underflow()
+{
+ throw diskann::ANNException("Attempt to read on streambuf meant only for writing.", -1);
+}
+
+int ANNStreamBuf::flush()
+{
+ const int num = (int)(pptr() - pbase());
+ logImpl(pbase(), num);
+ pbump(-num);
+ return num;
+}
+void ANNStreamBuf::logImpl(char *str, int num)
+{
+ str[num] = '\0'; // Safe. See the c'tor.
+ // Invoke the OLS custom logging function.
+ if (g_logger)
+ {
+ g_logger(_logLevel, str);
+ }
+}
+#else
+using std::cerr;
+using std::cout;
+#endif
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/math_utils.cpp b/be/src/extern/diskann/src/math_utils.cpp
new file mode 100644
index 0000000..7481da8
--- /dev/null
+++ b/be/src/extern/diskann/src/math_utils.cpp
@@ -0,0 +1,458 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <limits>
+#include <malloc.h>
+#include <math_utils.h>
+#include <mkl.h>
+#include "logger.h"
+#include "utils.h"
+
+namespace math_utils
+{
+
+float calc_distance(float *vec_1, float *vec_2, size_t dim)
+{
+ float dist = 0;
+ for (size_t j = 0; j < dim; j++)
+ {
+ dist += (vec_1[j] - vec_2[j]) * (vec_1[j] - vec_2[j]);
+ }
+ return dist;
+}
+
+// compute l2-squared norms of data stored in row major num_points * dim,
+// needs
+// to be pre-allocated
+void compute_vecs_l2sq(float *vecs_l2sq, float *data, const size_t num_points, const size_t dim)
+{
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t n_iter = 0; n_iter < (int64_t)num_points; n_iter++)
+ {
+ vecs_l2sq[n_iter] = cblas_snrm2((MKL_INT)dim, (data + (n_iter * dim)), 1);
+ vecs_l2sq[n_iter] *= vecs_l2sq[n_iter];
+ }
+}
+
+void rotate_data_randomly(float *data, size_t num_points, size_t dim, float *rot_mat, float *&new_mat,
+ bool transpose_rot)
+{
+ CBLAS_TRANSPOSE transpose = CblasNoTrans;
+ if (transpose_rot)
+ {
+ diskann::cout << "Transposing rotation matrix.." << std::flush;
+ transpose = CblasTrans;
+ }
+ diskann::cout << "done Rotating data with random matrix.." << std::flush;
+
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, transpose, (MKL_INT)num_points, (MKL_INT)dim, (MKL_INT)dim, 1.0, data,
+ (MKL_INT)dim, rot_mat, (MKL_INT)dim, 0, new_mat, (MKL_INT)dim);
+
+ diskann::cout << "done." << std::endl;
+}
+
+// calculate k closest centers to data of num_points * dim (row major)
+// centers is num_centers * dim (row major)
+// data_l2sq has pre-computed squared norms of data
+// centers_l2sq has pre-computed squared norms of centers
+// pre-allocated center_index will contain id of nearest center
+// pre-allocated dist_matrix shound be num_points * num_centers and contain
+// squared distances
+// Default value of k is 1
+
+// Ideally used only by compute_closest_centers
+void compute_closest_centers_in_block(const float *const data, const size_t num_points, const size_t dim,
+ const float *const centers, const size_t num_centers,
+ const float *const docs_l2sq, const float *const centers_l2sq,
+ uint32_t *center_index, float *const dist_matrix, size_t k)
+{
+ if (k > num_centers)
+ {
+ diskann::cout << "ERROR: k (" << k << ") > num_center(" << num_centers << ")" << std::endl;
+ return;
+ }
+
+ float *ones_a = new float[num_centers];
+ float *ones_b = new float[num_points];
+
+ for (size_t i = 0; i < num_centers; i++)
+ {
+ ones_a[i] = 1.0;
+ }
+ for (size_t i = 0; i < num_points; i++)
+ {
+ ones_b[i] = 1.0;
+ }
+
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, (MKL_INT)num_centers, (MKL_INT)1, 1.0f,
+ docs_l2sq, (MKL_INT)1, ones_a, (MKL_INT)1, 0.0f, dist_matrix, (MKL_INT)num_centers);
+
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, (MKL_INT)num_centers, (MKL_INT)1, 1.0f,
+ ones_b, (MKL_INT)1, centers_l2sq, (MKL_INT)1, 1.0f, dist_matrix, (MKL_INT)num_centers);
+
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, (MKL_INT)num_centers, (MKL_INT)dim, -2.0f,
+ data, (MKL_INT)dim, centers, (MKL_INT)dim, 1.0f, dist_matrix, (MKL_INT)num_centers);
+
+ if (k == 1)
+ {
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t i = 0; i < (int64_t)num_points; i++)
+ {
+ float min = std::numeric_limits<float>::max();
+ float *current = dist_matrix + (i * num_centers);
+ for (size_t j = 0; j < num_centers; j++)
+ {
+ if (current[j] < min)
+ {
+ center_index[i] = (uint32_t)j;
+ min = current[j];
+ }
+ }
+ }
+ }
+ else
+ {
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t i = 0; i < (int64_t)num_points; i++)
+ {
+ std::priority_queue<PivotContainer> top_k_queue;
+ float *current = dist_matrix + (i * num_centers);
+ for (size_t j = 0; j < num_centers; j++)
+ {
+ PivotContainer this_piv(j, current[j]);
+ top_k_queue.push(this_piv);
+ }
+ for (size_t j = 0; j < k; j++)
+ {
+ PivotContainer this_piv = top_k_queue.top();
+ center_index[i * k + j] = (uint32_t)this_piv.piv_id;
+ top_k_queue.pop();
+ }
+ }
+ }
+ delete[] ones_a;
+ delete[] ones_b;
+}
+
+// Given data in num_points * new_dim row major
+// Pivots stored in full_pivot_data as num_centers * new_dim row major
+// Calculate the k closest pivot for each point and store it in vector
+// closest_centers_ivf (row major, num_points*k) (which needs to be allocated
+// outside) Additionally, if inverted index is not null (and pre-allocated),
+// it
+// will return inverted index for each center, assuming each of the inverted
+// indices is an empty vector. Additionally, if pts_norms_squared is not null,
+// then it will assume that point norms are pre-computed and use those values
+
+void compute_closest_centers(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers,
+ size_t k, uint32_t *closest_centers_ivf, std::vector<size_t> *inverted_index,
+ float *pts_norms_squared)
+{
+ if (k > num_centers)
+ {
+ diskann::cout << "ERROR: k (" << k << ") > num_center(" << num_centers << ")" << std::endl;
+ return;
+ }
+
+ bool is_norm_given_for_pts = (pts_norms_squared != NULL);
+
+ float *pivs_norms_squared = new float[num_centers];
+ if (!is_norm_given_for_pts)
+ pts_norms_squared = new float[num_points];
+
+ size_t PAR_BLOCK_SIZE = num_points;
+ size_t N_BLOCKS =
+ (num_points % PAR_BLOCK_SIZE) == 0 ? (num_points / PAR_BLOCK_SIZE) : (num_points / PAR_BLOCK_SIZE) + 1;
+
+ if (!is_norm_given_for_pts)
+ math_utils::compute_vecs_l2sq(pts_norms_squared, data, num_points, dim);
+ math_utils::compute_vecs_l2sq(pivs_norms_squared, pivot_data, num_centers, dim);
+ uint32_t *closest_centers = new uint32_t[PAR_BLOCK_SIZE * k];
+ float *distance_matrix = new float[num_centers * PAR_BLOCK_SIZE];
+
+ for (size_t cur_blk = 0; cur_blk < N_BLOCKS; cur_blk++)
+ {
+ float *data_cur_blk = data + cur_blk * PAR_BLOCK_SIZE * dim;
+ size_t num_pts_blk = std::min(PAR_BLOCK_SIZE, num_points - cur_blk * PAR_BLOCK_SIZE);
+ float *pts_norms_blk = pts_norms_squared + cur_blk * PAR_BLOCK_SIZE;
+
+ math_utils::compute_closest_centers_in_block(data_cur_blk, num_pts_blk, dim, pivot_data, num_centers,
+ pts_norms_blk, pivs_norms_squared, closest_centers,
+ distance_matrix, k);
+
+#pragma omp parallel for schedule(static, 1)
+ for (int64_t j = cur_blk * PAR_BLOCK_SIZE;
+ j < std::min((int64_t)num_points, (int64_t)((cur_blk + 1) * PAR_BLOCK_SIZE)); j++)
+ {
+ for (size_t l = 0; l < k; l++)
+ {
+ size_t this_center_id = closest_centers[(j - cur_blk * PAR_BLOCK_SIZE) * k + l];
+ closest_centers_ivf[j * k + l] = (uint32_t)this_center_id;
+ if (inverted_index != NULL)
+ {
+#pragma omp critical
+ inverted_index[this_center_id].push_back(j);
+ }
+ }
+ }
+ }
+ delete[] closest_centers;
+ delete[] distance_matrix;
+ delete[] pivs_norms_squared;
+ if (!is_norm_given_for_pts)
+ delete[] pts_norms_squared;
+}
+
+// if to_subtract is 1, will subtract nearest center from each row. Else will
+// add. Output will be in data_load iself.
+// Nearest centers need to be provided in closst_centers.
+void process_residuals(float *data_load, size_t num_points, size_t dim, float *cur_pivot_data, size_t num_centers,
+ uint32_t *closest_centers, bool to_subtract)
+{
+ diskann::cout << "Processing residuals of " << num_points << " points in " << dim << " dimensions using "
+ << num_centers << " centers " << std::endl;
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t n_iter = 0; n_iter < (int64_t)num_points; n_iter++)
+ {
+ for (size_t d_iter = 0; d_iter < dim; d_iter++)
+ {
+ if (to_subtract == 1)
+ data_load[n_iter * dim + d_iter] =
+ data_load[n_iter * dim + d_iter] - cur_pivot_data[closest_centers[n_iter] * dim + d_iter];
+ else
+ data_load[n_iter * dim + d_iter] =
+ data_load[n_iter * dim + d_iter] + cur_pivot_data[closest_centers[n_iter] * dim + d_iter];
+ }
+ }
+}
+
+} // namespace math_utils
+
+namespace kmeans
+{
+
+// run Lloyds one iteration
+// Given data in row major num_points * dim, and centers in row major
+// num_centers * dim And squared lengths of data points, output the closest
+// center to each data point, update centers, and also return inverted index.
+// If
+// closest_centers == NULL, will allocate memory and return. Similarly, if
+// closest_docs == NULL, will allocate memory and return.
+
+float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, size_t num_centers, float *docs_l2sq,
+ std::vector<size_t> *closest_docs, uint32_t *&closest_center)
+{
+ bool compute_residual = true;
+ // Timer timer;
+
+ if (closest_center == NULL)
+ closest_center = new uint32_t[num_points];
+ if (closest_docs == NULL)
+ closest_docs = new std::vector<size_t>[num_centers];
+ else
+ for (size_t c = 0; c < num_centers; ++c)
+ closest_docs[c].clear();
+
+ math_utils::compute_closest_centers(data, num_points, dim, centers, num_centers, 1, closest_center, closest_docs,
+ docs_l2sq);
+
+ memset(centers, 0, sizeof(float) * (size_t)num_centers * (size_t)dim);
+
+#pragma omp parallel for schedule(static, 1)
+ for (int64_t c = 0; c < (int64_t)num_centers; ++c)
+ {
+ float *center = centers + (size_t)c * (size_t)dim;
+ double *cluster_sum = new double[dim];
+ for (size_t i = 0; i < dim; i++)
+ cluster_sum[i] = 0.0;
+ for (size_t i = 0; i < closest_docs[c].size(); i++)
+ {
+ float *current = data + ((closest_docs[c][i]) * dim);
+ for (size_t j = 0; j < dim; j++)
+ {
+ cluster_sum[j] += (double)current[j];
+ }
+ }
+ if (closest_docs[c].size() > 0)
+ {
+ for (size_t i = 0; i < dim; i++)
+ center[i] = (float)(cluster_sum[i] / ((double)closest_docs[c].size()));
+ }
+ delete[] cluster_sum;
+ }
+
+ float residual = 0.0;
+ if (compute_residual)
+ {
+ size_t BUF_PAD = 32;
+ size_t CHUNK_SIZE = 2 * 8192;
+ size_t nchunks = num_points / CHUNK_SIZE + (num_points % CHUNK_SIZE == 0 ? 0 : 1);
+ std::vector<float> residuals(nchunks * BUF_PAD, 0.0);
+
+#pragma omp parallel for schedule(static, 32)
+ for (int64_t chunk = 0; chunk < (int64_t)nchunks; ++chunk)
+ for (size_t d = chunk * CHUNK_SIZE; d < num_points && d < (chunk + 1) * CHUNK_SIZE; ++d)
+ residuals[chunk * BUF_PAD] +=
+ math_utils::calc_distance(data + (d * dim), centers + (size_t)closest_center[d] * (size_t)dim, dim);
+
+ for (size_t chunk = 0; chunk < nchunks; ++chunk)
+ residual += residuals[chunk * BUF_PAD];
+ }
+
+ return residual;
+}
+
+// Run Lloyds until max_reps or stopping criterion
+// If you pass NULL for closest_docs and closest_center, it will NOT return
+// the
+// results, else it will assume appriate allocation as closest_docs = new
+// vector<size_t> [num_centers], and closest_center = new size_t[num_points]
+// Final centers are output in centers as row major num_centers * dim
+//
+float run_lloyds(float *data, size_t num_points, size_t dim, float *centers, const size_t num_centers,
+ const size_t max_reps, std::vector<size_t> *closest_docs, uint32_t *closest_center)
+{
+ float residual = std::numeric_limits<float>::max();
+ bool ret_closest_docs = true;
+ bool ret_closest_center = true;
+ if (closest_docs == NULL)
+ {
+ closest_docs = new std::vector<size_t>[num_centers];
+ ret_closest_docs = false;
+ }
+ if (closest_center == NULL)
+ {
+ closest_center = new uint32_t[num_points];
+ ret_closest_center = false;
+ }
+
+ float *docs_l2sq = new float[num_points];
+ math_utils::compute_vecs_l2sq(docs_l2sq, data, num_points, dim);
+
+ float old_residual;
+ // Timer timer;
+ for (size_t i = 0; i < max_reps; ++i)
+ {
+ old_residual = residual;
+
+ residual = lloyds_iter(data, num_points, dim, centers, num_centers, docs_l2sq, closest_docs, closest_center);
+
+ if (((i != 0) && ((old_residual - residual) / residual) < 0.00001) ||
+ (residual < std::numeric_limits<float>::epsilon()))
+ {
+ diskann::cout << "Residuals unchanged: " << old_residual << " becomes " << residual
+ << ". Early termination." << std::endl;
+ break;
+ }
+ }
+ delete[] docs_l2sq;
+ if (!ret_closest_docs)
+ delete[] closest_docs;
+ if (!ret_closest_center)
+ delete[] closest_center;
+ return residual;
+}
+
+// assumes memory allocated for pivot_data as new
+// float[num_centers*dim]
+// and select randomly num_centers points as pivots
+void selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers)
+{
+ // pivot_data = new float[num_centers * dim];
+
+ std::vector<size_t> picked;
+ std::random_device rd;
+ auto x = rd();
+ std::mt19937 generator(x);
+ std::uniform_int_distribution<size_t> distribution(0, num_points - 1);
+
+ size_t tmp_pivot;
+ for (size_t j = 0; j < num_centers; j++)
+ {
+ tmp_pivot = distribution(generator);
+ if (std::find(picked.begin(), picked.end(), tmp_pivot) != picked.end())
+ continue;
+ picked.push_back(tmp_pivot);
+ std::memcpy(pivot_data + j * dim, data + tmp_pivot * dim, dim * sizeof(float));
+ }
+}
+
+void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers)
+{
+ if (num_points > 1 << 23)
+ {
+ diskann::cout << "ERROR: n_pts " << num_points
+ << " currently not supported for k-means++, maximum is "
+ "8388608. Falling back to random pivot "
+ "selection."
+ << std::endl;
+ selecting_pivots(data, num_points, dim, pivot_data, num_centers);
+ return;
+ }
+
+ std::vector<size_t> picked;
+ std::random_device rd;
+ auto x = rd();
+ std::mt19937 generator(x);
+ std::uniform_real_distribution<> distribution(0, 1);
+ std::uniform_int_distribution<size_t> int_dist(0, num_points - 1);
+ size_t init_id = int_dist(generator);
+ size_t num_picked = 1;
+
+ picked.push_back(init_id);
+ std::memcpy(pivot_data, data + init_id * dim, dim * sizeof(float));
+
+ float *dist = new float[num_points];
+
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t i = 0; i < (int64_t)num_points; i++)
+ {
+ dist[i] = math_utils::calc_distance(data + i * dim, data + init_id * dim, dim);
+ }
+
+ double dart_val;
+ size_t tmp_pivot;
+ bool sum_flag = false;
+
+ while (num_picked < num_centers)
+ {
+ dart_val = distribution(generator);
+
+ double sum = 0;
+ for (size_t i = 0; i < num_points; i++)
+ {
+ sum = sum + dist[i];
+ }
+ if (sum == 0)
+ sum_flag = true;
+
+ dart_val *= sum;
+
+ double prefix_sum = 0;
+ for (size_t i = 0; i < (num_points); i++)
+ {
+ tmp_pivot = i;
+ if (dart_val >= prefix_sum && dart_val < prefix_sum + dist[i])
+ {
+ break;
+ }
+
+ prefix_sum += dist[i];
+ }
+
+ if (std::find(picked.begin(), picked.end(), tmp_pivot) != picked.end() && (sum_flag == false))
+ continue;
+ picked.push_back(tmp_pivot);
+ std::memcpy(pivot_data + num_picked * dim, data + tmp_pivot * dim, dim * sizeof(float));
+
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t i = 0; i < (int64_t)num_points; i++)
+ {
+ dist[i] = (std::min)(dist[i], math_utils::calc_distance(data + i * dim, data + tmp_pivot * dim, dim));
+ }
+ num_picked++;
+ }
+ delete[] dist;
+}
+
+} // namespace kmeans
diff --git a/be/src/extern/diskann/src/memory_mapper.cpp b/be/src/extern/diskann/src/memory_mapper.cpp
new file mode 100644
index 0000000..d1c5ef9
--- /dev/null
+++ b/be/src/extern/diskann/src/memory_mapper.cpp
@@ -0,0 +1,107 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include "logger.h"
+#include "memory_mapper.h"
+#include <iostream>
+#include <sstream>
+
+using namespace diskann;
+
+MemoryMapper::MemoryMapper(const std::string &filename) : MemoryMapper(filename.c_str())
+{
+}
+
+MemoryMapper::MemoryMapper(const char *filename)
+{
+#ifndef _WINDOWS
+ _fd = open(filename, O_RDONLY);
+ if (_fd <= 0)
+ {
+ std::cerr << "Inner vertices file not found" << std::endl;
+ return;
+ }
+ struct stat sb;
+ if (fstat(_fd, &sb) != 0)
+ {
+ std::cerr << "Inner vertices file not dound. " << std::endl;
+ return;
+ }
+ _fileSize = sb.st_size;
+ diskann::cout << "File Size: " << _fileSize << std::endl;
+ _buf = (char *)mmap(NULL, _fileSize, PROT_READ, MAP_PRIVATE, _fd, 0);
+#else
+ _bareFile =
+ CreateFileA(filename, GENERIC_READ | GENERIC_EXECUTE, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
+ if (_bareFile == nullptr)
+ {
+ std::ostringstream message;
+ message << "CreateFileA(" << filename << ") failed with error " << GetLastError() << std::endl;
+ std::cerr << message.str();
+ throw std::exception(message.str().c_str());
+ }
+
+ _fd = CreateFileMapping(_bareFile, NULL, PAGE_EXECUTE_READ, 0, 0, NULL);
+ if (_fd == nullptr)
+ {
+ std::ostringstream message;
+ message << "CreateFileMapping(" << filename << ") failed with error " << GetLastError() << std::endl;
+ std::cerr << message.str() << std::endl;
+ throw std::exception(message.str().c_str());
+ }
+
+ _buf = (char *)MapViewOfFile(_fd, FILE_MAP_READ, 0, 0, 0);
+ if (_buf == nullptr)
+ {
+ std::ostringstream message;
+ message << "MapViewOfFile(" << filename << ") failed with error: " << GetLastError() << std::endl;
+ std::cerr << message.str() << std::endl;
+ throw std::exception(message.str().c_str());
+ }
+
+ LARGE_INTEGER fSize;
+ if (TRUE == GetFileSizeEx(_bareFile, &fSize))
+ {
+ _fileSize = fSize.QuadPart; // take the 64-bit value
+ diskann::cout << "File Size: " << _fileSize << std::endl;
+ }
+ else
+ {
+ std::cerr << "Failed to get size of file " << filename << std::endl;
+ }
+#endif
+}
+char *MemoryMapper::getBuf()
+{
+ return _buf;
+}
+
+size_t MemoryMapper::getFileSize()
+{
+ return _fileSize;
+}
+
+MemoryMapper::~MemoryMapper()
+{
+#ifndef _WINDOWS
+ if (munmap(_buf, _fileSize) != 0)
+ std::cerr << "ERROR unmapping. CHECK!" << std::endl;
+ close(_fd);
+#else
+ if (FALSE == UnmapViewOfFile(_buf))
+ {
+ std::cerr << "Unmap view of file failed. Error: " << GetLastError() << std::endl;
+ }
+
+ if (FALSE == CloseHandle(_fd))
+ {
+ std::cerr << "Failed to close memory mapped file. Error: " << GetLastError() << std::endl;
+ }
+
+ if (FALSE == CloseHandle(_bareFile))
+ {
+ std::cerr << "Failed to close file: " << _fileName << " Error: " << GetLastError() << std::endl;
+ }
+
+#endif
+}
diff --git a/be/src/extern/diskann/src/natural_number_map.cpp b/be/src/extern/diskann/src/natural_number_map.cpp
new file mode 100644
index 0000000..a996dcf
--- /dev/null
+++ b/be/src/extern/diskann/src/natural_number_map.cpp
@@ -0,0 +1,116 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <assert.h>
+#include <boost/dynamic_bitset.hpp>
+
+#include "natural_number_map.h"
+#include "tag_uint128.h"
+
+namespace diskann
+{
+static constexpr auto invalid_position = boost::dynamic_bitset<>::npos;
+
+template <typename Key, typename Value>
+natural_number_map<Key, Value>::natural_number_map()
+ : _size(0), _values_bitset(std::make_unique<boost::dynamic_bitset<>>())
+{
+}
+
+template <typename Key, typename Value> void natural_number_map<Key, Value>::reserve(size_t count)
+{
+ _values_vector.reserve(count);
+ _values_bitset->reserve(count);
+}
+
+template <typename Key, typename Value> size_t natural_number_map<Key, Value>::size() const
+{
+ return _size;
+}
+
+template <typename Key, typename Value> void natural_number_map<Key, Value>::set(Key key, Value value)
+{
+ if (key >= _values_bitset->size())
+ {
+ _values_bitset->resize(static_cast<size_t>(key) + 1);
+ _values_vector.resize(_values_bitset->size());
+ }
+
+ _values_vector[key] = value;
+ const bool was_present = _values_bitset->test_set(key, true);
+
+ if (!was_present)
+ {
+ ++_size;
+ }
+}
+
+template <typename Key, typename Value> void natural_number_map<Key, Value>::erase(Key key)
+{
+ if (key < _values_bitset->size())
+ {
+ const bool was_present = _values_bitset->test_set(key, false);
+
+ if (was_present)
+ {
+ --_size;
+ }
+ }
+}
+
+template <typename Key, typename Value> bool natural_number_map<Key, Value>::contains(Key key) const
+{
+ return key < _values_bitset->size() && _values_bitset->test(key);
+}
+
+template <typename Key, typename Value> bool natural_number_map<Key, Value>::try_get(Key key, Value &value) const
+{
+ if (!contains(key))
+ {
+ return false;
+ }
+
+ value = _values_vector[key];
+ return true;
+}
+
+template <typename Key, typename Value>
+typename natural_number_map<Key, Value>::position natural_number_map<Key, Value>::find_first() const
+{
+ return position{_size > 0 ? _values_bitset->find_first() : invalid_position, 0};
+}
+
+template <typename Key, typename Value>
+typename natural_number_map<Key, Value>::position natural_number_map<Key, Value>::find_next(
+ const position &after_position) const
+{
+ return position{after_position._keys_already_enumerated < _size ? _values_bitset->find_next(after_position._key)
+ : invalid_position,
+ after_position._keys_already_enumerated + 1};
+}
+
+template <typename Key, typename Value> bool natural_number_map<Key, Value>::position::is_valid() const
+{
+ return _key != invalid_position;
+}
+
+template <typename Key, typename Value> Value natural_number_map<Key, Value>::get(const position &pos) const
+{
+ assert(pos.is_valid());
+ return _values_vector[pos._key];
+}
+
+template <typename Key, typename Value> void natural_number_map<Key, Value>::clear()
+{
+ _size = 0;
+ _values_vector.clear();
+ _values_bitset->clear();
+}
+
+// Instantiate used templates.
+template class natural_number_map<uint32_t, int32_t>;
+template class natural_number_map<uint32_t, uint32_t>;
+template class natural_number_map<uint32_t, int64_t>;
+template class natural_number_map<uint32_t, uint64_t>;
+template class natural_number_map<uint32_t, tag_uint128>;
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/natural_number_set.cpp b/be/src/extern/diskann/src/natural_number_set.cpp
new file mode 100644
index 0000000..b36cb52
--- /dev/null
+++ b/be/src/extern/diskann/src/natural_number_set.cpp
@@ -0,0 +1,70 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <boost/dynamic_bitset.hpp>
+
+#include "ann_exception.h"
+#include "natural_number_set.h"
+
+namespace diskann
+{
+template <typename T>
+natural_number_set<T>::natural_number_set() : _values_bitset(std::make_unique<boost::dynamic_bitset<>>())
+{
+}
+
+template <typename T> bool natural_number_set<T>::is_empty() const
+{
+ return _values_vector.empty();
+}
+
+template <typename T> void natural_number_set<T>::reserve(size_t count)
+{
+ _values_vector.reserve(count);
+ _values_bitset->reserve(count);
+}
+
+template <typename T> void natural_number_set<T>::insert(T id)
+{
+ _values_vector.emplace_back(id);
+
+ if (id >= _values_bitset->size())
+ _values_bitset->resize(static_cast<size_t>(id) + 1);
+
+ _values_bitset->set(id, true);
+}
+
+template <typename T> T natural_number_set<T>::pop_any()
+{
+ if (_values_vector.empty())
+ {
+ throw diskann::ANNException("No values available", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ const T id = _values_vector.back();
+ _values_vector.pop_back();
+
+ _values_bitset->set(id, false);
+
+ return id;
+}
+
+template <typename T> void natural_number_set<T>::clear()
+{
+ _values_vector.clear();
+ _values_bitset->clear();
+}
+
+template <typename T> size_t natural_number_set<T>::size() const
+{
+ return _values_vector.size();
+}
+
+template <typename T> bool natural_number_set<T>::is_in_set(T id) const
+{
+ return _values_bitset->test(id);
+}
+
+// Instantiate used templates.
+template class natural_number_set<unsigned>;
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/partition.cpp b/be/src/extern/diskann/src/partition.cpp
new file mode 100644
index 0000000..325b1d6
--- /dev/null
+++ b/be/src/extern/diskann/src/partition.cpp
@@ -0,0 +1,706 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <cmath>
+#include <cstdio>
+#include <iostream>
+#include <sstream>
+#include <string>
+
+#include <omp.h>
+#include "tsl/robin_map.h"
+#include "tsl/robin_set.h"
+
+#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
+#include "gperftools/malloc_extension.h"
+#endif
+
+#include "utils.h"
+#include "math_utils.h"
+#include "index.h"
+#include "parameters.h"
+#include "memory_mapper.h"
+#include "partition.h"
+#ifdef _WINDOWS
+#include <xmmintrin.h>
+#endif
+
+// block size for reading/ processing large files and matrices in blocks
+#define BLOCK_SIZE 5000000
+
+// #define SAVE_INFLATED_PQ true
+
+template <typename T>
+void gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate)
+{
+ size_t read_blk_size = 64 * 1024 * 1024;
+ cached_ifstream base_reader(base_file.c_str(), read_blk_size);
+ std::ofstream sample_writer(std::string(output_prefix + "_data.bin").c_str(), std::ios::binary);
+ std::ofstream sample_id_writer(std::string(output_prefix + "_ids.bin").c_str(), std::ios::binary);
+
+ std::random_device rd; // Will be used to obtain a seed for the random number engine
+ auto x = rd();
+ std::mt19937 generator(x); // Standard mersenne_twister_engine seeded with rd()
+ std::uniform_real_distribution<float> distribution(0, 1);
+
+ size_t npts, nd;
+ uint32_t npts_u32, nd_u32;
+ uint32_t num_sampled_pts_u32 = 0;
+ uint32_t one_const = 1;
+
+ base_reader.read((char *)&npts_u32, sizeof(uint32_t));
+ base_reader.read((char *)&nd_u32, sizeof(uint32_t));
+ diskann::cout << "Loading base " << base_file << ". #points: " << npts_u32 << ". #dim: " << nd_u32 << "."
+ << std::endl;
+ sample_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t));
+ sample_writer.write((char *)&nd_u32, sizeof(uint32_t));
+ sample_id_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t));
+ sample_id_writer.write((char *)&one_const, sizeof(uint32_t));
+
+ npts = npts_u32;
+ nd = nd_u32;
+ std::unique_ptr<T[]> cur_row = std::make_unique<T[]>(nd);
+
+ for (size_t i = 0; i < npts; i++)
+ {
+ base_reader.read((char *)cur_row.get(), sizeof(T) * nd);
+ float sample = distribution(generator);
+ if (sample < sampling_rate)
+ {
+ sample_writer.write((char *)cur_row.get(), sizeof(T) * nd);
+ uint32_t cur_i_u32 = (uint32_t)i;
+ sample_id_writer.write((char *)&cur_i_u32, sizeof(uint32_t));
+ num_sampled_pts_u32++;
+ }
+ }
+ sample_writer.seekp(0, std::ios::beg);
+ sample_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t));
+ sample_id_writer.seekp(0, std::ios::beg);
+ sample_id_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t));
+ sample_writer.close();
+ sample_id_writer.close();
+ diskann::cout << "Wrote " << num_sampled_pts_u32 << " points to sample file: " << output_prefix + "_data.bin"
+ << std::endl;
+}
+
+// streams data from the file, and samples each vector with probability p_val
+// and returns a matrix of size slice_size* ndims as floating point type.
+// the slice_size and ndims are set inside the function.
+
+/***********************************
+ * Reimplement using gen_random_slice(const T* inputdata,...)
+ ************************************/
+
+template <typename T>
+void gen_random_slice(const std::string data_file, double p_val, float *&sampled_data, size_t &slice_size,
+ size_t &ndims)
+{
+ size_t npts;
+ uint32_t npts32, ndims32;
+ std::vector<std::vector<float>> sampled_vectors;
+
+ // amount to read in one shot
+ size_t read_blk_size = 64 * 1024 * 1024;
+ // create cached reader + writer
+ cached_ifstream base_reader(data_file.c_str(), read_blk_size);
+
+ // metadata: npts, ndims
+ base_reader.read((char *)&npts32, sizeof(uint32_t));
+ base_reader.read((char *)&ndims32, sizeof(uint32_t));
+ npts = npts32;
+ ndims = ndims32;
+
+ std::unique_ptr<T[]> cur_vector_T = std::make_unique<T[]>(ndims);
+ p_val = p_val < 1 ? p_val : 1;
+
+ std::random_device rd; // Will be used to obtain a seed for the random number
+ size_t x = rd();
+ std::mt19937 generator((uint32_t)x);
+ std::uniform_real_distribution<float> distribution(0, 1);
+
+ for (size_t i = 0; i < npts; i++)
+ {
+ base_reader.read((char *)cur_vector_T.get(), ndims * sizeof(T));
+ float rnd_val = distribution(generator);
+ if (rnd_val < p_val)
+ {
+ std::vector<float> cur_vector_float;
+ for (size_t d = 0; d < ndims; d++)
+ cur_vector_float.push_back(cur_vector_T[d]);
+ sampled_vectors.push_back(cur_vector_float);
+ }
+ }
+ slice_size = sampled_vectors.size();
+ sampled_data = new float[slice_size * ndims];
+ for (size_t i = 0; i < slice_size; i++)
+ {
+ for (size_t j = 0; j < ndims; j++)
+ {
+ sampled_data[i * ndims + j] = sampled_vectors[i][j];
+ }
+ }
+}
+
+template <typename T>
+void gen_random_slice(std::stringstream & _data_stream, double p_val, float *&sampled_data, size_t &slice_size,
+ size_t &ndims)
+{
+ size_t npts;
+ uint32_t npts32, ndims32;
+ std::vector<std::vector<float>> sampled_vectors;
+
+ // metadata: npts, ndims
+ _data_stream.seekg(0);
+ _data_stream.read((char *)&npts32, sizeof(uint32_t));
+ _data_stream.read((char *)&ndims32, sizeof(uint32_t));
+ npts = npts32;
+ ndims = ndims32;
+
+ std::unique_ptr<T[]> cur_vector_T = std::make_unique<T[]>(ndims);
+ p_val = p_val < 1 ? p_val : 1;
+
+ std::random_device rd; // Will be used to obtain a seed for the random number
+ size_t x = rd();
+ std::mt19937 generator((uint32_t)x);
+ std::uniform_real_distribution<float> distribution(0, 1);
+
+ for (size_t i = 0; i < npts; i++)
+ {
+ _data_stream.read((char *)cur_vector_T.get(), ndims * sizeof(T));
+ float rnd_val = distribution(generator);
+ if (rnd_val < p_val)
+ {
+ std::vector<float> cur_vector_float;
+ for (size_t d = 0; d < ndims; d++)
+ cur_vector_float.push_back(cur_vector_T[d]);
+ sampled_vectors.push_back(cur_vector_float);
+ }
+ }
+ slice_size = sampled_vectors.size();
+ sampled_data = new float[slice_size * ndims];
+ for (size_t i = 0; i < slice_size; i++)
+ {
+ for (size_t j = 0; j < ndims; j++)
+ {
+ sampled_data[i * ndims + j] = sampled_vectors[i][j];
+ }
+ }
+}
+
+// same as above, but samples from the matrix inputdata instead of a file of
+// npts*ndims to return sampled_data of size slice_size*ndims.
+template <typename T>
+void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, double p_val, float *&sampled_data,
+ size_t &slice_size)
+{
+ std::vector<std::vector<float>> sampled_vectors;
+ const T *cur_vector_T;
+
+ p_val = p_val < 1 ? p_val : 1;
+
+ std::random_device rd; // Will be used to obtain a seed for the random number engine
+ size_t x = rd();
+ std::mt19937 generator((uint32_t)x); // Standard mersenne_twister_engine seeded with rd()
+ std::uniform_real_distribution<float> distribution(0, 1);
+
+ for (size_t i = 0; i < npts; i++)
+ {
+ cur_vector_T = inputdata + ndims * i;
+ float rnd_val = distribution(generator);
+ if (rnd_val < p_val)
+ {
+ std::vector<float> cur_vector_float;
+ for (size_t d = 0; d < ndims; d++)
+ cur_vector_float.push_back(cur_vector_T[d]);
+ sampled_vectors.push_back(cur_vector_float);
+ }
+ }
+ slice_size = sampled_vectors.size();
+ sampled_data = new float[slice_size * ndims];
+ for (size_t i = 0; i < slice_size; i++)
+ {
+ for (size_t j = 0; j < ndims; j++)
+ {
+ sampled_data[i * ndims + j] = sampled_vectors[i][j];
+ }
+ }
+}
+
+int estimate_cluster_sizes(float *test_data_float, size_t num_test, float *pivots, const size_t num_centers,
+ const size_t test_dim, const size_t k_base, std::vector<size_t> &cluster_sizes)
+{
+ cluster_sizes.clear();
+
+ size_t *shard_counts = new size_t[num_centers];
+
+ for (size_t i = 0; i < num_centers; i++)
+ {
+ shard_counts[i] = 0;
+ }
+
+ size_t block_size = num_test <= BLOCK_SIZE ? num_test : BLOCK_SIZE;
+ uint32_t *block_closest_centers = new uint32_t[block_size * k_base];
+ float *block_data_float;
+
+ size_t num_blocks = DIV_ROUND_UP(num_test, block_size);
+
+ for (size_t block = 0; block < num_blocks; block++)
+ {
+ size_t start_id = block * block_size;
+ size_t end_id = (std::min)((block + 1) * block_size, num_test);
+ size_t cur_blk_size = end_id - start_id;
+
+ block_data_float = test_data_float + start_id * test_dim;
+
+ math_utils::compute_closest_centers(block_data_float, cur_blk_size, test_dim, pivots, num_centers, k_base,
+ block_closest_centers);
+
+ for (size_t p = 0; p < cur_blk_size; p++)
+ {
+ for (size_t p1 = 0; p1 < k_base; p1++)
+ {
+ size_t shard_id = block_closest_centers[p * k_base + p1];
+ shard_counts[shard_id]++;
+ }
+ }
+ }
+
+ diskann::cout << "Estimated cluster sizes: ";
+ for (size_t i = 0; i < num_centers; i++)
+ {
+ uint32_t cur_shard_count = (uint32_t)shard_counts[i];
+ cluster_sizes.push_back((size_t)cur_shard_count);
+ diskann::cout << cur_shard_count << " ";
+ }
+ diskann::cout << std::endl;
+ delete[] shard_counts;
+ delete[] block_closest_centers;
+ return 0;
+}
+
+template <typename T>
+int shard_data_into_clusters(const std::string data_file, float *pivots, const size_t num_centers, const size_t dim,
+ const size_t k_base, std::string prefix_path)
+{
+ size_t read_blk_size = 64 * 1024 * 1024;
+ // uint64_t write_blk_size = 64 * 1024 * 1024;
+ // create cached reader + writer
+ cached_ifstream base_reader(data_file, read_blk_size);
+ uint32_t npts32;
+ uint32_t basedim32;
+ base_reader.read((char *)&npts32, sizeof(uint32_t));
+ base_reader.read((char *)&basedim32, sizeof(uint32_t));
+ size_t num_points = npts32;
+ if (basedim32 != dim)
+ {
+ diskann::cout << "Error. dimensions dont match for train set and base set" << std::endl;
+ return -1;
+ }
+
+ std::unique_ptr<size_t[]> shard_counts = std::make_unique<size_t[]>(num_centers);
+ std::vector<std::ofstream> shard_data_writer(num_centers);
+ std::vector<std::ofstream> shard_idmap_writer(num_centers);
+ uint32_t dummy_size = 0;
+ uint32_t const_one = 1;
+
+ for (size_t i = 0; i < num_centers; i++)
+ {
+ std::string data_filename = prefix_path + "_subshard-" + std::to_string(i) + ".bin";
+ std::string idmap_filename = prefix_path + "_subshard-" + std::to_string(i) + "_ids_uint32.bin";
+ shard_data_writer[i] = std::ofstream(data_filename.c_str(), std::ios::binary);
+ shard_idmap_writer[i] = std::ofstream(idmap_filename.c_str(), std::ios::binary);
+ shard_data_writer[i].write((char *)&dummy_size, sizeof(uint32_t));
+ shard_data_writer[i].write((char *)&basedim32, sizeof(uint32_t));
+ shard_idmap_writer[i].write((char *)&dummy_size, sizeof(uint32_t));
+ shard_idmap_writer[i].write((char *)&const_one, sizeof(uint32_t));
+ shard_counts[i] = 0;
+ }
+
+ size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE;
+ std::unique_ptr<uint32_t[]> block_closest_centers = std::make_unique<uint32_t[]>(block_size * k_base);
+ std::unique_ptr<T[]> block_data_T = std::make_unique<T[]>(block_size * dim);
+ std::unique_ptr<float[]> block_data_float = std::make_unique<float[]>(block_size * dim);
+
+ size_t num_blocks = DIV_ROUND_UP(num_points, block_size);
+
+ for (size_t block = 0; block < num_blocks; block++)
+ {
+ size_t start_id = block * block_size;
+ size_t end_id = (std::min)((block + 1) * block_size, num_points);
+ size_t cur_blk_size = end_id - start_id;
+
+ base_reader.read((char *)block_data_T.get(), sizeof(T) * (cur_blk_size * dim));
+ diskann::convert_types<T, float>(block_data_T.get(), block_data_float.get(), cur_blk_size, dim);
+
+ math_utils::compute_closest_centers(block_data_float.get(), cur_blk_size, dim, pivots, num_centers, k_base,
+ block_closest_centers.get());
+
+ for (size_t p = 0; p < cur_blk_size; p++)
+ {
+ for (size_t p1 = 0; p1 < k_base; p1++)
+ {
+ size_t shard_id = block_closest_centers[p * k_base + p1];
+ uint32_t original_point_map_id = (uint32_t)(start_id + p);
+ shard_data_writer[shard_id].write((char *)(block_data_T.get() + p * dim), sizeof(T) * dim);
+ shard_idmap_writer[shard_id].write((char *)&original_point_map_id, sizeof(uint32_t));
+ shard_counts[shard_id]++;
+ }
+ }
+ }
+
+ size_t total_count = 0;
+ diskann::cout << "Actual shard sizes: " << std::flush;
+ for (size_t i = 0; i < num_centers; i++)
+ {
+ uint32_t cur_shard_count = (uint32_t)shard_counts[i];
+ total_count += cur_shard_count;
+ diskann::cout << cur_shard_count << " ";
+ shard_data_writer[i].seekp(0);
+ shard_data_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t));
+ shard_data_writer[i].close();
+ shard_idmap_writer[i].seekp(0);
+ shard_idmap_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t));
+ shard_idmap_writer[i].close();
+ }
+
+ diskann::cout << "\n Partitioned " << num_points << " with replication factor " << k_base << " to get "
+ << total_count << " points across " << num_centers << " shards " << std::endl;
+ return 0;
+}
+
+// useful for partitioning large dataset. we first generate only the IDS for
+// each shard, and retrieve the actual vectors on demand.
+template <typename T>
+int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots, const size_t num_centers,
+ const size_t dim, const size_t k_base, std::string prefix_path)
+{
+ size_t read_blk_size = 64 * 1024 * 1024;
+ // uint64_t write_blk_size = 64 * 1024 * 1024;
+ // create cached reader + writer
+ cached_ifstream base_reader(data_file, read_blk_size);
+ uint32_t npts32;
+ uint32_t basedim32;
+ base_reader.read((char *)&npts32, sizeof(uint32_t));
+ base_reader.read((char *)&basedim32, sizeof(uint32_t));
+ size_t num_points = npts32;
+ if (basedim32 != dim)
+ {
+ diskann::cout << "Error. dimensions dont match for train set and base set" << std::endl;
+ return -1;
+ }
+
+ std::unique_ptr<size_t[]> shard_counts = std::make_unique<size_t[]>(num_centers);
+
+ std::vector<std::ofstream> shard_idmap_writer(num_centers);
+ uint32_t dummy_size = 0;
+ uint32_t const_one = 1;
+
+ for (size_t i = 0; i < num_centers; i++)
+ {
+ std::string idmap_filename = prefix_path + "_subshard-" + std::to_string(i) + "_ids_uint32.bin";
+ shard_idmap_writer[i] = std::ofstream(idmap_filename.c_str(), std::ios::binary);
+ shard_idmap_writer[i].write((char *)&dummy_size, sizeof(uint32_t));
+ shard_idmap_writer[i].write((char *)&const_one, sizeof(uint32_t));
+ shard_counts[i] = 0;
+ }
+
+ size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE;
+ std::unique_ptr<uint32_t[]> block_closest_centers = std::make_unique<uint32_t[]>(block_size * k_base);
+ std::unique_ptr<T[]> block_data_T = std::make_unique<T[]>(block_size * dim);
+ std::unique_ptr<float[]> block_data_float = std::make_unique<float[]>(block_size * dim);
+
+ size_t num_blocks = DIV_ROUND_UP(num_points, block_size);
+
+ for (size_t block = 0; block < num_blocks; block++)
+ {
+ size_t start_id = block * block_size;
+ size_t end_id = (std::min)((block + 1) * block_size, num_points);
+ size_t cur_blk_size = end_id - start_id;
+
+ base_reader.read((char *)block_data_T.get(), sizeof(T) * (cur_blk_size * dim));
+ diskann::convert_types<T, float>(block_data_T.get(), block_data_float.get(), cur_blk_size, dim);
+
+ math_utils::compute_closest_centers(block_data_float.get(), cur_blk_size, dim, pivots, num_centers, k_base,
+ block_closest_centers.get());
+
+ for (size_t p = 0; p < cur_blk_size; p++)
+ {
+ for (size_t p1 = 0; p1 < k_base; p1++)
+ {
+ size_t shard_id = block_closest_centers[p * k_base + p1];
+ uint32_t original_point_map_id = (uint32_t)(start_id + p);
+ shard_idmap_writer[shard_id].write((char *)&original_point_map_id, sizeof(uint32_t));
+ shard_counts[shard_id]++;
+ }
+ }
+ }
+
+ size_t total_count = 0;
+ diskann::cout << "Actual shard sizes: " << std::flush;
+ for (size_t i = 0; i < num_centers; i++)
+ {
+ uint32_t cur_shard_count = (uint32_t)shard_counts[i];
+ total_count += cur_shard_count;
+ diskann::cout << cur_shard_count << " ";
+ shard_idmap_writer[i].seekp(0);
+ shard_idmap_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t));
+ shard_idmap_writer[i].close();
+ }
+
+ diskann::cout << "\n Partitioned " << num_points << " with replication factor " << k_base << " to get "
+ << total_count << " points across " << num_centers << " shards " << std::endl;
+ return 0;
+}
+
+template <typename T>
+int retrieve_shard_data_from_ids(const std::string data_file, std::string idmap_filename, std::string data_filename)
+{
+ size_t read_blk_size = 64 * 1024 * 1024;
+ // uint64_t write_blk_size = 64 * 1024 * 1024;
+ // create cached reader + writer
+ cached_ifstream base_reader(data_file, read_blk_size);
+ uint32_t npts32;
+ uint32_t basedim32;
+ base_reader.read((char *)&npts32, sizeof(uint32_t));
+ base_reader.read((char *)&basedim32, sizeof(uint32_t));
+ size_t num_points = npts32;
+ size_t dim = basedim32;
+
+ uint32_t dummy_size = 0;
+
+ std::ofstream shard_data_writer(data_filename.c_str(), std::ios::binary);
+ shard_data_writer.write((char *)&dummy_size, sizeof(uint32_t));
+ shard_data_writer.write((char *)&basedim32, sizeof(uint32_t));
+
+ uint32_t *shard_ids;
+ uint64_t shard_size, tmp;
+ diskann::load_bin<uint32_t>(idmap_filename, shard_ids, shard_size, tmp);
+
+ uint32_t cur_pos = 0;
+ uint32_t num_written = 0;
+ std::cout << "Shard has " << shard_size << " points" << std::endl;
+
+ size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE;
+ std::unique_ptr<T[]> block_data_T = std::make_unique<T[]>(block_size * dim);
+
+ size_t num_blocks = DIV_ROUND_UP(num_points, block_size);
+
+ for (size_t block = 0; block < num_blocks; block++)
+ {
+ size_t start_id = block * block_size;
+ size_t end_id = (std::min)((block + 1) * block_size, num_points);
+ size_t cur_blk_size = end_id - start_id;
+
+ base_reader.read((char *)block_data_T.get(), sizeof(T) * (cur_blk_size * dim));
+
+ for (size_t p = 0; p < cur_blk_size; p++)
+ {
+ uint32_t original_point_map_id = (uint32_t)(start_id + p);
+ if (cur_pos == shard_size)
+ break;
+ if (original_point_map_id == shard_ids[cur_pos])
+ {
+ cur_pos++;
+ shard_data_writer.write((char *)(block_data_T.get() + p * dim), sizeof(T) * dim);
+ num_written++;
+ }
+ }
+ if (cur_pos == shard_size)
+ break;
+ }
+
+ diskann::cout << "Written file with " << num_written << " points" << std::endl;
+
+ shard_data_writer.seekp(0);
+ shard_data_writer.write((char *)&num_written, sizeof(uint32_t));
+ shard_data_writer.close();
+ delete[] shard_ids;
+ return 0;
+}
+
+// partitions a large base file into many shards using k-means hueristic
+// on a random sample generated using sampling_rate probability. After this, it
+// assignes each base point to the closest k_base nearest centers and creates
+// the shards.
+// The total number of points across all shards will be k_base * num_points.
+
+template <typename T>
+int partition(const std::string data_file, const float sampling_rate, size_t num_parts, size_t max_k_means_reps,
+ const std::string prefix_path, size_t k_base)
+{
+ size_t train_dim;
+ size_t num_train;
+ float *train_data_float;
+
+ gen_random_slice<T>(data_file, sampling_rate, train_data_float, num_train, train_dim);
+
+ float *pivot_data;
+
+ std::string cur_file = std::string(prefix_path);
+ std::string output_file;
+
+ // kmeans_partitioning on training data
+
+ // cur_file = cur_file + "_kmeans_partitioning-" +
+ // std::to_string(num_parts);
+ output_file = cur_file + "_centroids.bin";
+
+ pivot_data = new float[num_parts * train_dim];
+
+ // Process Global k-means for kmeans_partitioning Step
+ diskann::cout << "Processing global k-means (kmeans_partitioning Step)" << std::endl;
+ kmeans::kmeanspp_selecting_pivots(train_data_float, num_train, train_dim, pivot_data, num_parts);
+
+ kmeans::run_lloyds(train_data_float, num_train, train_dim, pivot_data, num_parts, max_k_means_reps, NULL, NULL);
+
+ diskann::cout << "Saving global k-center pivots" << std::endl;
+ diskann::save_bin<float>(output_file.c_str(), pivot_data, (size_t)num_parts, train_dim);
+
+ // now pivots are ready. need to stream base points and assign them to
+ // closest clusters.
+
+ shard_data_into_clusters<T>(data_file, pivot_data, num_parts, train_dim, k_base, prefix_path);
+ delete[] pivot_data;
+ delete[] train_data_float;
+ return 0;
+}
+
+template <typename T>
+int partition_with_ram_budget(const std::string data_file, const double sampling_rate, double ram_budget,
+ size_t graph_degree, const std::string prefix_path, size_t k_base)
+{
+ size_t train_dim;
+ size_t num_train;
+ float *train_data_float;
+ size_t max_k_means_reps = 10;
+
+ int num_parts = 3;
+ bool fit_in_ram = false;
+
+ gen_random_slice<T>(data_file, sampling_rate, train_data_float, num_train, train_dim);
+
+ size_t test_dim;
+ size_t num_test;
+ float *test_data_float;
+ gen_random_slice<T>(data_file, sampling_rate, test_data_float, num_test, test_dim);
+
+ float *pivot_data = nullptr;
+
+ std::string cur_file = std::string(prefix_path);
+ std::string output_file;
+
+ // kmeans_partitioning on training data
+
+ // cur_file = cur_file + "_kmeans_partitioning-" +
+ // std::to_string(num_parts);
+ output_file = cur_file + "_centroids.bin";
+
+ while (!fit_in_ram)
+ {
+ fit_in_ram = true;
+
+ double max_ram_usage = 0;
+ if (pivot_data != nullptr)
+ delete[] pivot_data;
+
+ pivot_data = new float[num_parts * train_dim];
+ // Process Global k-means for kmeans_partitioning Step
+ diskann::cout << "Processing global k-means (kmeans_partitioning Step)" << std::endl;
+ kmeans::kmeanspp_selecting_pivots(train_data_float, num_train, train_dim, pivot_data, num_parts);
+
+ kmeans::run_lloyds(train_data_float, num_train, train_dim, pivot_data, num_parts, max_k_means_reps, NULL, NULL);
+
+ // now pivots are ready. need to stream base points and assign them to
+ // closest clusters.
+
+ std::vector<size_t> cluster_sizes;
+ estimate_cluster_sizes(test_data_float, num_test, pivot_data, num_parts, train_dim, k_base, cluster_sizes);
+
+ for (auto &p : cluster_sizes)
+ {
+ // to account for the fact that p is the size of the shard over the
+ // testing sample.
+ p = (uint64_t)(p / sampling_rate);
+ double cur_shard_ram_estimate =
+ diskann::estimate_ram_usage(p, (uint32_t)train_dim, sizeof(T), (uint32_t)graph_degree);
+
+ if (cur_shard_ram_estimate > max_ram_usage)
+ max_ram_usage = cur_shard_ram_estimate;
+ }
+ diskann::cout << "With " << num_parts
+ << " parts, max estimated RAM usage: " << max_ram_usage / (1024 * 1024 * 1024)
+ << "GB, budget given is " << ram_budget << std::endl;
+ if (max_ram_usage > 1024 * 1024 * 1024 * ram_budget)
+ {
+ fit_in_ram = false;
+ num_parts += 2;
+ }
+ }
+
+ diskann::cout << "Saving global k-center pivots" << std::endl;
+ diskann::save_bin<float>(output_file.c_str(), pivot_data, (size_t)num_parts, train_dim);
+
+ shard_data_into_clusters_only_ids<T>(data_file, pivot_data, num_parts, train_dim, k_base, prefix_path);
+ delete[] pivot_data;
+ delete[] train_data_float;
+ delete[] test_data_float;
+ return num_parts;
+}
+
+// Instantations of supported templates
+
+template void DISKANN_DLLEXPORT gen_random_slice<int8_t>(const std::string base_file, const std::string output_prefix,
+ double sampling_rate);
+template void DISKANN_DLLEXPORT gen_random_slice<uint8_t>(const std::string base_file, const std::string output_prefix,
+ double sampling_rate);
+template void DISKANN_DLLEXPORT gen_random_slice<float>(const std::string base_file, const std::string output_prefix,
+ double sampling_rate);
+
+template void gen_random_slice<float>(std::stringstream & _data_stream, double p_val, float *&sampled_data, size_t &slice_size,
+ size_t &ndims);
+
+template void DISKANN_DLLEXPORT gen_random_slice<float>(const float *inputdata, size_t npts, size_t ndims, double p_val,
+ float *&sampled_data, size_t &slice_size);
+template void DISKANN_DLLEXPORT gen_random_slice<uint8_t>(const uint8_t *inputdata, size_t npts, size_t ndims,
+ double p_val, float *&sampled_data, size_t &slice_size);
+template void DISKANN_DLLEXPORT gen_random_slice<int8_t>(const int8_t *inputdata, size_t npts, size_t ndims,
+ double p_val, float *&sampled_data, size_t &slice_size);
+
+template void DISKANN_DLLEXPORT gen_random_slice<float>(const std::string data_file, double p_val, float *&sampled_data,
+ size_t &slice_size, size_t &ndims);
+template void DISKANN_DLLEXPORT gen_random_slice<uint8_t>(const std::string data_file, double p_val,
+ float *&sampled_data, size_t &slice_size, size_t &ndims);
+template void DISKANN_DLLEXPORT gen_random_slice<int8_t>(const std::string data_file, double p_val,
+ float *&sampled_data, size_t &slice_size, size_t &ndims);
+
+template DISKANN_DLLEXPORT int partition<int8_t>(const std::string data_file, const float sampling_rate,
+ size_t num_centers, size_t max_k_means_reps,
+ const std::string prefix_path, size_t k_base);
+template DISKANN_DLLEXPORT int partition<uint8_t>(const std::string data_file, const float sampling_rate,
+ size_t num_centers, size_t max_k_means_reps,
+ const std::string prefix_path, size_t k_base);
+template DISKANN_DLLEXPORT int partition<float>(const std::string data_file, const float sampling_rate,
+ size_t num_centers, size_t max_k_means_reps,
+ const std::string prefix_path, size_t k_base);
+
+template DISKANN_DLLEXPORT int partition_with_ram_budget<int8_t>(const std::string data_file,
+ const double sampling_rate, double ram_budget,
+ size_t graph_degree, const std::string prefix_path,
+ size_t k_base);
+template DISKANN_DLLEXPORT int partition_with_ram_budget<uint8_t>(const std::string data_file,
+ const double sampling_rate, double ram_budget,
+ size_t graph_degree, const std::string prefix_path,
+ size_t k_base);
+template DISKANN_DLLEXPORT int partition_with_ram_budget<float>(const std::string data_file, const double sampling_rate,
+ double ram_budget, size_t graph_degree,
+ const std::string prefix_path, size_t k_base);
+
+template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids<float>(const std::string data_file,
+ std::string idmap_filename,
+ std::string data_filename);
+template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids<uint8_t>(const std::string data_file,
+ std::string idmap_filename,
+ std::string data_filename);
+template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids<int8_t>(const std::string data_file,
+ std::string idmap_filename,
+ std::string data_filename);
\ No newline at end of file
diff --git a/be/src/extern/diskann/src/pq.cpp b/be/src/extern/diskann/src/pq.cpp
new file mode 100644
index 0000000..876ecdf
--- /dev/null
+++ b/be/src/extern/diskann/src/pq.cpp
@@ -0,0 +1,1821 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include "mkl.h"
+#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
+#include "gperftools/malloc_extension.h"
+#endif
+#include "pq.h"
+#include "partition.h"
+#include "math_utils.h"
+#include "tsl/robin_map.h"
+
+#include "vector/stream_wrapper.h"
+
+// block size for reading/processing large files and matrices in blocks
+#define BLOCK_SIZE 5000000
+
+namespace diskann
+{
+FixedChunkPQTable::FixedChunkPQTable()
+{
+}
+
+FixedChunkPQTable::~FixedChunkPQTable()
+{
+#ifndef EXEC_ENV_OLS
+ if (tables != nullptr)
+ delete[] tables;
+ if (tables_tr != nullptr)
+ delete[] tables_tr;
+ if (chunk_offsets != nullptr)
+ delete[] chunk_offsets;
+ if (centroid != nullptr)
+ delete[] centroid;
+ if (rotmat_tr != nullptr)
+ delete[] rotmat_tr;
+#endif
+}
+
+#ifdef EXEC_ENV_OLS
+void FixedChunkPQTable::load_pq_centroid_bin(MemoryMappedFiles &files, const char *pq_table_file, size_t num_chunks)
+{
+#else
+void FixedChunkPQTable::load_pq_centroid_bin(const char *pq_table_file, size_t num_chunks)
+{
+#endif
+
+ uint64_t nr, nc;
+ std::string rotmat_file = std::string(pq_table_file) + "_rotation_matrix.bin";
+
+#ifdef EXEC_ENV_OLS
+ size_t *file_offset_data; // since load_bin only sets the pointer, no need
+ // to delete.
+ diskann::load_bin<size_t>(files, pq_table_file, file_offset_data, nr, nc);
+#else
+ std::unique_ptr<size_t[]> file_offset_data;
+ diskann::load_bin<size_t>(pq_table_file, file_offset_data, nr, nc);
+#endif
+
+ bool use_old_filetype = false;
+
+ if (nr != 4 && nr != 5)
+ {
+ diskann::cout << "Error reading pq_pivots file " << pq_table_file
+ << ". Offsets dont contain correct metadata, # offsets = " << nr << ", but expecting " << 4
+ << " or " << 5;
+ throw diskann::ANNException("Error reading pq_pivots file at offsets data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ if (nr == 4)
+ {
+ diskann::cout << "Offsets: " << file_offset_data[0] << " " << file_offset_data[1] << " " << file_offset_data[2]
+ << " " << file_offset_data[3] << std::endl;
+ }
+ else if (nr == 5)
+ {
+ use_old_filetype = true;
+ diskann::cout << "Offsets: " << file_offset_data[0] << " " << file_offset_data[1] << " " << file_offset_data[2]
+ << " " << file_offset_data[3] << file_offset_data[4] << std::endl;
+ }
+ else
+ {
+ throw diskann::ANNException("Wrong number of offsets in pq_pivots", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+#ifdef EXEC_ENV_OLS
+
+ diskann::load_bin<float>(files, pq_table_file, tables, nr, nc, file_offset_data[0]);
+#else
+ diskann::load_bin<float>(pq_table_file, tables, nr, nc, file_offset_data[0]);
+#endif
+
+ if ((nr != NUM_PQ_CENTROIDS))
+ {
+ diskann::cout << "Error reading pq_pivots file " << pq_table_file << ". file_num_centers = " << nr
+ << " but expecting " << NUM_PQ_CENTROIDS << " centers";
+ throw diskann::ANNException("Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ this->ndims = nc;
+
+#ifdef EXEC_ENV_OLS
+ diskann::load_bin<float>(files, pq_table_file, centroid, nr, nc, file_offset_data[1]);
+#else
+ diskann::load_bin<float>(pq_table_file, centroid, nr, nc, file_offset_data[1]);
+#endif
+
+ if ((nr != this->ndims) || (nc != 1))
+ {
+ diskann::cerr << "Error reading centroids from pq_pivots file " << pq_table_file << ". file_dim = " << nr
+ << ", file_cols = " << nc << " but expecting " << this->ndims << " entries in 1 dimension.";
+ throw diskann::ANNException("Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ int chunk_offsets_index = 2;
+ if (use_old_filetype)
+ {
+ chunk_offsets_index = 3;
+ }
+#ifdef EXEC_ENV_OLS
+ diskann::load_bin<uint32_t>(files, pq_table_file, chunk_offsets, nr, nc, file_offset_data[chunk_offsets_index]);
+#else
+ diskann::load_bin<uint32_t>(pq_table_file, chunk_offsets, nr, nc, file_offset_data[chunk_offsets_index]);
+#endif
+
+ if (nc != 1 || (nr != num_chunks + 1 && num_chunks != 0))
+ {
+ diskann::cerr << "Error loading chunk offsets file. numc: " << nc << " (should be 1). numr: " << nr
+ << " (should be " << num_chunks + 1 << " or 0 if we need to infer)" << std::endl;
+ throw diskann::ANNException("Error loading chunk offsets file", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ this->n_chunks = nr - 1;
+ diskann::cout << "Loaded PQ Pivots: #ctrs: " << NUM_PQ_CENTROIDS << ", #dims: " << this->ndims
+ << ", #chunks: " << this->n_chunks << std::endl;
+
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(rotmat_file))
+ {
+ diskann::load_bin<float>(files, rotmat_file, (float *&)rotmat_tr, nr, nc);
+#else
+ if (file_exists(rotmat_file))
+ {
+ diskann::load_bin<float>(rotmat_file, rotmat_tr, nr, nc);
+#endif
+ if (nr != this->ndims || nc != this->ndims)
+ {
+ diskann::cerr << "Error loading rotation matrix file" << std::endl;
+ throw diskann::ANNException("Error loading rotation matrix file", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ use_rotation = true;
+ }
+
+ // alloc and compute transpose
+ tables_tr = new float[256 * this->ndims];
+ for (size_t i = 0; i < 256; i++)
+ {
+ for (size_t j = 0; j < this->ndims; j++)
+ {
+ tables_tr[j * 256 + i] = tables[i * this->ndims + j];
+ }
+ }
+}
+
+void FixedChunkPQTable::load_pq_centroid_bin(IReaderWrapperSPtr reader, size_t num_chunks)
+{
+
+ uint64_t nr, nc;
+ size_t *file_offset_data;
+ diskann::load_bin<size_t>(reader, file_offset_data, nr, nc);
+ std::unique_ptr<size_t[]> tmp(file_offset_data);
+ diskann::load_bin<float>(reader, tables, nr, nc, file_offset_data[0]);
+
+ if ((nr != NUM_PQ_CENTROIDS))
+ {
+ diskann::cout << "Error reading pq_pivots file " << ". file_num_centers = " << nr
+ << " but expecting " << NUM_PQ_CENTROIDS << " centers";
+ throw diskann::ANNException("Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ this->ndims = nc;
+ diskann::load_bin<float>(reader, centroid, nr, nc, file_offset_data[1]);
+ if ((nr != this->ndims) || (nc != 1))
+ {
+ diskann::cerr << "Error reading centroids from pq_pivots file " << ". file_dim = " << nr
+ << ", file_cols = " << nc << " but expecting " << this->ndims << " entries in 1 dimension.";
+ throw diskann::ANNException("Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ diskann::load_bin<uint32_t>(reader, chunk_offsets, nr, nc, file_offset_data[2]);
+ if (nc != 1 || (nr != num_chunks + 1 && num_chunks != 0))
+ {
+ diskann::cerr << "Error loading chunk offsets file. numc: " << nc << " (should be 1). numr: " << nr
+ << " (should be " << num_chunks + 1 << " or 0 if we need to infer)" << std::endl;
+ throw diskann::ANNException("Error loading chunk offsets file", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ this->n_chunks = nr - 1;
+ diskann::cout << "Loaded PQ Pivots: #ctrs: " << NUM_PQ_CENTROIDS << ", #dims: " << this->ndims
+ << ", #chunks: " << this->n_chunks << std::endl;
+
+ // alloc and compute transpose
+ tables_tr = new float[256 * this->ndims];
+ for (size_t i = 0; i < 256; i++)
+ {
+ for (size_t j = 0; j < this->ndims; j++)
+ {
+ tables_tr[j * 256 + i] = tables[i * this->ndims + j];
+ //std::cout << tables_tr[j * 256 + i] << ",";
+ }
+ //std::cout << std::endl;
+ }
+}
+
+uint32_t FixedChunkPQTable::get_num_chunks()
+{
+ return static_cast<uint32_t>(n_chunks);
+}
+
+void FixedChunkPQTable::preprocess_query(float *query_vec)
+{
+ for (uint32_t d = 0; d < ndims; d++)
+ {
+ query_vec[d] -= centroid[d];
+ }
+ std::vector<float> tmp(ndims, 0);
+ if (use_rotation)
+ {
+ for (uint32_t d = 0; d < ndims; d++)
+ {
+ for (uint32_t d1 = 0; d1 < ndims; d1++)
+ {
+ tmp[d] += query_vec[d1] * rotmat_tr[d1 * ndims + d];
+ }
+ }
+ std::memcpy(query_vec, tmp.data(), ndims * sizeof(float));
+ }
+}
+
+// assumes pre-processed query
+void FixedChunkPQTable::populate_chunk_distances(const float *query_vec, float *dist_vec)
+{
+ memset(dist_vec, 0, 256 * n_chunks * sizeof(float));
+ // chunk wise distance computation
+ for (size_t chunk = 0; chunk < n_chunks; chunk++)
+ {
+ // sum (q-c)^2 for the dimensions associated with this chunk
+ float *chunk_dists = dist_vec + (256 * chunk);
+ for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++)
+ {
+ const float *centers_dim_vec = tables_tr + (256 * j);
+ for (size_t idx = 0; idx < 256; idx++)
+ {
+ double diff = centers_dim_vec[idx] - (query_vec[j]);
+ chunk_dists[idx] += (float)(diff * diff);
+ }
+ }
+ }
+}
+
+float FixedChunkPQTable::l2_distance(const float *query_vec, uint8_t *base_vec)
+{
+ float res = 0;
+ for (size_t chunk = 0; chunk < n_chunks; chunk++)
+ {
+ for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++)
+ {
+ const float *centers_dim_vec = tables_tr + (256 * j);
+ float diff = centers_dim_vec[base_vec[chunk]] - (query_vec[j]);
+ res += diff * diff;
+ }
+ }
+ return res;
+}
+
+float FixedChunkPQTable::inner_product(const float *query_vec, uint8_t *base_vec)
+{
+ float res = 0;
+ for (size_t chunk = 0; chunk < n_chunks; chunk++)
+ {
+ for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++)
+ {
+ const float *centers_dim_vec = tables_tr + (256 * j);
+ float diff = centers_dim_vec[base_vec[chunk]] * query_vec[j]; // assumes centroid is 0 to
+ // prevent translation errors
+ res += diff;
+ }
+ }
+ return -res; // returns negative value to simulate distances (max -> min
+ // conversion)
+}
+
+// assumes no rotation is involved
+void FixedChunkPQTable::inflate_vector(uint8_t *base_vec, float *out_vec)
+{
+ for (size_t chunk = 0; chunk < n_chunks; chunk++)
+ {
+ for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++)
+ {
+ const float *centers_dim_vec = tables_tr + (256 * j);
+ out_vec[j] = centers_dim_vec[base_vec[chunk]] + centroid[j];
+ }
+ }
+}
+
+void FixedChunkPQTable::populate_chunk_inner_products(const float *query_vec, float *dist_vec)
+{
+ memset(dist_vec, 0, 256 * n_chunks * sizeof(float));
+ // chunk wise distance computation
+ for (size_t chunk = 0; chunk < n_chunks; chunk++)
+ {
+ // sum (q-c)^2 for the dimensions associated with this chunk
+ float *chunk_dists = dist_vec + (256 * chunk);
+ for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++)
+ {
+ const float *centers_dim_vec = tables_tr + (256 * j);
+ for (size_t idx = 0; idx < 256; idx++)
+ {
+ double prod = centers_dim_vec[idx] * query_vec[j]; // assumes that we are not
+ // shifting the vectors to
+ // mean zero, i.e., centroid
+ // array should be all zeros
+ chunk_dists[idx] -= (float)prod; // returning negative to keep the search code
+ // clean (max inner product vs min distance)
+ }
+ }
+ }
+}
+
+void aggregate_coords(const std::vector<uint32_t> &ids, const uint8_t *all_coords, const size_t ndims, uint8_t *out)
+{
+ for (size_t i = 0; i < ids.size(); i++)
+ {
+ memcpy(out + i * ndims, all_coords + ids[i] * ndims, ndims * sizeof(uint8_t));
+ }
+}
+
+void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
+ std::vector<float> &dists_out)
+{
+ //_mm_prefetch((char*) dists_out, _MM_HINT_T0);
+ _mm_prefetch((char *)pq_ids, _MM_HINT_T0);
+ _mm_prefetch((char *)(pq_ids + 64), _MM_HINT_T0);
+ _mm_prefetch((char *)(pq_ids + 128), _MM_HINT_T0);
+ dists_out.clear();
+ dists_out.resize(n_pts, 0);
+ for (size_t chunk = 0; chunk < pq_nchunks; chunk++)
+ {
+ const float *chunk_dists = pq_dists + 256 * chunk;
+ if (chunk < pq_nchunks - 1)
+ {
+ _mm_prefetch((char *)(chunk_dists + 256), _MM_HINT_T0);
+ }
+ for (size_t idx = 0; idx < n_pts; idx++)
+ {
+ uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk];
+ dists_out[idx] += chunk_dists[pq_centerid];
+ }
+ }
+}
+
+// Need to replace calls to these functions with calls to vector& based
+// functions above
+void aggregate_coords(const uint32_t *ids, const size_t n_ids, const uint8_t *all_coords, const size_t ndims,
+ uint8_t *out)
+{
+ for (size_t i = 0; i < n_ids; i++)
+ {
+ memcpy(out + i * ndims, all_coords + ids[i] * ndims, ndims * sizeof(uint8_t));
+ }
+}
+
+void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
+ float *dists_out)
+{
+ _mm_prefetch((char *)dists_out, _MM_HINT_T0);
+ _mm_prefetch((char *)pq_ids, _MM_HINT_T0);
+ _mm_prefetch((char *)(pq_ids + 64), _MM_HINT_T0);
+ _mm_prefetch((char *)(pq_ids + 128), _MM_HINT_T0);
+ memset(dists_out, 0, n_pts * sizeof(float));
+ for (size_t chunk = 0; chunk < pq_nchunks; chunk++)
+ {
+ const float *chunk_dists = pq_dists + 256 * chunk;
+ if (chunk < pq_nchunks - 1)
+ {
+ _mm_prefetch((char *)(chunk_dists + 256), _MM_HINT_T0);
+ }
+ for (size_t idx = 0; idx < n_pts; idx++)
+ {
+ uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk];
+ dists_out[idx] += chunk_dists[pq_centerid];
+ }
+ }
+}
+
+// generate_pq_pivots_simplified is a simplified version of generate_pq_pivots.
+// Input is provided in the in-memory buffer train_data.
+// Output is stored in the in-memory buffer pivot_data_vector.
+// Simplification is based on the following assumptions:
+// dim % num_pq_chunks == 0
+// num_centers == 256 by default
+// KMEANS_ITERS_FOR_PQ == 15 by default
+// make_zero_mean is false by default.
+// These assumptions allow to make the function much simpler and avoid storing
+// array of chunk_offsets and centroids.
+// The compiler pragma for multi-threading support is removed from this implementation
+// for the purpose of integration into systems that strictly control resource allocation.
+int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim, size_t num_pq_chunks,
+ std::vector<float> &pivot_data_vector)
+{
+ if (num_pq_chunks > dim || dim % num_pq_chunks != 0)
+ {
+ return -1;
+ }
+
+ const size_t num_centers = 256;
+ const size_t cur_chunk_size = dim / num_pq_chunks;
+ const uint32_t KMEANS_ITERS_FOR_PQ = 15;
+
+ pivot_data_vector.resize(num_centers * dim);
+ std::vector<float> cur_pivot_data_vector(num_centers * cur_chunk_size);
+ std::vector<float> cur_data_vector(num_train * cur_chunk_size);
+ std::vector<uint32_t> closest_center_vector(num_train);
+
+ float *pivot_data = &pivot_data_vector[0];
+ float *cur_pivot_data = &cur_pivot_data_vector[0];
+ float *cur_data = &cur_data_vector[0];
+ uint32_t *closest_center = &closest_center_vector[0];
+
+ for (size_t i = 0; i < num_pq_chunks; i++)
+ {
+ size_t chunk_offset = cur_chunk_size * i;
+
+ for (int32_t j = 0; j < num_train; j++)
+ {
+ std::memcpy(cur_data + j * cur_chunk_size, train_data + j * dim + chunk_offset,
+ cur_chunk_size * sizeof(float));
+ }
+
+ kmeans::kmeanspp_selecting_pivots(cur_data, num_train, cur_chunk_size, cur_pivot_data, num_centers);
+
+ kmeans::run_lloyds(cur_data, num_train, cur_chunk_size, cur_pivot_data, num_centers, KMEANS_ITERS_FOR_PQ, NULL,
+ closest_center);
+
+ for (uint64_t j = 0; j < num_centers; j++)
+ {
+ std::memcpy(pivot_data + j * dim + chunk_offset, cur_pivot_data + j * cur_chunk_size,
+ cur_chunk_size * sizeof(float));
+ }
+ }
+
+ return 0;
+}
+
+// given training data in train_data of dimensions num_train * dim, generate
+// PQ pivots using k-means algorithm to partition the co-ordinates into
+// num_pq_chunks (if it divides dimension, else rounded) chunks, and runs
+// k-means in each chunk to compute the PQ pivots and stores in bin format in
+// file pq_pivots_path as a s num_centers*dim floating point binary file
+int generate_pq_pivots(const float *const passed_train_data, size_t num_train, uint32_t dim, uint32_t num_centers,
+ uint32_t num_pq_chunks, uint32_t max_k_means_reps, std::string pq_pivots_path,
+ bool make_zero_mean)
+{
+ if (num_pq_chunks > dim)
+ {
+ diskann::cout << " Error: number of chunks more than dimension" << std::endl;
+ return -1;
+ }
+
+ std::unique_ptr<float[]> train_data = std::make_unique<float[]>(num_train * dim);
+ std::memcpy(train_data.get(), passed_train_data, num_train * dim * sizeof(float));
+
+ std::unique_ptr<float[]> full_pivot_data;
+
+ if (file_exists(pq_pivots_path))
+ {
+ size_t file_dim, file_num_centers;
+ diskann::load_bin<float>(pq_pivots_path, full_pivot_data, file_num_centers, file_dim, METADATA_SIZE);
+ if (file_dim == dim && file_num_centers == num_centers)
+ {
+ diskann::cout << "PQ pivot file exists. Not generating again" << std::endl;
+ return -1;
+ }
+ }
+
+ // Calculate centroid and center the training data
+ std::unique_ptr<float[]> centroid = std::make_unique<float[]>(dim);
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ centroid[d] = 0;
+ }
+ if (make_zero_mean)
+ { // If we use L2 distance, there is an option to
+ // translate all vectors to make them centered and
+ // then compute PQ. This needs to be set to false
+ // when using PQ for MIPS as such translations dont
+ // preserve inner products.
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ for (uint64_t p = 0; p < num_train; p++)
+ {
+ centroid[d] += train_data[p * dim + d];
+ }
+ centroid[d] /= num_train;
+ }
+
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ for (uint64_t p = 0; p < num_train; p++)
+ {
+ train_data[p * dim + d] -= centroid[d];
+ }
+ }
+ }
+
+ std::vector<uint32_t> chunk_offsets;
+
+ size_t low_val = (size_t)std::floor((double)dim / (double)num_pq_chunks);
+ size_t high_val = (size_t)std::ceil((double)dim / (double)num_pq_chunks);
+ size_t max_num_high = dim - (low_val * num_pq_chunks);
+ size_t cur_num_high = 0;
+ size_t cur_bin_threshold = high_val;
+
+ std::vector<std::vector<uint32_t>> bin_to_dims(num_pq_chunks);
+ tsl::robin_map<uint32_t, uint32_t> dim_to_bin;
+ std::vector<float> bin_loads(num_pq_chunks, 0);
+
+ // Process dimensions not inserted by previous loop
+ for (uint32_t d = 0; d < dim; d++)
+ {
+ if (dim_to_bin.find(d) != dim_to_bin.end())
+ continue;
+ auto cur_best = num_pq_chunks + 1;
+ float cur_best_load = std::numeric_limits<float>::max();
+ for (uint32_t b = 0; b < num_pq_chunks; b++)
+ {
+ if (bin_loads[b] < cur_best_load && bin_to_dims[b].size() < cur_bin_threshold)
+ {
+ cur_best = b;
+ cur_best_load = bin_loads[b];
+ }
+ }
+ bin_to_dims[cur_best].push_back(d);
+ if (bin_to_dims[cur_best].size() == high_val)
+ {
+ cur_num_high++;
+ if (cur_num_high == max_num_high)
+ cur_bin_threshold = low_val;
+ }
+ }
+
+ chunk_offsets.clear();
+ chunk_offsets.push_back(0);
+
+ for (uint32_t b = 0; b < num_pq_chunks; b++)
+ {
+ if (b > 0)
+ chunk_offsets.push_back(chunk_offsets[b - 1] + (uint32_t)bin_to_dims[b - 1].size());
+ }
+ chunk_offsets.push_back(dim);
+
+ full_pivot_data.reset(new float[num_centers * dim]);
+
+ for (size_t i = 0; i < num_pq_chunks; i++)
+ {
+ size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i];
+
+ if (cur_chunk_size == 0)
+ continue;
+ std::unique_ptr<float[]> cur_pivot_data = std::make_unique<float[]>(num_centers * cur_chunk_size);
+ std::unique_ptr<float[]> cur_data = std::make_unique<float[]>(num_train * cur_chunk_size);
+ std::unique_ptr<uint32_t[]> closest_center = std::make_unique<uint32_t[]>(num_train);
+
+ diskann::cout << "Processing chunk " << i << " with dimensions [" << chunk_offsets[i] << ", "
+ << chunk_offsets[i + 1] << ")" << std::endl;
+
+#pragma omp parallel for schedule(static, 65536)
+ for (int64_t j = 0; j < (int64_t)num_train; j++)
+ {
+ std::memcpy(cur_data.get() + j * cur_chunk_size, train_data.get() + j * dim + chunk_offsets[i],
+ cur_chunk_size * sizeof(float));
+ }
+
+ kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers);
+
+ kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers,
+ max_k_means_reps, NULL, closest_center.get());
+
+ for (uint64_t j = 0; j < num_centers; j++)
+ {
+ std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i], cur_pivot_data.get() + j * cur_chunk_size,
+ cur_chunk_size * sizeof(float));
+ }
+ }
+
+ std::vector<size_t> cumul_bytes(4, 0);
+ cumul_bytes[0] = METADATA_SIZE;
+ cumul_bytes[1] = cumul_bytes[0] + diskann::save_bin<float>(pq_pivots_path.c_str(), full_pivot_data.get(),
+ (size_t)num_centers, dim, cumul_bytes[0]);
+ cumul_bytes[2] = cumul_bytes[1] +
+ diskann::save_bin<float>(pq_pivots_path.c_str(), centroid.get(), (size_t)dim, 1, cumul_bytes[1]);
+ cumul_bytes[3] = cumul_bytes[2] + diskann::save_bin<uint32_t>(pq_pivots_path.c_str(), chunk_offsets.data(),
+ chunk_offsets.size(), 1, cumul_bytes[2]);
+ diskann::save_bin<size_t>(pq_pivots_path.c_str(), cumul_bytes.data(), cumul_bytes.size(), 1, 0);
+
+ diskann::cout << "Saved pq pivot data to " << pq_pivots_path << " of size " << cumul_bytes[cumul_bytes.size() - 1]
+ << "B." << std::endl;
+
+ return 0;
+}
+
+
+int generate_pq_pivots(const float *const passed_train_data, size_t num_train, uint32_t dim, uint32_t num_centers,
+ uint32_t num_pq_chunks, uint32_t max_k_means_reps, std::stringstream &pq_pivots_stream,
+ bool make_zero_mean)
+{
+ if (num_pq_chunks > dim)
+ {
+ diskann::cout << " Error: number of chunks more than dimension" << std::endl;
+ return -1;
+ }
+
+ std::unique_ptr<float[]> train_data = std::make_unique<float[]>(num_train * dim);
+ std::memcpy(train_data.get(), passed_train_data, num_train * dim * sizeof(float));
+
+ std::unique_ptr<float[]> full_pivot_data;
+
+ // Calculate centroid and center the training data
+ std::unique_ptr<float[]> centroid = std::make_unique<float[]>(dim);
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ centroid[d] = 0;
+ }
+ if (make_zero_mean)
+ { // If we use L2 distance, there is an option to
+ // translate all vectors to make them centered and
+ // then compute PQ. This needs to be set to false
+ // when using PQ for MIPS as such translations dont
+ // preserve inner products.
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ for (uint64_t p = 0; p < num_train; p++)
+ {
+ centroid[d] += train_data[p * dim + d];
+ }
+ centroid[d] /= num_train;
+ }
+
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ for (uint64_t p = 0; p < num_train; p++)
+ {
+ train_data[p * dim + d] -= centroid[d];
+ }
+ }
+ }
+
+ std::vector<uint32_t> chunk_offsets;
+
+ size_t low_val = (size_t)std::floor((double)dim / (double)num_pq_chunks);
+ size_t high_val = (size_t)std::ceil((double)dim / (double)num_pq_chunks);
+ size_t max_num_high = dim - (low_val * num_pq_chunks);
+ size_t cur_num_high = 0;
+ size_t cur_bin_threshold = high_val;
+
+ std::vector<std::vector<uint32_t>> bin_to_dims(num_pq_chunks);
+ tsl::robin_map<uint32_t, uint32_t> dim_to_bin;
+ std::vector<float> bin_loads(num_pq_chunks, 0);
+
+ // Process dimensions not inserted by previous loop
+ for (uint32_t d = 0; d < dim; d++)
+ {
+ if (dim_to_bin.find(d) != dim_to_bin.end())
+ continue;
+ auto cur_best = num_pq_chunks + 1;
+ float cur_best_load = std::numeric_limits<float>::max();
+ for (uint32_t b = 0; b < num_pq_chunks; b++)
+ {
+ if (bin_loads[b] < cur_best_load && bin_to_dims[b].size() < cur_bin_threshold)
+ {
+ cur_best = b;
+ cur_best_load = bin_loads[b];
+ }
+ }
+ bin_to_dims[cur_best].push_back(d);
+ if (bin_to_dims[cur_best].size() == high_val)
+ {
+ cur_num_high++;
+ if (cur_num_high == max_num_high)
+ cur_bin_threshold = low_val;
+ }
+ }
+
+ chunk_offsets.clear();
+ chunk_offsets.push_back(0);
+
+ for (uint32_t b = 0; b < num_pq_chunks; b++)
+ {
+ if (b > 0)
+ chunk_offsets.push_back(chunk_offsets[b - 1] + (uint32_t)bin_to_dims[b - 1].size());
+ }
+ chunk_offsets.push_back(dim);
+
+ full_pivot_data.reset(new float[num_centers * dim]);
+
+ for (size_t i = 0; i < num_pq_chunks; i++)
+ {
+ size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i];
+
+ if (cur_chunk_size == 0)
+ continue;
+ std::unique_ptr<float[]> cur_pivot_data = std::make_unique<float[]>(num_centers * cur_chunk_size);
+ std::unique_ptr<float[]> cur_data = std::make_unique<float[]>(num_train * cur_chunk_size);
+ std::unique_ptr<uint32_t[]> closest_center = std::make_unique<uint32_t[]>(num_train);
+
+ diskann::cout << "Processing chunk " << i << " with dimensions [" << chunk_offsets[i] << ", "
+ << chunk_offsets[i + 1] << ")" << std::endl;
+
+#pragma omp parallel for schedule(static, 65536)
+ for (int64_t j = 0; j < (int64_t)num_train; j++)
+ {
+ std::memcpy(cur_data.get() + j * cur_chunk_size, train_data.get() + j * dim + chunk_offsets[i],
+ cur_chunk_size * sizeof(float));
+ }
+
+ kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers);
+
+ kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers,
+ max_k_means_reps, NULL, closest_center.get());
+
+ for (uint64_t j = 0; j < num_centers; j++)
+ {
+ std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i], cur_pivot_data.get() + j * cur_chunk_size,
+ cur_chunk_size * sizeof(float));
+ }
+ }
+ std::vector<size_t> cumul_bytes(4, 0);
+ cumul_bytes[0] = cumul_bytes.size() * sizeof(size_t) + + 2 * sizeof(uint32_t);
+ cumul_bytes[1] = cumul_bytes[0] + (size_t)num_centers * dim * sizeof(float) + + 2 * sizeof(uint32_t);
+ cumul_bytes[2] = cumul_bytes[1] + (size_t)dim * 1 * sizeof(float) + + 2 * sizeof(uint32_t);
+ cumul_bytes[3] = cumul_bytes[2] + (size_t)chunk_offsets.size() * 1 * sizeof(uint32_t) + + 2 * sizeof(uint32_t);
+ diskann::save_bin<size_t>(pq_pivots_stream, cumul_bytes.data(), cumul_bytes.size(), 1, 0);
+ diskann::save_bin<float>(pq_pivots_stream, full_pivot_data.get(),
+ (size_t)num_centers, dim, cumul_bytes[0]);
+ diskann::save_bin<float>(pq_pivots_stream, centroid.get(), (size_t)dim, 1, cumul_bytes[1]);
+ diskann::save_bin<uint32_t>(pq_pivots_stream, chunk_offsets.data(),
+ chunk_offsets.size(), 1, cumul_bytes[2]);
+ pq_pivots_stream.seekp(0, pq_pivots_stream.beg);
+ return 0;
+}
+
+int generate_opq_pivots(const float *passed_train_data, size_t num_train, uint32_t dim, uint32_t num_centers,
+ uint32_t num_pq_chunks, std::string opq_pivots_path, bool make_zero_mean)
+{
+ if (num_pq_chunks > dim)
+ {
+ diskann::cout << " Error: number of chunks more than dimension" << std::endl;
+ return -1;
+ }
+
+ std::unique_ptr<float[]> train_data = std::make_unique<float[]>(num_train * dim);
+ std::memcpy(train_data.get(), passed_train_data, num_train * dim * sizeof(float));
+
+ std::unique_ptr<float[]> rotated_train_data = std::make_unique<float[]>(num_train * dim);
+ std::unique_ptr<float[]> rotated_and_quantized_train_data = std::make_unique<float[]>(num_train * dim);
+
+ std::unique_ptr<float[]> full_pivot_data;
+
+ // rotation matrix for OPQ
+ std::unique_ptr<float[]> rotmat_tr;
+
+ // matrices for SVD
+ std::unique_ptr<float[]> Umat = std::make_unique<float[]>(dim * dim);
+ std::unique_ptr<float[]> Vmat_T = std::make_unique<float[]>(dim * dim);
+ std::unique_ptr<float[]> singular_values = std::make_unique<float[]>(dim);
+ std::unique_ptr<float[]> correlation_matrix = std::make_unique<float[]>(dim * dim);
+
+ // Calculate centroid and center the training data
+ std::unique_ptr<float[]> centroid = std::make_unique<float[]>(dim);
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ centroid[d] = 0;
+ }
+ if (make_zero_mean)
+ { // If we use L2 distance, there is an option to
+ // translate all vectors to make them centered and
+ // then compute PQ. This needs to be set to false
+ // when using PQ for MIPS as such translations dont
+ // preserve inner products.
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ for (uint64_t p = 0; p < num_train; p++)
+ {
+ centroid[d] += train_data[p * dim + d];
+ }
+ centroid[d] /= num_train;
+ }
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ for (uint64_t p = 0; p < num_train; p++)
+ {
+ train_data[p * dim + d] -= centroid[d];
+ }
+ }
+ }
+
+ std::vector<uint32_t> chunk_offsets;
+
+ size_t low_val = (size_t)std::floor((double)dim / (double)num_pq_chunks);
+ size_t high_val = (size_t)std::ceil((double)dim / (double)num_pq_chunks);
+ size_t max_num_high = dim - (low_val * num_pq_chunks);
+ size_t cur_num_high = 0;
+ size_t cur_bin_threshold = high_val;
+
+ std::vector<std::vector<uint32_t>> bin_to_dims(num_pq_chunks);
+ tsl::robin_map<uint32_t, uint32_t> dim_to_bin;
+ std::vector<float> bin_loads(num_pq_chunks, 0);
+
+ // Process dimensions not inserted by previous loop
+ for (uint32_t d = 0; d < dim; d++)
+ {
+ if (dim_to_bin.find(d) != dim_to_bin.end())
+ continue;
+ auto cur_best = num_pq_chunks + 1;
+ float cur_best_load = std::numeric_limits<float>::max();
+ for (uint32_t b = 0; b < num_pq_chunks; b++)
+ {
+ if (bin_loads[b] < cur_best_load && bin_to_dims[b].size() < cur_bin_threshold)
+ {
+ cur_best = b;
+ cur_best_load = bin_loads[b];
+ }
+ }
+ bin_to_dims[cur_best].push_back(d);
+ if (bin_to_dims[cur_best].size() == high_val)
+ {
+ cur_num_high++;
+ if (cur_num_high == max_num_high)
+ cur_bin_threshold = low_val;
+ }
+ }
+
+ chunk_offsets.clear();
+ chunk_offsets.push_back(0);
+
+ for (uint32_t b = 0; b < num_pq_chunks; b++)
+ {
+ if (b > 0)
+ chunk_offsets.push_back(chunk_offsets[b - 1] + (uint32_t)bin_to_dims[b - 1].size());
+ }
+ chunk_offsets.push_back(dim);
+
+ full_pivot_data.reset(new float[num_centers * dim]);
+ rotmat_tr.reset(new float[dim * dim]);
+
+ std::memset(rotmat_tr.get(), 0, dim * dim * sizeof(float));
+ for (uint32_t d1 = 0; d1 < dim; d1++)
+ *(rotmat_tr.get() + d1 * dim + d1) = 1;
+
+ for (uint32_t rnd = 0; rnd < MAX_OPQ_ITERS; rnd++)
+ {
+ // rotate the training data using the current rotation matrix
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)num_train, (MKL_INT)dim, (MKL_INT)dim, 1.0f,
+ train_data.get(), (MKL_INT)dim, rotmat_tr.get(), (MKL_INT)dim, 0.0f, rotated_train_data.get(),
+ (MKL_INT)dim);
+
+ // compute the PQ pivots on the rotated space
+ for (size_t i = 0; i < num_pq_chunks; i++)
+ {
+ size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i];
+
+ if (cur_chunk_size == 0)
+ continue;
+ std::unique_ptr<float[]> cur_pivot_data = std::make_unique<float[]>(num_centers * cur_chunk_size);
+ std::unique_ptr<float[]> cur_data = std::make_unique<float[]>(num_train * cur_chunk_size);
+ std::unique_ptr<uint32_t[]> closest_center = std::make_unique<uint32_t[]>(num_train);
+
+ diskann::cout << "Processing chunk " << i << " with dimensions [" << chunk_offsets[i] << ", "
+ << chunk_offsets[i + 1] << ")" << std::endl;
+
+#pragma omp parallel for schedule(static, 65536)
+ for (int64_t j = 0; j < (int64_t)num_train; j++)
+ {
+ std::memcpy(cur_data.get() + j * cur_chunk_size, rotated_train_data.get() + j * dim + chunk_offsets[i],
+ cur_chunk_size * sizeof(float));
+ }
+
+ if (rnd == 0)
+ {
+ kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(),
+ num_centers);
+ }
+ else
+ {
+ for (uint64_t j = 0; j < num_centers; j++)
+ {
+ std::memcpy(cur_pivot_data.get() + j * cur_chunk_size,
+ full_pivot_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float));
+ }
+ }
+
+ uint32_t num_lloyds_iters = 8;
+ kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers,
+ num_lloyds_iters, NULL, closest_center.get());
+
+ for (uint64_t j = 0; j < num_centers; j++)
+ {
+ std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i],
+ cur_pivot_data.get() + j * cur_chunk_size, cur_chunk_size * sizeof(float));
+ }
+
+ for (size_t j = 0; j < num_train; j++)
+ {
+ std::memcpy(rotated_and_quantized_train_data.get() + j * dim + chunk_offsets[i],
+ cur_pivot_data.get() + (size_t)closest_center[j] * cur_chunk_size,
+ cur_chunk_size * sizeof(float));
+ }
+ }
+
+ // compute the correlation matrix between the original data and the
+ // quantized data to compute the new rotation
+ cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, (MKL_INT)dim, (MKL_INT)dim, (MKL_INT)num_train, 1.0f,
+ train_data.get(), (MKL_INT)dim, rotated_and_quantized_train_data.get(), (MKL_INT)dim, 0.0f,
+ correlation_matrix.get(), (MKL_INT)dim);
+
+ // compute the SVD of the correlation matrix to help determine the new
+ // rotation matrix
+ uint32_t errcode = (uint32_t)LAPACKE_sgesdd(LAPACK_ROW_MAJOR, 'A', (MKL_INT)dim, (MKL_INT)dim,
+ correlation_matrix.get(), (MKL_INT)dim, singular_values.get(),
+ Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim);
+
+ if (errcode > 0)
+ {
+ std::cout << "SVD failed to converge." << std::endl;
+ exit(-1);
+ }
+
+ // compute the new rotation matrix from the singular vectors as R^T = U
+ // V^T
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)dim, (MKL_INT)dim, (MKL_INT)dim, 1.0f,
+ Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim, 0.0f, rotmat_tr.get(), (MKL_INT)dim);
+ }
+
+ std::vector<size_t> cumul_bytes(4, 0);
+ cumul_bytes[0] = METADATA_SIZE;
+ cumul_bytes[1] = cumul_bytes[0] + diskann::save_bin<float>(opq_pivots_path.c_str(), full_pivot_data.get(),
+ (size_t)num_centers, dim, cumul_bytes[0]);
+ cumul_bytes[2] = cumul_bytes[1] +
+ diskann::save_bin<float>(opq_pivots_path.c_str(), centroid.get(), (size_t)dim, 1, cumul_bytes[1]);
+ cumul_bytes[3] = cumul_bytes[2] + diskann::save_bin<uint32_t>(opq_pivots_path.c_str(), chunk_offsets.data(),
+ chunk_offsets.size(), 1, cumul_bytes[2]);
+ diskann::save_bin<size_t>(opq_pivots_path.c_str(), cumul_bytes.data(), cumul_bytes.size(), 1, 0);
+
+ diskann::cout << "Saved opq pivot data to " << opq_pivots_path << " of size " << cumul_bytes[cumul_bytes.size() - 1]
+ << "B." << std::endl;
+
+ std::string rotmat_path = opq_pivots_path + "_rotation_matrix.bin";
+ diskann::save_bin<float>(rotmat_path.c_str(), rotmat_tr.get(), dim, dim);
+
+ return 0;
+}
+
+
+
+int generate_opq_pivots(const float *passed_train_data, size_t num_train, uint32_t dim, uint32_t num_centers,
+ uint32_t num_pq_chunks, std::stringstream opq_pivots_stream, bool make_zero_mean)
+{
+ if (num_pq_chunks > dim)
+ {
+ diskann::cout << " Error: number of chunks more than dimension" << std::endl;
+ return -1;
+ }
+
+ std::unique_ptr<float[]> train_data = std::make_unique<float[]>(num_train * dim);
+ std::memcpy(train_data.get(), passed_train_data, num_train * dim * sizeof(float));
+
+ std::unique_ptr<float[]> rotated_train_data = std::make_unique<float[]>(num_train * dim);
+ std::unique_ptr<float[]> rotated_and_quantized_train_data = std::make_unique<float[]>(num_train * dim);
+
+ std::unique_ptr<float[]> full_pivot_data;
+
+ // rotation matrix for OPQ
+ std::unique_ptr<float[]> rotmat_tr;
+
+ // matrices for SVD
+ std::unique_ptr<float[]> Umat = std::make_unique<float[]>(dim * dim);
+ std::unique_ptr<float[]> Vmat_T = std::make_unique<float[]>(dim * dim);
+ std::unique_ptr<float[]> singular_values = std::make_unique<float[]>(dim);
+ std::unique_ptr<float[]> correlation_matrix = std::make_unique<float[]>(dim * dim);
+
+ // Calculate centroid and center the training data
+ std::unique_ptr<float[]> centroid = std::make_unique<float[]>(dim);
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ centroid[d] = 0;
+ }
+ if (make_zero_mean)
+ { // If we use L2 distance, there is an option to
+ // translate all vectors to make them centered and
+ // then compute PQ. This needs to be set to false
+ // when using PQ for MIPS as such translations dont
+ // preserve inner products.
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ for (uint64_t p = 0; p < num_train; p++)
+ {
+ centroid[d] += train_data[p * dim + d];
+ }
+ centroid[d] /= num_train;
+ }
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ for (uint64_t p = 0; p < num_train; p++)
+ {
+ train_data[p * dim + d] -= centroid[d];
+ }
+ }
+ }
+
+ std::vector<uint32_t> chunk_offsets;
+
+ size_t low_val = (size_t)std::floor((double)dim / (double)num_pq_chunks);
+ size_t high_val = (size_t)std::ceil((double)dim / (double)num_pq_chunks);
+ size_t max_num_high = dim - (low_val * num_pq_chunks);
+ size_t cur_num_high = 0;
+ size_t cur_bin_threshold = high_val;
+
+ std::vector<std::vector<uint32_t>> bin_to_dims(num_pq_chunks);
+ tsl::robin_map<uint32_t, uint32_t> dim_to_bin;
+ std::vector<float> bin_loads(num_pq_chunks, 0);
+
+ // Process dimensions not inserted by previous loop
+ for (uint32_t d = 0; d < dim; d++)
+ {
+ if (dim_to_bin.find(d) != dim_to_bin.end())
+ continue;
+ auto cur_best = num_pq_chunks + 1;
+ float cur_best_load = std::numeric_limits<float>::max();
+ for (uint32_t b = 0; b < num_pq_chunks; b++)
+ {
+ if (bin_loads[b] < cur_best_load && bin_to_dims[b].size() < cur_bin_threshold)
+ {
+ cur_best = b;
+ cur_best_load = bin_loads[b];
+ }
+ }
+ bin_to_dims[cur_best].push_back(d);
+ if (bin_to_dims[cur_best].size() == high_val)
+ {
+ cur_num_high++;
+ if (cur_num_high == max_num_high)
+ cur_bin_threshold = low_val;
+ }
+ }
+
+ chunk_offsets.clear();
+ chunk_offsets.push_back(0);
+
+ for (uint32_t b = 0; b < num_pq_chunks; b++)
+ {
+ if (b > 0)
+ chunk_offsets.push_back(chunk_offsets[b - 1] + (uint32_t)bin_to_dims[b - 1].size());
+ }
+ chunk_offsets.push_back(dim);
+
+ full_pivot_data.reset(new float[num_centers * dim]);
+ rotmat_tr.reset(new float[dim * dim]);
+
+ std::memset(rotmat_tr.get(), 0, dim * dim * sizeof(float));
+ for (uint32_t d1 = 0; d1 < dim; d1++)
+ *(rotmat_tr.get() + d1 * dim + d1) = 1;
+
+ for (uint32_t rnd = 0; rnd < MAX_OPQ_ITERS; rnd++)
+ {
+ // rotate the training data using the current rotation matrix
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)num_train, (MKL_INT)dim, (MKL_INT)dim, 1.0f,
+ train_data.get(), (MKL_INT)dim, rotmat_tr.get(), (MKL_INT)dim, 0.0f, rotated_train_data.get(),
+ (MKL_INT)dim);
+
+ // compute the PQ pivots on the rotated space
+ for (size_t i = 0; i < num_pq_chunks; i++)
+ {
+ size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i];
+
+ if (cur_chunk_size == 0)
+ continue;
+ std::unique_ptr<float[]> cur_pivot_data = std::make_unique<float[]>(num_centers * cur_chunk_size);
+ std::unique_ptr<float[]> cur_data = std::make_unique<float[]>(num_train * cur_chunk_size);
+ std::unique_ptr<uint32_t[]> closest_center = std::make_unique<uint32_t[]>(num_train);
+
+ diskann::cout << "Processing chunk " << i << " with dimensions [" << chunk_offsets[i] << ", "
+ << chunk_offsets[i + 1] << ")" << std::endl;
+
+#pragma omp parallel for schedule(static, 65536)
+ for (int64_t j = 0; j < (int64_t)num_train; j++)
+ {
+ std::memcpy(cur_data.get() + j * cur_chunk_size, rotated_train_data.get() + j * dim + chunk_offsets[i],
+ cur_chunk_size * sizeof(float));
+ }
+
+ if (rnd == 0)
+ {
+ kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(),
+ num_centers);
+ }
+ else
+ {
+ for (uint64_t j = 0; j < num_centers; j++)
+ {
+ std::memcpy(cur_pivot_data.get() + j * cur_chunk_size,
+ full_pivot_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float));
+ }
+ }
+
+ uint32_t num_lloyds_iters = 8;
+ kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers,
+ num_lloyds_iters, NULL, closest_center.get());
+
+ for (uint64_t j = 0; j < num_centers; j++)
+ {
+ std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i],
+ cur_pivot_data.get() + j * cur_chunk_size, cur_chunk_size * sizeof(float));
+ }
+
+ for (size_t j = 0; j < num_train; j++)
+ {
+ std::memcpy(rotated_and_quantized_train_data.get() + j * dim + chunk_offsets[i],
+ cur_pivot_data.get() + (size_t)closest_center[j] * cur_chunk_size,
+ cur_chunk_size * sizeof(float));
+ }
+ }
+
+ // compute the correlation matrix between the original data and the
+ // quantized data to compute the new rotation
+ cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, (MKL_INT)dim, (MKL_INT)dim, (MKL_INT)num_train, 1.0f,
+ train_data.get(), (MKL_INT)dim, rotated_and_quantized_train_data.get(), (MKL_INT)dim, 0.0f,
+ correlation_matrix.get(), (MKL_INT)dim);
+
+ // compute the SVD of the correlation matrix to help determine the new
+ // rotation matrix
+ uint32_t errcode = (uint32_t)LAPACKE_sgesdd(LAPACK_ROW_MAJOR, 'A', (MKL_INT)dim, (MKL_INT)dim,
+ correlation_matrix.get(), (MKL_INT)dim, singular_values.get(),
+ Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim);
+
+ if (errcode > 0)
+ {
+ std::cout << "SVD failed to converge." << std::endl;
+ exit(-1);
+ }
+
+ // compute the new rotation matrix from the singular vectors as R^T = U
+ // V^T
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)dim, (MKL_INT)dim, (MKL_INT)dim, 1.0f,
+ Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim, 0.0f, rotmat_tr.get(), (MKL_INT)dim);
+ }
+
+ std::vector<size_t> cumul_bytes(4, 0);
+ cumul_bytes[0] = METADATA_SIZE;
+ cumul_bytes[1] = cumul_bytes[0] + diskann::save_bin<float>(opq_pivots_stream, full_pivot_data.get(),
+ (size_t)num_centers, dim, cumul_bytes[0]);
+ cumul_bytes[2] = cumul_bytes[1] +
+ diskann::save_bin<float>(opq_pivots_stream, centroid.get(), (size_t)dim, 1, cumul_bytes[1]);
+ cumul_bytes[3] = cumul_bytes[2] + diskann::save_bin<uint32_t>(opq_pivots_stream, chunk_offsets.data(),
+ chunk_offsets.size(), 1, cumul_bytes[2]);
+ diskann::save_bin<size_t>(opq_pivots_stream, cumul_bytes.data(), cumul_bytes.size(), 1, 0);
+
+ // diskann::cout << "Saved opq pivot data to " << opq_pivots_path << " of size " << cumul_bytes[cumul_bytes.size() - 1]
+ // << "B." << std::endl;
+
+ // std::string rotmat_path = opq_pivots_path + "_rotation_matrix.bin";
+ // diskann::save_bin<float>(rotmat_path.c_str(), rotmat_tr.get(), dim, dim);
+
+ return 0;
+}
+
+// generate_pq_data_from_pivots_simplified is a simplified version of generate_pq_data_from_pivots.
+// Input is provided in the in-memory buffers data and pivot_data.
+// Output is stored in the in-memory buffer pq.
+// Simplification is based on the following assumptions:
+// supporting only float data type
+// dim % num_pq_chunks == 0, which results in a fixed chunk_size
+// num_centers == 256 by default
+// make_zero_mean is false by default.
+// These assumptions allow to make the function much simpler and avoid using
+// array of chunk_offsets and centroids.
+// The compiler pragma for multi-threading support is removed from this implementation
+// for the purpose of integration into systems that strictly control resource allocation.
+int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, const float *pivot_data,
+ const size_t pivots_num, const size_t dim, const size_t num_pq_chunks,
+ std::vector<uint8_t> &pq)
+{
+ if (num_pq_chunks == 0 || num_pq_chunks > dim || dim % num_pq_chunks != 0)
+ {
+ return -1;
+ }
+
+ const size_t num_centers = 256;
+ const size_t chunk_size = dim / num_pq_chunks;
+
+ if (pivots_num != num_centers * dim)
+ {
+ return -1;
+ }
+
+ pq.resize(num * num_pq_chunks);
+
+ std::vector<float> cur_pivot_vector(num_centers * chunk_size);
+ std::vector<float> cur_data_vector(num * chunk_size);
+ std::vector<uint32_t> closest_center_vector(num);
+
+ float *cur_pivot_data = &cur_pivot_vector[0];
+ float *cur_data = &cur_data_vector[0];
+ uint32_t *closest_center = &closest_center_vector[0];
+
+ for (size_t i = 0; i < num_pq_chunks; i++)
+ {
+ const size_t chunk_offset = chunk_size * i;
+
+ for (int j = 0; j < num_centers; j++)
+ {
+ std::memcpy(cur_pivot_data + j * chunk_size, pivot_data + j * dim + chunk_offset,
+ chunk_size * sizeof(float));
+ }
+
+ for (int j = 0; j < num; j++)
+ {
+ for (size_t k = 0; k < chunk_size; k++)
+ {
+ cur_data[j * chunk_size + k] = data[j * dim + chunk_offset + k];
+ }
+ }
+
+ math_utils::compute_closest_centers(cur_data, num, chunk_size, cur_pivot_data, num_centers, 1, closest_center);
+
+ for (int j = 0; j < num; j++)
+ {
+ assert(closest_center[j] < num_centers);
+ pq[j * num_pq_chunks + i] = closest_center[j];
+ }
+ }
+
+ return 0;
+}
+
+// streams the base file (data_file), and computes the closest centers in each
+// chunk to generate the compressed data_file and stores it in
+// pq_compressed_vectors_path.
+// If the numbber of centers is < 256, it stores as byte vector, else as
+// 4-byte vector in binary format.
+template <typename T>
+int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, uint32_t num_pq_chunks,
+ const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path,
+ bool use_opq)
+{
+ size_t read_blk_size = 64 * 1024 * 1024;
+ cached_ifstream base_reader(data_file, read_blk_size);
+ uint32_t npts32;
+ uint32_t basedim32;
+ base_reader.read((char *)&npts32, sizeof(uint32_t));
+ base_reader.read((char *)&basedim32, sizeof(uint32_t));
+ size_t num_points = npts32;
+ size_t dim = basedim32;
+
+ std::unique_ptr<float[]> full_pivot_data;
+ std::unique_ptr<float[]> rotmat_tr;
+ std::unique_ptr<float[]> centroid;
+ std::unique_ptr<uint32_t[]> chunk_offsets;
+
+ std::string inflated_pq_file = pq_compressed_vectors_path + "_inflated.bin";
+
+ if (!file_exists(pq_pivots_path))
+ {
+ std::cout << "ERROR: PQ k-means pivot file not found" << std::endl;
+ throw diskann::ANNException("PQ k-means pivot file not found", -1);
+ }
+ else
+ {
+ size_t nr, nc;
+ std::unique_ptr<size_t[]> file_offset_data;
+
+ diskann::load_bin<size_t>(pq_pivots_path.c_str(), file_offset_data, nr, nc, 0);
+
+ if (nr != 4)
+ {
+ diskann::cout << "Error reading pq_pivots file " << pq_pivots_path
+ << ". Offsets dont contain correct metadata, # offsets = " << nr << ", but expecting 4.";
+ throw diskann::ANNException("Error reading pq_pivots file at offsets data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ diskann::load_bin<float>(pq_pivots_path.c_str(), full_pivot_data, nr, nc, file_offset_data[0]);
+
+ if ((nr != num_centers) || (nc != dim))
+ {
+ diskann::cout << "Error reading pq_pivots file " << pq_pivots_path << ". file_num_centers = " << nr
+ << ", file_dim = " << nc << " but expecting " << num_centers << " centers in " << dim
+ << " dimensions.";
+ throw diskann::ANNException("Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ diskann::load_bin<float>(pq_pivots_path.c_str(), centroid, nr, nc, file_offset_data[1]);
+
+ if ((nr != dim) || (nc != 1))
+ {
+ diskann::cout << "Error reading pq_pivots file " << pq_pivots_path << ". file_dim = " << nr
+ << ", file_cols = " << nc << " but expecting " << dim << " entries in 1 dimension.";
+ throw diskann::ANNException("Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ diskann::load_bin<uint32_t>(pq_pivots_path.c_str(), chunk_offsets, nr, nc, file_offset_data[2]);
+
+ if (nr != (uint64_t)num_pq_chunks + 1 || nc != 1)
+ {
+ diskann::cout << "Error reading pq_pivots file at chunk offsets; file has nr=" << nr << ",nc=" << nc
+ << ", expecting nr=" << num_pq_chunks + 1 << ", nc=1." << std::endl;
+ throw diskann::ANNException("Error reading pq_pivots file at chunk offsets.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ if (use_opq)
+ {
+ std::string rotmat_path = pq_pivots_path + "_rotation_matrix.bin";
+ diskann::load_bin<float>(rotmat_path.c_str(), rotmat_tr, nr, nc);
+ if (nr != (uint64_t)dim || nc != dim)
+ {
+ diskann::cout << "Error reading rotation matrix file." << std::endl;
+ throw diskann::ANNException("Error reading rotation matrix file.", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+
+ diskann::cout << "Loaded PQ pivot information" << std::endl;
+ }
+
+ std::ofstream compressed_file_writer(pq_compressed_vectors_path, std::ios::binary);
+ uint32_t num_pq_chunks_u32 = num_pq_chunks;
+
+ compressed_file_writer.write((char *)&num_points, sizeof(uint32_t));
+ compressed_file_writer.write((char *)&num_pq_chunks_u32, sizeof(uint32_t));
+
+ size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE;
+
+#ifdef SAVE_INFLATED_PQ
+ std::ofstream inflated_file_writer(inflated_pq_file, std::ios::binary);
+ inflated_file_writer.write((char *)&num_points, sizeof(uint32_t));
+ inflated_file_writer.write((char *)&basedim32, sizeof(uint32_t));
+
+ std::unique_ptr<float[]> block_inflated_base = std::make_unique<float[]>(block_size * dim);
+ std::memset(block_inflated_base.get(), 0, block_size * dim * sizeof(float));
+#endif
+
+ std::unique_ptr<uint32_t[]> block_compressed_base =
+ std::make_unique<uint32_t[]>(block_size * (size_t)num_pq_chunks);
+ std::memset(block_compressed_base.get(), 0, block_size * (size_t)num_pq_chunks * sizeof(uint32_t));
+
+ std::unique_ptr<T[]> block_data_T = std::make_unique<T[]>(block_size * dim);
+ std::unique_ptr<float[]> block_data_float = std::make_unique<float[]>(block_size * dim);
+ std::unique_ptr<float[]> block_data_tmp = std::make_unique<float[]>(block_size * dim);
+
+ size_t num_blocks = DIV_ROUND_UP(num_points, block_size);
+
+ for (size_t block = 0; block < num_blocks; block++)
+ {
+ size_t start_id = block * block_size;
+ size_t end_id = (std::min)((block + 1) * block_size, num_points);
+ size_t cur_blk_size = end_id - start_id;
+
+ base_reader.read((char *)(block_data_T.get()), sizeof(T) * (cur_blk_size * dim));
+ diskann::convert_types<T, float>(block_data_T.get(), block_data_tmp.get(), cur_blk_size, dim);
+
+ diskann::cout << "Processing points [" << start_id << ", " << end_id << ").." << std::flush;
+
+ for (size_t p = 0; p < cur_blk_size; p++)
+ {
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ block_data_tmp[p * dim + d] -= centroid[d];
+ }
+ }
+
+ for (size_t p = 0; p < cur_blk_size; p++)
+ {
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ block_data_float[p * dim + d] = block_data_tmp[p * dim + d];
+ }
+ }
+
+ if (use_opq)
+ {
+ // rotate the current block with the trained rotation matrix before
+ // PQ
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)cur_blk_size, (MKL_INT)dim, (MKL_INT)dim,
+ 1.0f, block_data_float.get(), (MKL_INT)dim, rotmat_tr.get(), (MKL_INT)dim, 0.0f,
+ block_data_tmp.get(), (MKL_INT)dim);
+ std::memcpy(block_data_float.get(), block_data_tmp.get(), cur_blk_size * dim * sizeof(float));
+ }
+
+ for (size_t i = 0; i < num_pq_chunks; i++)
+ {
+ size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i];
+ if (cur_chunk_size == 0)
+ continue;
+
+ std::unique_ptr<float[]> cur_pivot_data = std::make_unique<float[]>(num_centers * cur_chunk_size);
+ std::unique_ptr<float[]> cur_data = std::make_unique<float[]>(cur_blk_size * cur_chunk_size);
+ std::unique_ptr<uint32_t[]> closest_center = std::make_unique<uint32_t[]>(cur_blk_size);
+
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t j = 0; j < (int64_t)cur_blk_size; j++)
+ {
+ for (size_t k = 0; k < cur_chunk_size; k++)
+ cur_data[j * cur_chunk_size + k] = block_data_float[j * dim + chunk_offsets[i] + k];
+ }
+
+#pragma omp parallel for schedule(static, 1)
+ for (int64_t j = 0; j < (int64_t)num_centers; j++)
+ {
+ std::memcpy(cur_pivot_data.get() + j * cur_chunk_size,
+ full_pivot_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float));
+ }
+
+ math_utils::compute_closest_centers(cur_data.get(), cur_blk_size, cur_chunk_size, cur_pivot_data.get(),
+ num_centers, 1, closest_center.get());
+
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t j = 0; j < (int64_t)cur_blk_size; j++)
+ {
+ block_compressed_base[j * num_pq_chunks + i] = closest_center[j];
+#ifdef SAVE_INFLATED_PQ
+ for (size_t k = 0; k < cur_chunk_size; k++)
+ block_inflated_base[j * dim + chunk_offsets[i] + k] =
+ cur_pivot_data[closest_center[j] * cur_chunk_size + k] + centroid[chunk_offsets[i] + k];
+#endif
+ }
+ }
+
+ if (num_centers > 256)
+ {
+ compressed_file_writer.write((char *)(block_compressed_base.get()),
+ cur_blk_size * num_pq_chunks * sizeof(uint32_t));
+ }
+ else
+ {
+ std::unique_ptr<uint8_t[]> pVec = std::make_unique<uint8_t[]>(cur_blk_size * num_pq_chunks);
+ diskann::convert_types<uint32_t, uint8_t>(block_compressed_base.get(), pVec.get(), cur_blk_size,
+ num_pq_chunks);
+ compressed_file_writer.write((char *)(pVec.get()), cur_blk_size * num_pq_chunks * sizeof(uint8_t));
+ }
+#ifdef SAVE_INFLATED_PQ
+ inflated_file_writer.write((char *)(block_inflated_base.get()), cur_blk_size * dim * sizeof(float));
+#endif
+ diskann::cout << ".done." << std::endl;
+ }
+// Gopal. Splitting diskann_dll into separate DLLs for search and build.
+// This code should only be available in the "build" DLL.
+#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
+ MallocExtension::instance()->ReleaseFreeMemory();
+#endif
+ compressed_file_writer.close();
+#ifdef SAVE_INFLATED_PQ
+ inflated_file_writer.close();
+#endif
+ return 0;
+}
+
+
+template <typename T>
+int generate_pq_data_from_pivots(std::stringstream &data_stream, unsigned num_centers, unsigned num_pq_chunks,
+ std::stringstream &pq_pivots_stream, std::stringstream &pq_compressed_stream,
+ bool use_opq)
+{
+ uint32_t npts32;
+ uint32_t basedim32;
+ data_stream.seekg(0, data_stream.beg);
+ data_stream.read((char *)&npts32, sizeof(uint32_t));
+ data_stream.read((char *)&basedim32, sizeof(uint32_t));
+ size_t num_points = npts32;
+ size_t dim = basedim32;
+
+ std::unique_ptr<float[]> full_pivot_data;
+ std::unique_ptr<float[]> rotmat_tr;
+ std::unique_ptr<float[]> centroid;
+ std::unique_ptr<uint32_t[]> chunk_offsets;
+
+ size_t nr, nc;
+ std::unique_ptr<size_t[]> file_offset_data;
+
+ //SampleStringStreamReaderWrapperSPtr pq_pivots(new SampleStringStreamReaderWrapper(pq_pivots_stream));
+ diskann::load_bin<size_t>(pq_pivots_stream, file_offset_data, nr, nc, 0);
+
+
+ if (nr != 4)
+ {
+ diskann::cout << "Error reading pq_pivots file "
+ << ". Offsets dont contain correct metadata, # offsets = " << nr << ", but expecting 4.";
+ throw diskann::ANNException("Error reading pq_pivots file at offsets data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ diskann::load_bin<float>(pq_pivots_stream, full_pivot_data, nr, nc, file_offset_data[0]);
+
+ if ((nr != num_centers) || (nc != dim))
+ {
+ diskann::cout << "Error reading pq_pivots. file_num_centers = " << nr
+ << ", file_dim = " << nc << " but expecting " << num_centers << " centers in " << dim
+ << " dimensions.";
+ throw diskann::ANNException("Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ diskann::load_bin<float>(pq_pivots_stream, centroid, nr, nc, file_offset_data[1]);
+
+ if ((nr != dim) || (nc != 1))
+ {
+ diskann::cout << "Error reading pq_pivots file file_dim = " << nr
+ << ", file_cols = " << nc << " but expecting " << dim << " entries in 1 dimension.";
+ throw diskann::ANNException("Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ diskann::load_bin<uint32_t>(pq_pivots_stream, chunk_offsets, nr, nc, file_offset_data[2]);
+
+ if (nr != (uint64_t)num_pq_chunks + 1 || nc != 1)
+ {
+ diskann::cout << "Error reading pq_pivots file at chunk offsets; file has nr=" << nr << ",nc=" << nc
+ << ", expecting nr=" << num_pq_chunks + 1 << ", nc=1." << std::endl;
+ throw diskann::ANNException("Error reading pq_pivots file at chunk offsets.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ diskann::cout << "Loaded PQ pivot information" << std::endl;
+
+
+ uint32_t num_pq_chunks_u32 = num_pq_chunks;
+
+ pq_compressed_stream.write((char *)&num_points, sizeof(uint32_t));
+ pq_compressed_stream.write((char *)&num_pq_chunks_u32, sizeof(uint32_t));
+
+ size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE;
+
+ std::unique_ptr<uint32_t[]> block_compressed_base =
+ std::make_unique<uint32_t[]>(block_size * (size_t)num_pq_chunks);
+ std::memset(block_compressed_base.get(), 0, block_size * (size_t)num_pq_chunks * sizeof(uint32_t));
+
+ std::unique_ptr<T[]> block_data_T = std::make_unique<T[]>(block_size * dim);
+ std::unique_ptr<float[]> block_data_float = std::make_unique<float[]>(block_size * dim);
+ std::unique_ptr<float[]> block_data_tmp = std::make_unique<float[]>(block_size * dim);
+
+ size_t num_blocks = DIV_ROUND_UP(num_points, block_size);
+
+ for (size_t block = 0; block < num_blocks; block++)
+ {
+ size_t start_id = block * block_size;
+ size_t end_id = (std::min)((block + 1) * block_size, num_points);
+ size_t cur_blk_size = end_id - start_id;
+
+ data_stream.read((char *)(block_data_T.get()), sizeof(T) * (cur_blk_size * dim));
+ diskann::convert_types<T, float>(block_data_T.get(), block_data_tmp.get(), cur_blk_size, dim);
+
+ diskann::cout << "Processing points [" << start_id << ", " << end_id << ").." << std::flush;
+
+ for (size_t p = 0; p < cur_blk_size; p++)
+ {
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ block_data_tmp[p * dim + d] -= centroid[d];
+ }
+ }
+
+ for (size_t p = 0; p < cur_blk_size; p++)
+ {
+ for (uint64_t d = 0; d < dim; d++)
+ {
+ block_data_float[p * dim + d] = block_data_tmp[p * dim + d];
+ }
+ }
+
+ if (use_opq)
+ {
+ // rotate the current block with the trained rotation matrix before
+ // PQ
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)cur_blk_size, (MKL_INT)dim, (MKL_INT)dim,
+ 1.0f, block_data_float.get(), (MKL_INT)dim, rotmat_tr.get(), (MKL_INT)dim, 0.0f,
+ block_data_tmp.get(), (MKL_INT)dim);
+ std::memcpy(block_data_float.get(), block_data_tmp.get(), cur_blk_size * dim * sizeof(float));
+ }
+
+ for (size_t i = 0; i < num_pq_chunks; i++)
+ {
+ size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i];
+ if (cur_chunk_size == 0)
+ continue;
+
+ std::unique_ptr<float[]> cur_pivot_data = std::make_unique<float[]>(num_centers * cur_chunk_size);
+ std::unique_ptr<float[]> cur_data = std::make_unique<float[]>(cur_blk_size * cur_chunk_size);
+ std::unique_ptr<uint32_t[]> closest_center = std::make_unique<uint32_t[]>(cur_blk_size);
+
+#pragma omp parallel for schedule(static, 8192)
+ for (int64_t j = 0; j < (int64_t)cur_blk_size; j++)
+ {
+ for (size_t k = 0; k < cur_chunk_size; k++)
+ cur_data[j * cur_chunk_size + k] = block_data_float[j * dim + chunk_offsets[i] + k];
+ }
+
+#pragma omp parallel for schedule(static, 1)
+ for (int64_t j = 0; j < (int64_t)num_centers; j++)
+ {
+ std::memcpy(cur_pivot_data.get() + j * cur_chunk_size,
+ full_pivot_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float));
+ }
+
+ math_utils::compute_closest_centers(cur_data.get(), cur_blk_size, cur_chunk_size, cur_pivot_data.get(),
+ num_centers, 1, closest_center.get());
+
+//#pragma omp parallel for schedule(static, 8192)
+ for (int64_t j = 0; j < (int64_t)cur_blk_size; j++)
+ {
+ block_compressed_base[j * num_pq_chunks + i] = closest_center[j];
+ }
+ }
+
+ if (num_centers > 256)
+ {
+ pq_compressed_stream.write((char *)(block_compressed_base.get()),
+ cur_blk_size * num_pq_chunks * sizeof(uint32_t));
+ }
+ else
+ {
+ std::unique_ptr<uint8_t[]> pVec = std::make_unique<uint8_t[]>(cur_blk_size * num_pq_chunks);
+ diskann::convert_types<uint32_t, uint8_t>(block_compressed_base.get(), pVec.get(), cur_blk_size,
+ num_pq_chunks);
+ pq_compressed_stream.write((char *)(pVec.get()), cur_blk_size * num_pq_chunks * sizeof(uint8_t));
+ }
+ diskann::cout << ".done." << std::endl;
+ }
+// Gopal. Splitting diskann_dll into separate DLLs for search and build.
+// This code should only be available in the "build" DLL.
+#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
+ MallocExtension::instance()->ReleaseFreeMemory();
+#endif
+ return 0;
+}
+
+
+template <typename T>
+void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path,
+ const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric,
+ const double p_val, size_t &disk_pq_dims)
+{
+ size_t train_size, train_dim;
+ float *train_data;
+
+ // instantiates train_data with random sample updates train_size
+ gen_random_slice<T>(data_file_to_use.c_str(), p_val, train_data, train_size, train_dim);
+ diskann::cout << "Training data with " << train_size << " samples loaded1." << std::endl;
+
+ if (disk_pq_dims > train_dim)
+ disk_pq_dims = train_dim;
+
+ std::cout << "Compressing base for disk-PQ into " << disk_pq_dims << " chunks " << std::endl;
+ generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, 256, (uint32_t)disk_pq_dims, NUM_KMEANS_REPS_PQ,
+ disk_pq_pivots_path, false);
+ if (compareMetric == diskann::Metric::INNER_PRODUCT)
+ generate_pq_data_from_pivots<float>(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path,
+ disk_pq_compressed_vectors_path);
+ else
+ generate_pq_data_from_pivots<T>(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path,
+ disk_pq_compressed_vectors_path);
+
+ delete[] train_data;
+}
+
+template <typename T>
+void generate_quantized_data(const std::string &data_file_to_use, const std::string &pq_pivots_path,
+ const std::string &pq_compressed_vectors_path, diskann::Metric compareMetric,
+ const double p_val, const size_t num_pq_chunks, const bool use_opq,
+ const std::string &codebook_prefix)
+{
+ size_t train_size, train_dim;
+ float *train_data;
+ if (!file_exists(codebook_prefix))
+ {
+ // instantiates train_data with random sample updates train_size
+ gen_random_slice<T>(data_file_to_use.c_str(), p_val, train_data, train_size, train_dim);
+ diskann::cout << "Training data with " << train_size << " samples loaded2." << ",p_val:" <<p_val << std::endl;
+
+ bool make_zero_mean = true;
+ if (compareMetric == diskann::Metric::INNER_PRODUCT)
+ make_zero_mean = false;
+ if (use_opq) // we also do not center the data for OPQ
+ make_zero_mean = false;
+
+ if (!use_opq)
+ {
+ generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks,
+ NUM_KMEANS_REPS_PQ, pq_pivots_path, make_zero_mean);
+ }
+ else
+ {
+ generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks,
+ pq_pivots_path, make_zero_mean);
+ }
+ delete[] train_data;
+ }
+ else
+ {
+ diskann::cout << "Skip Training with predefined pivots in: " << pq_pivots_path << std::endl;
+ }
+ generate_pq_data_from_pivots<T>(data_file_to_use, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, pq_pivots_path,
+ pq_compressed_vectors_path, use_opq);
+}
+
+template <typename T>
+void generate_quantized_data(std::stringstream &data_stream, std::stringstream &pq_pivots_stream,
+ std::stringstream &pq_compressed_stream, diskann::Metric compareMetric,
+ const double p_val, const size_t num_pq_chunks, const bool use_opq,
+ const std::string &codebook_prefix)
+{
+ size_t train_size, train_dim;
+ float *train_data;
+ // instantiates train_data with random sample updates train_size
+ gen_random_slice<T>(data_stream, p_val, train_data, train_size, train_dim);
+
+ diskann::cout << "Training data with " << train_size << " samples loaded2." << ",p_val:" <<p_val << std::endl;
+
+ bool make_zero_mean = true;
+ if (compareMetric == diskann::Metric::INNER_PRODUCT)
+ make_zero_mean = false;
+ if (use_opq) // we also do not center the data for OPQ
+ make_zero_mean = false;
+
+ // if (!use_opq)
+ // {
+ generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks,
+ (uint32_t)NUM_KMEANS_REPS_PQ, pq_pivots_stream, make_zero_mean);
+ // }
+ // else
+ // {
+ // generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks,
+ // pq_pivots_stream, make_zero_mean);
+ // }
+ delete[] train_data;
+ generate_pq_data_from_pivots<T>(data_stream, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, pq_pivots_stream,
+ pq_compressed_stream, use_opq);
+}
+
+// Instantations of supported templates
+
+
+template DISKANN_DLLEXPORT void generate_quantized_data<float>(std::stringstream & data_stream, std::stringstream &pq_pivots_stream,
+ std::stringstream &pq_compressed_stream, diskann::Metric compareMetric,
+ const double p_val, const size_t num_pq_chunks, const bool use_opq,
+ const std::string &codebook_prefix);
+
+template DISKANN_DLLEXPORT int generate_pq_data_from_pivots<float>(std::stringstream & data_stream, unsigned num_centers, unsigned num_pq_chunks,
+ std::stringstream &pq_pivots_stream, std::stringstream &pq_compressed_stream,
+ bool use_opq);
+
+template DISKANN_DLLEXPORT int generate_pq_data_from_pivots<int8_t>(const std::string &data_file, uint32_t num_centers,
+ uint32_t num_pq_chunks,
+ const std::string &pq_pivots_path,
+ const std::string &pq_compressed_vectors_path,
+ bool use_opq);
+template DISKANN_DLLEXPORT int generate_pq_data_from_pivots<uint8_t>(const std::string &data_file, uint32_t num_centers,
+ uint32_t num_pq_chunks,
+ const std::string &pq_pivots_path,
+ const std::string &pq_compressed_vectors_path,
+ bool use_opq);
+template DISKANN_DLLEXPORT int generate_pq_data_from_pivots<float>(const std::string &data_file, uint32_t num_centers,
+ uint32_t num_pq_chunks,
+ const std::string &pq_pivots_path,
+ const std::string &pq_compressed_vectors_path,
+ bool use_opq);
+
+template DISKANN_DLLEXPORT void generate_disk_quantized_data<int8_t>(const std::string &data_file_to_use,
+ const std::string &disk_pq_pivots_path,
+ const std::string &disk_pq_compressed_vectors_path,
+ diskann::Metric compareMetric, const double p_val,
+ size_t &disk_pq_dims);
+
+template DISKANN_DLLEXPORT void generate_disk_quantized_data<uint8_t>(
+ const std::string &data_file_to_use, const std::string &disk_pq_pivots_path,
+ const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val,
+ size_t &disk_pq_dims);
+
+
+
+template DISKANN_DLLEXPORT void generate_disk_quantized_data<float>(const std::string &data_file_to_use,
+ const std::string &disk_pq_pivots_path,
+ const std::string &disk_pq_compressed_vectors_path,
+ diskann::Metric compareMetric, const double p_val,
+ size_t &disk_pq_dims);
+
+template DISKANN_DLLEXPORT void generate_quantized_data<int8_t>(const std::string &data_file_to_use,
+ const std::string &pq_pivots_path,
+ const std::string &pq_compressed_vectors_path,
+ diskann::Metric compareMetric, const double p_val,
+ const size_t num_pq_chunks, const bool use_opq,
+ const std::string &codebook_prefix);
+
+template DISKANN_DLLEXPORT void generate_quantized_data<uint8_t>(const std::string &data_file_to_use,
+ const std::string &pq_pivots_path,
+ const std::string &pq_compressed_vectors_path,
+ diskann::Metric compareMetric, const double p_val,
+ const size_t num_pq_chunks, const bool use_opq,
+ const std::string &codebook_prefix);
+
+template DISKANN_DLLEXPORT void generate_quantized_data<float>(const std::string &data_file_to_use,
+ const std::string &pq_pivots_path,
+ const std::string &pq_compressed_vectors_path,
+ diskann::Metric compareMetric, const double p_val,
+ const size_t num_pq_chunks, const bool use_opq,
+ const std::string &codebook_prefix);
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/pq_data_store.cpp b/be/src/extern/diskann/src/pq_data_store.cpp
new file mode 100644
index 0000000..c47c167
--- /dev/null
+++ b/be/src/extern/diskann/src/pq_data_store.cpp
@@ -0,0 +1,260 @@
+#include <exception>
+
+#include "pq_data_store.h"
+#include "pq.h"
+#include "pq_scratch.h"
+#include "utils.h"
+#include "distance.h"
+
+namespace diskann
+{
+
+// REFACTOR TODO: Assuming that num_pq_chunks is known already. Must verify if
+// this is true.
+template <typename data_t>
+PQDataStore<data_t>::PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks,
+ std::unique_ptr<Distance<data_t>> distance_fn,
+ std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn)
+ : AbstractDataStore<data_t>(num_points, dim), _quantized_data(nullptr), _num_chunks(num_pq_chunks),
+ _distance_metric(distance_fn->get_metric())
+{
+ if (num_pq_chunks > dim)
+ {
+ throw diskann::ANNException("ERROR: num_pq_chunks > dim", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ _distance_fn = std::move(distance_fn);
+ _pq_distance_fn = std::move(pq_distance_fn);
+}
+
+template <typename data_t> PQDataStore<data_t>::~PQDataStore()
+{
+ if (_quantized_data != nullptr)
+ {
+ aligned_free(_quantized_data);
+ _quantized_data = nullptr;
+ }
+}
+
+template <typename data_t> location_t PQDataStore<data_t>::load(const std::string &filename)
+{
+ return load_impl(filename);
+}
+template <typename data_t> size_t PQDataStore<data_t>::save(const std::string &filename, const location_t num_points)
+{
+ return diskann::save_bin(filename, _quantized_data, this->capacity(), _num_chunks, 0);
+}
+
+template <typename data_t> size_t PQDataStore<data_t>::get_aligned_dim() const
+{
+ return this->get_dims();
+}
+
+// Populate quantized data from regular data.
+template <typename data_t> void PQDataStore<data_t>::populate_data(const data_t *vectors, const location_t num_pts)
+{
+ throw std::logic_error("Not implemented yet");
+}
+
+template <typename data_t> void PQDataStore<data_t>::populate_data(const std::string &filename, const size_t offset)
+{
+ if (_quantized_data != nullptr)
+ {
+ aligned_free(_quantized_data);
+ }
+
+ uint64_t file_num_points = 0, file_dim = 0;
+ get_bin_metadata(filename, file_num_points, file_dim, offset);
+ this->_capacity = (location_t)file_num_points;
+ this->_dim = file_dim;
+
+ double p_val = std::min(1.0, ((double)MAX_PQ_TRAINING_SET_SIZE / (double)file_num_points));
+
+ auto pivots_file = _pq_distance_fn->get_pivot_data_filename(filename);
+ auto compressed_file = _pq_distance_fn->get_quantized_vectors_filename(filename);
+
+ generate_quantized_data<data_t>(filename, pivots_file, compressed_file, _distance_metric, p_val, _num_chunks,
+ _pq_distance_fn->is_opq());
+
+ // REFACTOR TODO: Not sure of the alignment. Just copying from index.cpp
+ alloc_aligned(((void **)&_quantized_data), file_num_points * _num_chunks * sizeof(uint8_t), 1);
+ copy_aligned_data_from_file<uint8_t>(compressed_file.c_str(), _quantized_data, file_num_points, _num_chunks,
+ _num_chunks);
+#ifdef EXEC_ENV_OLS
+ throw ANNException("load_pq_centroid_bin should not be called when "
+ "EXEC_ENV_OLS is defined.",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+#else
+ _pq_distance_fn->load_pivot_data(pivots_file.c_str(), _num_chunks);
+#endif
+}
+
+template <typename data_t>
+void PQDataStore<data_t>::extract_data_to_bin(const std::string &filename, const location_t num_pts)
+{
+ throw std::logic_error("Not implemented yet");
+}
+
+template <typename data_t> void PQDataStore<data_t>::get_vector(const location_t i, data_t *target) const
+{
+ // REFACTOR TODO: Should we inflate the compressed vector here?
+ if (i < this->capacity())
+ {
+ throw std::logic_error("Not implemented yet.");
+ }
+ else
+ {
+ std::stringstream ss;
+ ss << "Requested vector " << i << " but only " << this->capacity() << " vectors are present";
+ throw diskann::ANNException(ss.str(), -1);
+ }
+}
+template <typename data_t> void PQDataStore<data_t>::set_vector(const location_t i, const data_t *const vector)
+{
+ // REFACTOR TODO: Should we accept a normal vector and compress here?
+ // memcpy (_data + i * _num_chunks, vector, _num_chunks * sizeof(data_t));
+ throw std::logic_error("Not implemented yet");
+}
+
+template <typename data_t> void PQDataStore<data_t>::prefetch_vector(const location_t loc)
+{
+ const uint8_t *ptr = _quantized_data + ((size_t)loc) * _num_chunks * sizeof(data_t);
+ diskann::prefetch_vector((const char *)ptr, _num_chunks * sizeof(data_t));
+}
+
+template <typename data_t>
+void PQDataStore<data_t>::move_vectors(const location_t old_location_start, const location_t new_location_start,
+ const location_t num_points)
+{
+ // REFACTOR TODO: Moving vectors is only for in-mem fresh.
+ throw std::logic_error("Not implemented yet");
+}
+
+template <typename data_t>
+void PQDataStore<data_t>::copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points)
+{
+ // REFACTOR TODO: Is the number of bytes correct?
+ memcpy(_quantized_data + to_loc * _num_chunks, _quantized_data + from_loc * _num_chunks, _num_chunks * num_points);
+}
+
+// REFACTOR TODO: Currently, we take aligned_query as parameter, but this
+// function should also do the alignment.
+template <typename data_t>
+void PQDataStore<data_t>::preprocess_query(const data_t *aligned_query, AbstractScratch<data_t> *scratch) const
+{
+ if (scratch == nullptr)
+ {
+ throw diskann::ANNException("Scratch space is null", -1);
+ }
+
+ PQScratch<data_t> *pq_scratch = scratch->pq_scratch();
+
+ if (pq_scratch == nullptr)
+ {
+ throw diskann::ANNException("PQScratch space has not been set in the scratch object.", -1);
+ }
+
+ _pq_distance_fn->preprocess_query(aligned_query, (location_t)this->get_dims(), *pq_scratch);
+}
+
+template <typename data_t> float PQDataStore<data_t>::get_distance(const data_t *query, const location_t loc) const
+{
+ throw std::logic_error("Not implemented yet");
+}
+
+template <typename data_t> float PQDataStore<data_t>::get_distance(const location_t loc1, const location_t loc2) const
+{
+ throw std::logic_error("Not implemented yet");
+}
+
+template <typename data_t>
+void PQDataStore<data_t>::get_distance(const data_t *preprocessed_query, const location_t *locations,
+ const uint32_t location_count, float *distances,
+ AbstractScratch<data_t> *scratch_space) const
+{
+ if (scratch_space == nullptr)
+ {
+ throw diskann::ANNException("Scratch space is null", -1);
+ }
+ PQScratch<data_t> *pq_scratch = scratch_space->pq_scratch();
+ if (pq_scratch == nullptr)
+ {
+ throw diskann::ANNException("PQScratch not set in scratch space.", -1);
+ }
+ diskann::aggregate_coords(locations, location_count, _quantized_data, this->_num_chunks,
+ pq_scratch->aligned_pq_coord_scratch);
+ _pq_distance_fn->preprocessed_distance(*pq_scratch, location_count, distances);
+}
+
+template <typename data_t>
+void PQDataStore<data_t>::get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
+ std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const
+{
+ if (scratch_space == nullptr)
+ {
+ throw diskann::ANNException("Scratch space is null", -1);
+ }
+ PQScratch<data_t> *pq_scratch = scratch_space->pq_scratch();
+ if (pq_scratch == nullptr)
+ {
+ throw diskann::ANNException("PQScratch not set in scratch space.", -1);
+ }
+ diskann::aggregate_coords(ids, _quantized_data, this->_num_chunks, pq_scratch->aligned_pq_coord_scratch);
+ _pq_distance_fn->preprocessed_distance(*pq_scratch, (location_t)ids.size(), distances);
+}
+
+template <typename data_t> location_t PQDataStore<data_t>::calculate_medoid() const
+{
+ // REFACTOR TODO: Must calculate this just like we do with data store.
+ size_t r = (size_t)rand() * (size_t)RAND_MAX + (size_t)rand();
+ return (uint32_t)(r % (size_t)this->capacity());
+}
+
+template <typename data_t> size_t PQDataStore<data_t>::get_alignment_factor() const
+{
+ return 1;
+}
+
+template <typename data_t> Distance<data_t> *PQDataStore<data_t>::get_dist_fn() const
+{
+ return _distance_fn.get();
+}
+
+template <typename data_t> location_t PQDataStore<data_t>::load_impl(const std::string &file_prefix)
+{
+ if (_quantized_data != nullptr)
+ {
+ aligned_free(_quantized_data);
+ }
+ auto quantized_vectors_file = _pq_distance_fn->get_quantized_vectors_filename(file_prefix);
+
+ size_t num_points;
+ load_aligned_bin(quantized_vectors_file, _quantized_data, num_points, _num_chunks, _num_chunks);
+ this->_capacity = (location_t)num_points;
+
+ auto pivots_file = _pq_distance_fn->get_pivot_data_filename(file_prefix);
+ _pq_distance_fn->load_pivot_data(pivots_file, _num_chunks);
+
+ return this->_capacity;
+}
+
+template <typename data_t> location_t PQDataStore<data_t>::expand(const location_t new_size)
+{
+ throw std::logic_error("Not implemented yet");
+}
+
+template <typename data_t> location_t PQDataStore<data_t>::shrink(const location_t new_size)
+{
+ throw std::logic_error("Not implemented yet");
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename data_t> location_t PQDataStore<data_t>::load_impl(AlignedFileReader &reader)
+{
+}
+#endif
+
+template DISKANN_DLLEXPORT class PQDataStore<int8_t>;
+template DISKANN_DLLEXPORT class PQDataStore<float>;
+template DISKANN_DLLEXPORT class PQDataStore<uint8_t>;
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/src/pq_flash_index.cpp b/be/src/extern/diskann/src/pq_flash_index.cpp
new file mode 100644
index 0000000..b71f375
--- /dev/null
+++ b/be/src/extern/diskann/src/pq_flash_index.cpp
@@ -0,0 +1,1851 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include "common_includes.h"
+
+#include "timer.h"
+#include "pq.h"
+#include "pq_scratch.h"
+#include "pq_flash_index.h"
+#include "cosine_similarity.h"
+#include "disk_utils.h"
+#include "ThreadPool.h"
+
+#ifdef _WINDOWS
+#include "windows_aligned_file_reader.h"
+#else
+#include "linux_aligned_file_reader.h"
+#endif
+
+#define READ_U64(stream, val) stream.read((char *)&val, sizeof(uint64_t))
+#define READ_U32(stream, val) stream.read((char *)&val, sizeof(uint32_t))
+#define READ_UNSIGNED(stream, val) stream.read((char *)&val, sizeof(unsigned))
+
+// sector # beyond the end of graph where data for id is present for reordering
+#define VECTOR_SECTOR_NO(id) (((uint64_t)(id)) / _nvecs_per_sector + _reorder_data_start_sector)
+
+// sector # beyond the end of graph where data for id is present for reordering
+#define VECTOR_SECTOR_OFFSET(id) ((((uint64_t)(id)) % _nvecs_per_sector) * _data_dim * sizeof(float))
+
+#include "vector/stream_wrapper.h"
+
+namespace diskann
+{
+template <typename T, typename LabelT>
+PQFlashIndex<T, LabelT>::PQFlashIndex(IReaderWrapperSPtr fileReader, diskann::Metric m)
+ : customReader(fileReader), metric(m), _thread_data(nullptr)
+{
+ diskann::Metric metric_to_invoke = m;
+ if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT)
+ {
+ if (std::is_floating_point<T>::value)
+ {
+ diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed "
+ "(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we "
+ "shall invoke an l2 distance function."
+ << std::endl;
+ metric_to_invoke = diskann::Metric::L2;
+ }
+ else
+ {
+ diskann::cerr << "WARNING: Cannot normalize integral data types."
+ << " This may result in erroneous results or poor recall."
+ << " Consider using L2 distance with integral data types." << std::endl;
+ }
+ }
+ this->_dist_cmp.reset(diskann::get_distance_function<T>(metric_to_invoke));
+ this->_dist_cmp_float.reset(diskann::get_distance_function<float>(metric_to_invoke));
+}
+
+
+template <typename T, typename LabelT>
+PQFlashIndex<T, LabelT>::PQFlashIndex(std::shared_ptr<AlignedFileReader> &fileReader, diskann::Metric m)
+ : reader(fileReader), metric(m), _thread_data(nullptr)
+{
+ diskann::Metric metric_to_invoke = m;
+ if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT)
+ {
+ if (std::is_floating_point<T>::value)
+ {
+ diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed "
+ "(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we "
+ "shall invoke an l2 distance function."
+ << std::endl;
+ metric_to_invoke = diskann::Metric::L2;
+ }
+ else
+ {
+ diskann::cerr << "WARNING: Cannot normalize integral data types."
+ << " This may result in erroneous results or poor recall."
+ << " Consider using L2 distance with integral data types." << std::endl;
+ }
+ }
+ this->_dist_cmp.reset(diskann::get_distance_function<T>(metric_to_invoke));
+ this->_dist_cmp_float.reset(diskann::get_distance_function<float>(metric_to_invoke));
+}
+
+template <typename T, typename LabelT> PQFlashIndex<T, LabelT>::~PQFlashIndex()
+{
+ if (_centroid_data != nullptr)
+ aligned_free(_centroid_data);
+ // delete backing bufs for nhood and coord cache
+ if (_nhood_cache_buf != nullptr)
+ {
+ delete[] _nhood_cache_buf;
+ diskann::aligned_free(_coord_cache_buf);
+ }
+
+ if (_load_flag)
+ {
+ diskann::cout << "Clearing scratch" << std::endl;
+ ScratchStoreManager<SSDThreadData<T>> manager(this->_thread_data);
+ manager.destroy();
+ //this->reader->deregister_all_threads();
+ //reader->close();
+ }
+ if (_pts_to_label_offsets != nullptr)
+ {
+ delete[] _pts_to_label_offsets;
+ }
+ if (_pts_to_label_counts != nullptr)
+ {
+ delete[] _pts_to_label_counts;
+ }
+ if (_pts_to_labels != nullptr)
+ {
+ delete[] _pts_to_labels;
+ }
+ if (_medoids != nullptr)
+ {
+ delete[] _medoids;
+ }
+}
+
+template <typename T, typename LabelT> inline uint64_t PQFlashIndex<T, LabelT>::get_node_sector(uint64_t node_id)
+{
+ return 1 + (_nnodes_per_sector > 0 ? node_id / _nnodes_per_sector
+ : node_id * DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN));
+}
+
+template <typename T, typename LabelT>
+inline char *PQFlashIndex<T, LabelT>::offset_to_node(char *sector_buf, uint64_t node_id)
+{
+ return sector_buf + (_nnodes_per_sector == 0 ? 0 : (node_id % _nnodes_per_sector) * _max_node_len);
+}
+
+template <typename T, typename LabelT> inline uint32_t *PQFlashIndex<T, LabelT>::offset_to_node_nhood(char *node_buf)
+{
+ return (unsigned *)(node_buf + _disk_bytes_per_point);
+}
+
+template <typename T, typename LabelT> inline T *PQFlashIndex<T, LabelT>::offset_to_node_coords(char *node_buf)
+{
+ return (T *)(node_buf);
+}
+
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve)
+{
+ diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << std::endl;
+// omp parallel for to generate unique thread IDs
+#pragma omp parallel for num_threads((int)nthreads)
+ for (int64_t thread = 0; thread < (int64_t)nthreads; thread++)
+ {
+#pragma omp critical
+ {
+ SSDThreadData<T> *data = new SSDThreadData<T>(this->_aligned_dim, visited_reserve);
+ this->reader->register_thread();
+ data->ctx = this->reader->get_ctx();
+ this->_thread_data.push(data);
+ }
+ }
+ _load_flag = true;
+}
+
+//这个threaddata主要是解决内存复用问题
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::setup_thread_data_without_ctx(uint64_t nthreads, uint64_t visited_reserve)
+{
+ diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << ", dim:" << this->_aligned_dim << ", visited_reserve:" << visited_reserve << std::endl;
+ for (int64_t thread = 0; thread < (int64_t)nthreads; thread++)
+ {
+ SSDThreadData<T> *data = new SSDThreadData<T>(this->_aligned_dim, visited_reserve);
+ this->_thread_data.push(data);
+ }
+ _load_flag = true;
+}
+
+template <typename T, typename LabelT>
+std::vector<bool> PQFlashIndex<T, LabelT>::read_nodes(const std::vector<uint32_t> &node_ids,
+ std::vector<T *> &coord_buffers,
+ std::vector<std::pair<uint32_t, uint32_t *>> &nbr_buffers)
+{
+ std::vector<AlignedRead> read_reqs;
+ std::vector<bool> retval(node_ids.size(), true);
+
+ char *buf = nullptr;
+ auto num_sectors = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);
+ alloc_aligned((void **)&buf, node_ids.size() * num_sectors * defaults::SECTOR_LEN, defaults::SECTOR_LEN);
+
+ // create read requests
+ for (size_t i = 0; i < node_ids.size(); ++i)
+ {
+ auto node_id = node_ids[i];
+
+ AlignedRead read;
+ read.len = num_sectors * defaults::SECTOR_LEN;
+ read.buf = buf + i * num_sectors * defaults::SECTOR_LEN;
+ read.offset = get_node_sector(node_id) * defaults::SECTOR_LEN;
+ read_reqs.push_back(read);
+ }
+
+ // borrow thread data and issue reads
+ ScratchStoreManager<SSDThreadData<T>> manager(this->_thread_data);
+ auto this_thread_data = manager.scratch_space();
+ IOContext &ctx = this_thread_data->ctx;
+ _batch_reader->read(read_reqs);
+
+ // copy reads into buffers
+ for (uint32_t i = 0; i < read_reqs.size(); i++)
+ {
+#if defined(_WINDOWS) && defined(USE_BING_INFRA) // this block is to handle failed reads in
+ // production settings
+ if ((*ctx.m_pRequestsStatus)[i] != IOContext::READ_SUCCESS)
+ {
+ retval[i] = false;
+ continue;
+ }
+#endif
+
+ char *node_buf = offset_to_node((char *)read_reqs[i].buf, node_ids[i]);
+
+ if (coord_buffers[i] != nullptr)
+ {
+ T *node_coords = offset_to_node_coords(node_buf);
+ memcpy(coord_buffers[i], node_coords, _disk_bytes_per_point);
+ }
+
+ if (nbr_buffers[i].second != nullptr)
+ {
+ uint32_t *node_nhood = offset_to_node_nhood(node_buf);
+ auto num_nbrs = *node_nhood;
+ nbr_buffers[i].first = num_nbrs;
+ memcpy(nbr_buffers[i].second, node_nhood + 1, num_nbrs * sizeof(uint32_t));
+ }
+ }
+
+ aligned_free(buf);
+
+ return retval;
+}
+
+template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::load_cache_list(std::vector<uint32_t> &node_list)
+{
+ diskann::cout << "Loading the cache list into memory.." << std::flush;
+ size_t num_cached_nodes = node_list.size();
+
+ // Allocate space for neighborhood cache
+ _nhood_cache_buf = new uint32_t[num_cached_nodes * (_max_degree + 1)];
+ memset(_nhood_cache_buf, 0, num_cached_nodes * (_max_degree + 1));
+
+ // Allocate space for coordinate cache
+ size_t coord_cache_buf_len = num_cached_nodes * _aligned_dim;
+ diskann::alloc_aligned((void **)&_coord_cache_buf, coord_cache_buf_len * sizeof(T), 8 * sizeof(T));
+ memset(_coord_cache_buf, 0, coord_cache_buf_len * sizeof(T));
+
+ size_t BLOCK_SIZE = 8;
+ size_t num_blocks = DIV_ROUND_UP(num_cached_nodes, BLOCK_SIZE);
+ for (size_t block = 0; block < num_blocks; block++)
+ {
+ size_t start_idx = block * BLOCK_SIZE;
+ size_t end_idx = (std::min)(num_cached_nodes, (block + 1) * BLOCK_SIZE);
+
+ // Copy offset into buffers to read into
+ std::vector<uint32_t> nodes_to_read;
+ std::vector<T *> coord_buffers;
+ std::vector<std::pair<uint32_t, uint32_t *>> nbr_buffers;
+ for (size_t node_idx = start_idx; node_idx < end_idx; node_idx++)
+ {
+ nodes_to_read.push_back(node_list[node_idx]);
+ coord_buffers.push_back(_coord_cache_buf + node_idx * _aligned_dim);
+ nbr_buffers.emplace_back(0, _nhood_cache_buf + node_idx * (_max_degree + 1));
+ }
+
+ // issue the reads
+ auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers);
+
+ // check for success and insert into the cache.
+ for (size_t i = 0; i < read_status.size(); i++)
+ {
+ if (read_status[i] == true)
+ {
+ _coord_cache.insert(std::make_pair(nodes_to_read[i], coord_buffers[i]));
+ _nhood_cache.insert(std::make_pair(nodes_to_read[i], nbr_buffers[i]));
+ }
+ }
+ }
+ diskann::cout << "..done." << std::endl;
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::generate_cache_list_from_sample_queries(MemoryMappedFiles &files, std::string sample_bin,
+ uint64_t l_search, uint64_t beamwidth,
+ uint64_t num_nodes_to_cache, uint32_t nthreads,
+ std::vector<uint32_t> &node_list)
+{
+#else
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::generate_cache_list_from_sample_queries(std::string sample_bin, uint64_t l_search,
+ uint64_t beamwidth, uint64_t num_nodes_to_cache,
+ uint32_t nthreads,
+ std::vector<uint32_t> &node_list)
+{
+#endif
+ if (num_nodes_to_cache >= this->_num_points)
+ {
+ // for small num_points and big num_nodes_to_cache, use below way to get the node_list quickly
+ node_list.resize(this->_num_points);
+ for (uint32_t i = 0; i < this->_num_points; ++i)
+ {
+ node_list[i] = i;
+ }
+ return;
+ }
+
+ this->_count_visited_nodes = true;
+ this->_node_visit_counter.clear();
+ this->_node_visit_counter.resize(this->_num_points);
+ for (uint32_t i = 0; i < _node_visit_counter.size(); i++)
+ {
+ this->_node_visit_counter[i].first = i;
+ this->_node_visit_counter[i].second = 0;
+ }
+
+ uint64_t sample_num, sample_dim, sample_aligned_dim;
+ T *samples;
+
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(sample_bin))
+ {
+ diskann::load_aligned_bin<T>(files, sample_bin, samples, sample_num, sample_dim, sample_aligned_dim);
+ }
+#else
+ if (file_exists(sample_bin))
+ {
+ diskann::load_aligned_bin<T>(sample_bin, samples, sample_num, sample_dim, sample_aligned_dim);
+ }
+#endif
+ else
+ {
+ diskann::cerr << "Sample bin file not found. Not generating cache." << std::endl;
+ return;
+ }
+
+ std::vector<uint64_t> tmp_result_ids_64(sample_num, 0);
+ std::vector<float> tmp_result_dists(sample_num, 0);
+
+ bool filtered_search = false;
+ std::vector<LabelT> random_query_filters(sample_num);
+ if (_filter_to_medoid_ids.size() != 0)
+ {
+ filtered_search = true;
+ generate_random_labels(random_query_filters, (uint32_t)sample_num, nthreads);
+ }
+
+#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads)
+ for (int64_t i = 0; i < (int64_t)sample_num; i++)
+ {
+ auto &label_for_search = random_query_filters[i];
+ // run a search on the sample query with a random label (sampled from base label distribution), and it will
+ // concurrently update the node_visit_counter to track most visited nodes. The last false is to not use the
+ // "use_reorder_data" option which enables a final reranking if the disk index itself contains only PQ data.
+ cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i,
+ tmp_result_dists.data() + i, beamwidth);
+ }
+
+ std::sort(this->_node_visit_counter.begin(), _node_visit_counter.end(),
+ [](std::pair<uint32_t, uint32_t> &left, std::pair<uint32_t, uint32_t> &right) {
+ return left.second > right.second;
+ });
+ node_list.clear();
+ node_list.shrink_to_fit();
+ num_nodes_to_cache = std::min(num_nodes_to_cache, this->_node_visit_counter.size());
+ node_list.reserve(num_nodes_to_cache);
+ for (uint64_t i = 0; i < num_nodes_to_cache; i++)
+ {
+ node_list.push_back(this->_node_visit_counter[i].first);
+ }
+ this->_count_visited_nodes = false;
+
+ diskann::aligned_free(samples);
+}
+
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::cache_bfs_levels(uint64_t num_nodes_to_cache, std::vector<uint32_t> &node_list,
+ const bool shuffle)
+{
+ std::random_device rng;
+ std::mt19937 urng(rng());
+
+ tsl::robin_set<uint32_t> node_set;
+
+ // Do not cache more than 10% of the nodes in the index
+ uint64_t tenp_nodes = (uint64_t)(std::round(this->_num_points * 0.1));
+ if (num_nodes_to_cache > tenp_nodes)
+ {
+ diskann::cout << "Reducing nodes to cache from: " << num_nodes_to_cache << " to: " << tenp_nodes
+ << "(10 percent of total nodes:" << this->_num_points << ")" << std::endl;
+ num_nodes_to_cache = tenp_nodes == 0 ? 1 : tenp_nodes;
+ }
+ diskann::cout << "Caching " << num_nodes_to_cache << "..." << std::endl;
+
+ std::unique_ptr<tsl::robin_set<uint32_t>> cur_level, prev_level;
+ cur_level = std::make_unique<tsl::robin_set<uint32_t>>();
+ prev_level = std::make_unique<tsl::robin_set<uint32_t>>();
+
+ for (uint64_t miter = 0; miter < _num_medoids && cur_level->size() < num_nodes_to_cache; miter++)
+ {
+ cur_level->insert(_medoids[miter]);
+ }
+
+ if ((_filter_to_medoid_ids.size() > 0) && (cur_level->size() < num_nodes_to_cache))
+ {
+ for (auto &x : _filter_to_medoid_ids)
+ {
+ for (auto &y : x.second)
+ {
+ cur_level->insert(y);
+ if (cur_level->size() == num_nodes_to_cache)
+ break;
+ }
+ if (cur_level->size() == num_nodes_to_cache)
+ break;
+ }
+ }
+
+ uint64_t lvl = 1;
+ uint64_t prev_node_set_size = 0;
+ while ((node_set.size() + cur_level->size() < num_nodes_to_cache) && cur_level->size() != 0)
+ {
+ // swap prev_level and cur_level
+ std::swap(prev_level, cur_level);
+ // clear cur_level
+ cur_level->clear();
+
+ std::vector<uint32_t> nodes_to_expand;
+
+ for (const uint32_t &id : *prev_level)
+ {
+ if (node_set.find(id) != node_set.end())
+ {
+ continue;
+ }
+ node_set.insert(id);
+ nodes_to_expand.push_back(id);
+ }
+
+ if (shuffle)
+ std::shuffle(nodes_to_expand.begin(), nodes_to_expand.end(), urng);
+ else
+ std::sort(nodes_to_expand.begin(), nodes_to_expand.end());
+
+ diskann::cout << "Level: " << lvl << std::flush;
+ bool finish_flag = false;
+
+ uint64_t BLOCK_SIZE = 1024;
+ uint64_t nblocks = DIV_ROUND_UP(nodes_to_expand.size(), BLOCK_SIZE);
+ for (size_t block = 0; block < nblocks && !finish_flag; block++)
+ {
+ diskann::cout << "." << std::flush;
+ size_t start = block * BLOCK_SIZE;
+ size_t end = (std::min)((block + 1) * BLOCK_SIZE, nodes_to_expand.size());
+
+ std::vector<uint32_t> nodes_to_read;
+ std::vector<T *> coord_buffers(end - start, nullptr);
+ std::vector<std::pair<uint32_t, uint32_t *>> nbr_buffers;
+
+ for (size_t cur_pt = start; cur_pt < end; cur_pt++)
+ {
+ nodes_to_read.push_back(nodes_to_expand[cur_pt]);
+ nbr_buffers.emplace_back(0, new uint32_t[_max_degree + 1]);
+ }
+
+ // issue read requests
+ auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers);
+
+ // process each nhood buf
+ for (uint32_t i = 0; i < read_status.size(); i++)
+ {
+ if (read_status[i] == false)
+ {
+ continue;
+ }
+ else
+ {
+ uint32_t nnbrs = nbr_buffers[i].first;
+ uint32_t *nbrs = nbr_buffers[i].second;
+
+ // explore next level
+ for (uint32_t j = 0; j < nnbrs && !finish_flag; j++)
+ {
+ if (node_set.find(nbrs[j]) == node_set.end())
+ {
+ cur_level->insert(nbrs[j]);
+ }
+ if (cur_level->size() + node_set.size() >= num_nodes_to_cache)
+ {
+ finish_flag = true;
+ }
+ }
+ }
+ delete[] nbr_buffers[i].second;
+ }
+ }
+
+ diskann::cout << ". #nodes: " << node_set.size() - prev_node_set_size
+ << ", #nodes thus far: " << node_set.size() << std::endl;
+ prev_node_set_size = node_set.size();
+ lvl++;
+ }
+
+ assert(node_set.size() + cur_level->size() == num_nodes_to_cache || cur_level->size() == 0);
+
+ node_list.clear();
+ node_list.reserve(node_set.size() + cur_level->size());
+ for (auto node : node_set)
+ node_list.push_back(node);
+ for (auto node : *cur_level)
+ node_list.push_back(node);
+
+ diskann::cout << "Level: " << lvl << std::flush;
+ diskann::cout << ". #nodes: " << node_list.size() - prev_node_set_size << ", #nodes thus far: " << node_list.size()
+ << std::endl;
+ diskann::cout << "done" << std::endl;
+}
+
+template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::use_medoids_data_as_centroids()
+{
+ if (_centroid_data != nullptr)
+ aligned_free(_centroid_data);
+ alloc_aligned(((void **)&_centroid_data), _num_medoids * _aligned_dim * sizeof(float), 32);
+ std::memset(_centroid_data, 0, _num_medoids * _aligned_dim * sizeof(float));
+
+ diskann::cout << "Loading centroid data from medoids vector data of " << _num_medoids << " medoid(s)" << std::endl;
+
+ std::vector<uint32_t> nodes_to_read;
+ std::vector<T *> medoid_bufs;
+ std::vector<std::pair<uint32_t, uint32_t *>> nbr_bufs;
+
+ for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++)
+ {
+ nodes_to_read.push_back(_medoids[cur_m]);
+ medoid_bufs.push_back(new T[_data_dim]);
+ nbr_bufs.emplace_back(0, nullptr);
+ }
+
+ auto read_status = read_nodes(nodes_to_read, medoid_bufs, nbr_bufs);
+
+ for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++)
+ {
+ if (read_status[cur_m] == true)
+ {
+ if (!_use_disk_index_pq)
+ {
+ for (uint32_t i = 0; i < _data_dim; i++)
+ _centroid_data[cur_m * _aligned_dim + i] = medoid_bufs[cur_m][i];
+ }
+ else
+ {
+ _disk_pq_table.inflate_vector((uint8_t *)medoid_bufs[cur_m], (_centroid_data + cur_m * _aligned_dim));
+ }
+ }
+ else
+ {
+ throw ANNException("Unable to read a medoid", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ delete[] medoid_bufs[cur_m];
+ }
+}
+
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
+ const uint32_t nthreads)
+{
+ std::random_device rd;
+ labels.clear();
+ labels.resize(num_labels);
+
+ uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1];
+ std::mt19937 gen(rd());
+ if (num_total_labels == 0)
+ {
+ std::stringstream stream;
+ stream << "No labels found in data. Not sampling random labels ";
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ std::uniform_int_distribution<uint64_t> dis(0, num_total_labels - 1);
+
+#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads)
+ for (int64_t i = 0; i < num_labels; i++)
+ {
+ uint64_t rnd_loc = dis(gen);
+ labels[i] = (LabelT)_pts_to_labels[rnd_loc];
+ }
+}
+
+template <typename T, typename LabelT>
+std::unordered_map<std::string, LabelT> PQFlashIndex<T, LabelT>::load_label_map(std::basic_istream<char> &map_reader)
+{
+ std::unordered_map<std::string, LabelT> string_to_int_mp;
+ std::string line, token;
+ LabelT token_as_num;
+ std::string label_str;
+ while (std::getline(map_reader, line))
+ {
+ std::istringstream iss(line);
+ getline(iss, token, '\t');
+ label_str = token;
+ getline(iss, token, '\t');
+ token_as_num = (LabelT)std::stoul(token);
+ string_to_int_mp[label_str] = token_as_num;
+ }
+ return string_to_int_mp;
+}
+
+template <typename T, typename LabelT>
+LabelT PQFlashIndex<T, LabelT>::get_converted_label(const std::string &filter_label)
+{
+ if (_label_map.find(filter_label) != _label_map.end())
+ {
+ return _label_map[filter_label];
+ }
+ if (_use_universal_label)
+ {
+ return _universal_filter_label;
+ }
+ std::stringstream stream;
+ stream << "Unable to find label in the Label Map";
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+}
+
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::reset_stream_for_reading(std::basic_istream<char> &infile)
+{
+ infile.clear();
+ infile.seekg(0);
+}
+
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
+ uint32_t &num_total_labels)
+{
+ num_pts = 0;
+ num_total_labels = 0;
+
+ size_t file_size = fileContent.length();
+
+ std::string label_str;
+ size_t cur_pos = 0;
+ size_t next_pos = 0;
+ while (cur_pos < file_size && cur_pos != std::string::npos)
+ {
+ next_pos = fileContent.find('\n', cur_pos);
+ if (next_pos == std::string::npos)
+ {
+ break;
+ }
+
+ size_t lbl_pos = cur_pos;
+ size_t next_lbl_pos = 0;
+ while (lbl_pos < next_pos && lbl_pos != std::string::npos)
+ {
+ next_lbl_pos = fileContent.find(',', lbl_pos);
+ if (next_lbl_pos == std::string::npos) // the last label
+ {
+ next_lbl_pos = next_pos;
+ }
+
+ num_total_labels++;
+
+ lbl_pos = next_lbl_pos + 1;
+ }
+
+ cur_pos = next_pos + 1;
+
+ num_pts++;
+ }
+
+ diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels
+ << std::endl;
+}
+
+template <typename T, typename LabelT>
+inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, LabelT label_id)
+{
+ uint32_t start_vec = _pts_to_label_offsets[point_id];
+ uint32_t num_lbls = _pts_to_label_counts[point_id];
+ bool ret_val = false;
+ for (uint32_t i = 0; i < num_lbls; i++)
+ {
+ if (_pts_to_labels[start_vec + i] == label_id)
+ {
+ ret_val = true;
+ break;
+ }
+ }
+ return ret_val;
+}
+
+template <typename T, typename LabelT>
+void PQFlashIndex<T, LabelT>::parse_label_file(std::basic_istream<char> &infile, size_t &num_points_labels)
+{
+ infile.seekg(0, std::ios::end);
+ size_t file_size = infile.tellg();
+
+ std::string buffer(file_size, ' ');
+
+ infile.seekg(0, std::ios::beg);
+ infile.read(&buffer[0], file_size);
+
+ std::string line;
+ uint32_t line_cnt = 0;
+
+ uint32_t num_pts_in_label_file;
+ uint32_t num_total_labels;
+ get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels);
+
+ _pts_to_label_offsets = new uint32_t[num_pts_in_label_file];
+ _pts_to_label_counts = new uint32_t[num_pts_in_label_file];
+ _pts_to_labels = new LabelT[num_total_labels];
+ uint32_t labels_seen_so_far = 0;
+
+ std::string label_str;
+ size_t cur_pos = 0;
+ size_t next_pos = 0;
+ while (cur_pos < file_size && cur_pos != std::string::npos)
+ {
+ next_pos = buffer.find('\n', cur_pos);
+ if (next_pos == std::string::npos)
+ {
+ break;
+ }
+
+ _pts_to_label_offsets[line_cnt] = labels_seen_so_far;
+ uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt];
+ num_lbls_in_cur_pt = 0;
+
+ size_t lbl_pos = cur_pos;
+ size_t next_lbl_pos = 0;
+ while (lbl_pos < next_pos && lbl_pos != std::string::npos)
+ {
+ next_lbl_pos = buffer.find(',', lbl_pos);
+ if (next_lbl_pos == std::string::npos) // the last label in the whole file
+ {
+ next_lbl_pos = next_pos;
+ }
+
+ if (next_lbl_pos > next_pos) // the last label in one line, just read to the end
+ {
+ next_lbl_pos = next_pos;
+ }
+
+ label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
+ if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file?
+ {
+ label_str.erase(label_str.length() - 1);
+ }
+
+ LabelT token_as_num = (LabelT)std::stoul(label_str);
+ _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num;
+ num_lbls_in_cur_pt++;
+
+ // move to next label
+ lbl_pos = next_lbl_pos + 1;
+ }
+
+ // move to next line
+ cur_pos = next_pos + 1;
+
+ if (num_lbls_in_cur_pt == 0)
+ {
+ diskann::cout << "No label found for point " << line_cnt << std::endl;
+ exit(-1);
+ }
+
+ line_cnt++;
+ }
+
+ num_points_labels = line_cnt;
+ reset_stream_for_reading(infile);
+}
+
+template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::set_universal_label(const LabelT &label)
+{
+ _use_universal_label = true;
+ _universal_filter_label = label;
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T, typename LabelT>
+int PQFlashIndex<T, LabelT>::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix)
+{
+#else
+template <typename T, typename LabelT> int PQFlashIndex<T, LabelT>::load(uint32_t num_threads, const char *index_prefix)
+{
+#endif
+ std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin";
+ std::string pq_compressed_vectors = std::string(index_prefix) + "_pq_compressed.bin";
+ std::string _disk_index_file = std::string(index_prefix) + "_disk.index";
+#ifdef EXEC_ENV_OLS
+ return load_from_separate_paths(files, num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(),
+ pq_compressed_vectors.c_str());
+#else
+ return load_from_separate_paths(num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(),
+ pq_compressed_vectors.c_str());
+#endif
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T, typename LabelT>
+int PQFlashIndex<T, LabelT>::load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
+ const char *index_filepath, const char *pivots_filepath,
+ const char *compressed_filepath)
+{
+#else
+template <typename T, typename LabelT>
+int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, const char *index_filepath,
+ const char *pivots_filepath, const char *compressed_filepath)
+{
+#endif
+ std::string pq_table_bin = pivots_filepath;
+ std::string pq_compressed_vectors = compressed_filepath;
+ std::string _disk_index_file = index_filepath;
+ std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin";
+ std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin";
+
+ std::string labels_file = std ::string(_disk_index_file) + "_labels.txt";
+ std::string labels_to_medoids = std ::string(_disk_index_file) + "_labels_to_medoids.txt";
+ std::string dummy_map_file = std ::string(_disk_index_file) + "_dummy_map.txt";
+ std::string labels_map_file = std ::string(_disk_index_file) + "_labels_map.txt";
+ size_t num_pts_in_label_file = 0;
+
+ size_t pq_file_dim, pq_file_num_centroids;
+#ifdef EXEC_ENV_OLS
+ get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE);
+#else
+ get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE);
+#endif
+
+ this->_disk_index_file = _disk_index_file;
+
+ if (pq_file_num_centroids != 256)
+ {
+ diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." << std::endl;
+ return -1;
+ }
+
+ this->_data_dim = pq_file_dim;
+ // will change later if we use PQ on disk or if we are using
+ // inner product without PQ
+ this->_disk_bytes_per_point = this->_data_dim * sizeof(T);
+ this->_aligned_dim = ROUND_UP(pq_file_dim, 8);
+
+ size_t npts_u64, nchunks_u64;
+#ifdef EXEC_ENV_OLS
+ diskann::load_bin<uint8_t>(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64);
+#else
+ diskann::load_bin<uint8_t>(pq_compressed_vectors, this->data, npts_u64, nchunks_u64);
+#endif
+
+ this->_num_points = npts_u64;
+ this->_n_chunks = nchunks_u64;
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(labels_file))
+ {
+ FileContent &content_labels = files.getContent(labels_file);
+ std::stringstream infile(std::string((const char *)content_labels._content, content_labels._size));
+#else
+ if (file_exists(labels_file))
+ {
+ std::ifstream infile(labels_file, std::ios::binary);
+ if (infile.fail())
+ {
+ throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1);
+ }
+#endif
+ parse_label_file(infile, num_pts_in_label_file);
+ assert(num_pts_in_label_file == this->_num_points);
+
+#ifndef EXEC_ENV_OLS
+ infile.close();
+#endif
+
+#ifdef EXEC_ENV_OLS
+ FileContent &content_labels_map = files.getContent(labels_map_file);
+ std::stringstream map_reader(std::string((const char *)content_labels_map._content, content_labels_map._size));
+#else
+ std::ifstream map_reader(labels_map_file);
+#endif
+ _label_map = load_label_map(map_reader);
+
+#ifndef EXEC_ENV_OLS
+ map_reader.close();
+#endif
+
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(labels_to_medoids))
+ {
+ FileContent &content_labels_to_meoids = files.getContent(labels_to_medoids);
+ std::stringstream medoid_stream(
+ std::string((const char *)content_labels_to_meoids._content, content_labels_to_meoids._size));
+#else
+ if (file_exists(labels_to_medoids))
+ {
+ std::ifstream medoid_stream(labels_to_medoids);
+ assert(medoid_stream.is_open());
+#endif
+ std::string line, token;
+
+ _filter_to_medoid_ids.clear();
+ try
+ {
+ while (std::getline(medoid_stream, line))
+ {
+ std::istringstream iss(line);
+ uint32_t cnt = 0;
+ std::vector<uint32_t> medoids;
+ LabelT label;
+ while (std::getline(iss, token, ','))
+ {
+ if (cnt == 0)
+ label = (LabelT)std::stoul(token);
+ else
+ medoids.push_back((uint32_t)stoul(token));
+ cnt++;
+ }
+ _filter_to_medoid_ids[label].swap(medoids);
+ }
+ }
+ catch (std::system_error &e)
+ {
+ throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+ std::string univ_label_file = std ::string(_disk_index_file) + "_universal_label.txt";
+
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(univ_label_file))
+ {
+ FileContent &content_univ_label = files.getContent(univ_label_file);
+ std::stringstream universal_label_reader(
+ std::string((const char *)content_univ_label._content, content_univ_label._size));
+#else
+ if (file_exists(univ_label_file))
+ {
+ std::ifstream universal_label_reader(univ_label_file);
+ assert(universal_label_reader.is_open());
+#endif
+ std::string univ_label;
+ universal_label_reader >> univ_label;
+#ifndef EXEC_ENV_OLS
+ universal_label_reader.close();
+#endif
+ LabelT label_as_num = (LabelT)std::stoul(univ_label);
+ set_universal_label(label_as_num);
+ }
+
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(dummy_map_file))
+ {
+ FileContent &content_dummy_map = files.getContent(dummy_map_file);
+ std::stringstream dummy_map_stream(
+ std::string((const char *)content_dummy_map._content, content_dummy_map._size));
+#else
+ if (file_exists(dummy_map_file))
+ {
+ std::ifstream dummy_map_stream(dummy_map_file);
+ assert(dummy_map_stream.is_open());
+#endif
+ std::string line, token;
+
+ while (std::getline(dummy_map_stream, line))
+ {
+ std::istringstream iss(line);
+ uint32_t cnt = 0;
+ uint32_t dummy_id;
+ uint32_t real_id;
+ while (std::getline(iss, token, ','))
+ {
+ if (cnt == 0)
+ dummy_id = (uint32_t)stoul(token);
+ else
+ real_id = (uint32_t)stoul(token);
+ cnt++;
+ }
+ _dummy_pts.insert(dummy_id);
+ _has_dummy_pts.insert(real_id);
+ _dummy_to_real_map[dummy_id] = real_id;
+
+ if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end())
+ _real_to_dummy_map[real_id] = std::vector<uint32_t>();
+
+ _real_to_dummy_map[real_id].emplace_back(dummy_id);
+ }
+#ifndef EXEC_ENV_OLS
+ dummy_map_stream.close();
+#endif
+ diskann::cout << "Loaded dummy map" << std::endl;
+ }
+ }
+
+#ifdef EXEC_ENV_OLS
+ _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64);
+#else
+ _pq_table.load_pq_centroid_bin(pq_table_bin.c_str(), nchunks_u64);
+#endif
+
+ diskann::cout << "Loaded PQ centroids and in-memory compressed vectors. #points: " << _num_points
+ << " #dim: " << _data_dim << " #aligned_dim: " << _aligned_dim << " #chunks: " << _n_chunks
+ << std::endl;
+
+ if (_n_chunks > MAX_PQ_CHUNKS)
+ {
+ std::stringstream stream;
+ stream << "Error loading index. Ensure that max PQ bytes for in-memory "
+ "PQ data does not exceed "
+ << MAX_PQ_CHUNKS << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ std::string disk_pq_pivots_path = this->_disk_index_file + "_pq_pivots.bin";
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(disk_pq_pivots_path))
+ {
+ _use_disk_index_pq = true;
+ // giving 0 chunks to make the _pq_table infer from the
+ // chunk_offsets file the correct value
+ _disk_pq_table.load_pq_centroid_bin(files, disk_pq_pivots_path.c_str(), 0);
+#else
+ if (file_exists(disk_pq_pivots_path))
+ {
+ _use_disk_index_pq = true;
+ // giving 0 chunks to make the _pq_table infer from the
+ // chunk_offsets file the correct value
+ _disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0);
+#endif
+ _disk_pq_n_chunks = _disk_pq_table.get_num_chunks();
+ _disk_bytes_per_point =
+ _disk_pq_n_chunks * sizeof(uint8_t); // revising disk_bytes_per_point since DISK PQ is used.
+ diskann::cout << "Disk index uses PQ data compressed down to " << _disk_pq_n_chunks << " bytes per point."
+ << std::endl;
+ }
+
+// read index metadata
+#ifdef EXEC_ENV_OLS
+ // This is a bit tricky. We have to read the header from the
+ // disk_index_file. But this is now exclusively a preserve of the
+ // DiskPriorityIO class. So, we need to estimate how many
+ // bytes are needed to store the header and read in that many using our
+ // 'standard' aligned file reader approach.
+ reader->open(_disk_index_file);
+ this->setup_thread_data(num_threads);
+ this->_max_nthreads = num_threads;
+
+ char *bytes = getHeaderBytes();
+ ContentBuf buf(bytes, HEADER_SIZE);
+ std::basic_istream<char> index_metadata(&buf);
+#else
+ std::ifstream index_metadata(_disk_index_file, std::ios::binary);
+#endif
+
+ uint32_t nr, nc; // metadata itself is stored as bin format (nr is number of
+ // metadata, nc should be 1)
+ READ_U32(index_metadata, nr);
+ READ_U32(index_metadata, nc);
+
+ uint64_t disk_nnodes;
+ uint64_t disk_ndims; // can be disk PQ dim if disk_PQ is set to true
+ READ_U64(index_metadata, disk_nnodes);
+ READ_U64(index_metadata, disk_ndims);
+
+ if (disk_nnodes != _num_points)
+ {
+ diskann::cout << "Mismatch in #points for compressed data file and disk "
+ "index file: "
+ << disk_nnodes << " vs " << _num_points << std::endl;
+ return -1;
+ }
+
+ size_t medoid_id_on_file;
+ READ_U64(index_metadata, medoid_id_on_file);
+ READ_U64(index_metadata, _max_node_len);
+ READ_U64(index_metadata, _nnodes_per_sector);
+ _max_degree = ((_max_node_len - _disk_bytes_per_point) / sizeof(uint32_t)) - 1;
+
+ if (_max_degree > defaults::MAX_GRAPH_DEGREE)
+ {
+ std::stringstream stream;
+ stream << "Error loading index. Ensure that max graph degree (R) does "
+ "not exceed "
+ << defaults::MAX_GRAPH_DEGREE << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ // setting up concept of frozen points in disk index for streaming-DiskANN
+ READ_U64(index_metadata, this->_num_frozen_points);
+ uint64_t file_frozen_id;
+ READ_U64(index_metadata, file_frozen_id);
+ if (this->_num_frozen_points == 1)
+ this->_frozen_location = file_frozen_id;
+ if (this->_num_frozen_points == 1)
+ {
+ diskann::cout << " Detected frozen point in index at location " << this->_frozen_location
+ << ". Will not output it at search time." << std::endl;
+ }
+
+ READ_U64(index_metadata, this->_reorder_data_exists);
+ if (this->_reorder_data_exists)
+ {
+ if (this->_use_disk_index_pq == false)
+ {
+ throw ANNException("Reordering is designed for used with disk PQ "
+ "compression option",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ READ_U64(index_metadata, this->_reorder_data_start_sector);
+ READ_U64(index_metadata, this->_ndims_reorder_vecs);
+ READ_U64(index_metadata, this->_nvecs_per_sector);
+ }
+
+ diskann::cout << "Disk-Index File Meta-data: ";
+ diskann::cout << "# nodes per sector: " << _nnodes_per_sector;
+ diskann::cout << ", max node len (bytes): " << _max_node_len;
+ diskann::cout << ", max node degree: " << _max_degree << std::endl;
+
+#ifdef EXEC_ENV_OLS
+ delete[] bytes;
+#else
+ index_metadata.close();
+#endif
+
+#ifndef EXEC_ENV_OLS
+ // open AlignedFileReader handle to index_file
+ std::string index_fname(_disk_index_file);
+ reader->open(index_fname);
+ this->setup_thread_data(num_threads);
+ this->_max_nthreads = num_threads;
+
+#endif
+
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(medoids_file))
+ {
+ size_t tmp_dim;
+ diskann::load_bin<uint32_t>(files, norm_file, medoids_file, _medoids, _num_medoids, tmp_dim);
+#else
+ if (file_exists(medoids_file))
+ {
+ size_t tmp_dim;
+ diskann::load_bin<uint32_t>(medoids_file, _medoids, _num_medoids, tmp_dim);
+#endif
+
+ if (tmp_dim != 1)
+ {
+ std::stringstream stream;
+ stream << "Error loading medoids file. Expected bin format of m times "
+ "1 vector of uint32_t."
+ << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+#ifdef EXEC_ENV_OLS
+ if (!files.fileExists(centroids_file))
+ {
+#else
+ if (!file_exists(centroids_file))
+ {
+#endif
+ diskann::cout << "Centroid data file not found. Using corresponding vectors "
+ "for the medoids "
+ << std::endl;
+ use_medoids_data_as_centroids();
+ }
+ else
+ {
+ size_t num_centroids, aligned_tmp_dim;
+#ifdef EXEC_ENV_OLS
+ diskann::load_aligned_bin<float>(files, centroids_file, _centroid_data, num_centroids, tmp_dim,
+ aligned_tmp_dim);
+#else
+ diskann::load_aligned_bin<float>(centroids_file, _centroid_data, num_centroids, tmp_dim, aligned_tmp_dim);
+#endif
+ if (aligned_tmp_dim != _aligned_dim || num_centroids != _num_medoids)
+ {
+ std::stringstream stream;
+ stream << "Error loading centroids data file. Expected bin format "
+ "of "
+ "m times data_dim vector of float, where m is number of "
+ "medoids "
+ "in medoids file.";
+ diskann::cerr << stream.str() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+ }
+ else
+ {
+ _num_medoids = 1;
+ _medoids = new uint32_t[1];
+ _medoids[0] = (uint32_t)(medoid_id_on_file);
+ use_medoids_data_as_centroids();
+ }
+
+ std::string norm_file = std::string(_disk_index_file) + "_max_base_norm.bin";
+
+#ifdef EXEC_ENV_OLS
+ if (files.fileExists(norm_file) && metric == diskann::Metric::INNER_PRODUCT)
+ {
+ uint64_t dumr, dumc;
+ float *norm_val;
+ diskann::load_bin<float>(files, norm_val, dumr, dumc);
+#else
+ if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT)
+ {
+ uint64_t dumr, dumc;
+ float *norm_val;
+ diskann::load_bin<float>(norm_file, norm_val, dumr, dumc);
+#endif
+ this->_max_base_norm = norm_val[0];
+ diskann::cout << "Setting re-scaling factor of base vectors to " << this->_max_base_norm << std::endl;
+ delete[] norm_val;
+ }
+ diskann::cout << "done.." << std::endl;
+ return 0;
+}
+
+
+template <typename T, typename LabelT>
+int PQFlashIndex<T, LabelT>::load(uint32_t num_threads,
+ IReaderWrapperSPtr pq_pivots_reader,
+ IReaderWrapperSPtr pq_compressed_reader,
+ IReaderWrapperSPtr vamana_index_reader,
+ IReaderWrapperSPtr disk_layout_reader,
+ IReaderWrapperSPtr tag_reader)
+{
+ size_t pq_file_dim, pq_file_num_centroids;
+ size_t rows, clos;
+ //获取codebook的信息, 分类中心个数,分段数
+ //get_bin_metadata(pq_pivots_reader, pq_file_num_centroids, pq_file_dim, 0);
+
+ std::unique_ptr<size_t[]> cumu_offsets;
+ std::unique_ptr<float[]> codebook;
+ diskann::load_bin<size_t>(pq_pivots_reader, cumu_offsets, rows, clos, 0);
+ diskann::load_bin<float>(pq_pivots_reader, codebook, pq_file_num_centroids, pq_file_dim, cumu_offsets[0]);
+ std::cout <<"codebook:" << pq_file_num_centroids << ", pq_file_dim:" << pq_file_dim << std::endl;
+
+ if (pq_file_num_centroids != 256)
+ {
+ diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." << std::endl;
+ return -1;
+ }
+
+ this->_data_dim = pq_file_dim;
+ // will change later if we use PQ on disk or if we are using
+ // inner product without PQ
+ this->_disk_bytes_per_point = this->_data_dim * sizeof(T);
+ this->_aligned_dim = ROUND_UP(pq_file_dim, 8);
+
+ // //读取PQ向量(一次性加载到内存)
+ size_t npts_u64, nchunks_u64;
+ diskann::load_bin<uint8_t>(pq_compressed_reader, this->data, npts_u64, nchunks_u64, 0);
+ std::cout <<"pq_compressed:" << npts_u64 << ", dim:" << nchunks_u64 << std::endl;
+
+ this->_num_points = npts_u64;
+ this->_n_chunks = nchunks_u64;
+
+
+ _pq_table.load_pq_centroid_bin(pq_pivots_reader, nchunks_u64);
+
+
+ diskann::cout << "Loaded PQ centroids and in-memory compressed vectors. #points: " << _num_points
+ << " #dim: " << _data_dim << " #aligned_dim: " << _aligned_dim << " #chunks: " << _n_chunks
+ << std::endl;
+
+ if (_n_chunks > MAX_PQ_CHUNKS)
+ {
+ std::stringstream stream;
+ stream << "Error loading index. Ensure that max PQ bytes for in-memory "
+ "PQ data does not exceed "
+ << MAX_PQ_CHUNKS << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+
+ uint32_t nr, nc; // metadata itself is stored as bin format (nr is number of
+ // metadata, nc should be 1)
+ uint64_t offset = 0;
+ disk_layout_reader->read((char*)&nr, 4, offset);
+ offset += 4;
+ disk_layout_reader->read((char*)&nc, 4, offset);
+ offset += 4;
+
+ uint64_t disk_nnodes;
+ uint64_t disk_ndims; // can be disk PQ dim if disk_PQ is set to true
+ disk_layout_reader->read((char*)&disk_nnodes, 8, offset);
+ offset += 8;
+ disk_layout_reader->read((char*)&disk_ndims, 8, offset);
+ offset += 8;
+
+ if (disk_nnodes != _num_points)
+ {
+ diskann::cout << "Mismatch in #points for compressed data file and disk "
+ "index file: "
+ << disk_nnodes << " vs " << _num_points << std::endl;
+ return -1;
+ }
+
+ size_t medoid_id_on_file;
+ disk_layout_reader->read((char*)&medoid_id_on_file, 8, offset);
+ offset += 8;
+ disk_layout_reader->read((char*)&_max_node_len, 8, offset);
+ offset += 8;
+ disk_layout_reader->read((char*)&_nnodes_per_sector, 8, offset);
+ offset += 8;
+ _max_degree = ((_max_node_len - _disk_bytes_per_point) / sizeof(uint32_t)) - 1;
+ if (_max_degree > defaults::MAX_GRAPH_DEGREE)
+ {
+ std::stringstream stream;
+ stream << "Error loading index. Ensure that max graph degree (R) does "
+ "not exceed "
+ << defaults::MAX_GRAPH_DEGREE << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ // setting up concept of frozen points in disk index for streaming-DiskANN
+ disk_layout_reader->read((char*)&_num_frozen_points, 8, offset);
+ offset += 8;
+ uint64_t file_frozen_id;
+ disk_layout_reader->read((char*)&file_frozen_id, 8, offset);
+ offset += 8;
+ if (this->_num_frozen_points == 1)
+ this->_frozen_location = file_frozen_id;
+ if (this->_num_frozen_points == 1)
+ {
+ diskann::cout << " Detected frozen point in index at location " << this->_frozen_location
+ << ". Will not output it at search time." << std::endl;
+ }
+ disk_layout_reader->read((char*)&_reorder_data_exists, 8, offset);
+ offset += 8;
+ if (this->_reorder_data_exists)
+ {
+ if (this->_use_disk_index_pq == false)
+ {
+ throw ANNException("Reordering is designed for used with disk PQ "
+ "compression option",
+ -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ disk_layout_reader->read((char*)&(this->_reorder_data_start_sector), 8, offset);
+ offset += 8;
+ disk_layout_reader->read((char*)&(this->_ndims_reorder_vecs), 8, offset);
+ offset += 8;
+ disk_layout_reader->read((char*)&(this->_nvecs_per_sector), 8, offset);
+ }
+ this->setup_thread_data_without_ctx(num_threads);
+ _batch_reader.reset(new BatchReader(num_threads));
+ _batch_reader->set_reader(disk_layout_reader);
+
+ _num_medoids = 1;
+ _medoids = new uint32_t[1];
+ _medoids[0] = (uint32_t)(medoid_id_on_file);
+ use_medoids_data_as_centroids();
+
+ diskann::cout << "Disk-Index File Meta-data: ";
+ diskann::cout << "# nodes per sector: " << _nnodes_per_sector;
+ diskann::cout << ", max node len (bytes): " << _max_node_len;
+ diskann::cout << ", max node degree: " << _max_degree << std::endl;
+ diskann::cout << "done.." << std::endl;
+ return 0;
+}
+
+#ifdef USE_BING_INFRA
+bool getNextCompletedRequest(std::shared_ptr<AlignedFileReader> &reader, IOContext &ctx, size_t size,
+ int &completedIndex)
+{
+ if ((*ctx.m_pRequests)[0].m_callback)
+ {
+ bool waitsRemaining = false;
+ long completeCount = ctx.m_completeCount;
+ do
+ {
+ for (int i = 0; i < size; i++)
+ {
+ auto ithStatus = (*ctx.m_pRequestsStatus)[i];
+ if (ithStatus == IOContext::Status::READ_SUCCESS)
+ {
+ completedIndex = i;
+ return true;
+ }
+ else if (ithStatus == IOContext::Status::READ_WAIT)
+ {
+ waitsRemaining = true;
+ }
+ }
+
+ // if we didn't find one in READ_SUCCESS, wait for one to complete.
+ if (waitsRemaining)
+ {
+ WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100);
+ // this assumes the knowledge of the reader behavior (implicit
+ // contract). need better factoring?
+ }
+ } while (waitsRemaining);
+
+ completedIndex = -1;
+ return false;
+ }
+ else
+ {
+ reader->wait(ctx, completedIndex);
+ return completedIndex != -1;
+ }
+}
+#endif
+
+
+template <typename T, typename LabelT>
+uint32_t PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search,
+ uint64_t *indices, float *distances, const uint64_t beam_width,
+ Filter *filter,
+ QueryStats *stats)
+{
+
+ uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);
+ if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS)
+ throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+
+ ScratchStoreManager<SSDThreadData<T>> manager(this->_thread_data);
+ auto data = manager.scratch_space();
+ //IOContext &ctx = data->ctx;
+ auto query_scratch = &(data->scratch);
+ auto pq_query_scratch = query_scratch->pq_scratch();
+
+ // reset query scratch
+ query_scratch->reset();
+
+ // copy query to thread specific aligned and allocated memory (for distance
+ // calculations we need aligned data)
+ float query_norm = 0;
+ T *aligned_query_T = query_scratch->aligned_query_T();
+ float *query_float = pq_query_scratch->aligned_query_float;
+ float *query_rotated = pq_query_scratch->rotated_query;
+
+ // normalization step. for cosine, we simply normalize the query
+ // for mips, we normalize the first d-1 dims, and add a 0 for last dim, since an extra coordinate was used to
+ // convert MIPS to L2 search
+ if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE)
+ {
+ uint64_t inherent_dim = (metric == diskann::Metric::COSINE) ? this->_data_dim : (uint64_t)(this->_data_dim - 1);
+ for (size_t i = 0; i < inherent_dim; i++)
+ {
+ aligned_query_T[i] = query1[i];
+ query_norm += query1[i] * query1[i];
+ }
+ if (metric == diskann::Metric::INNER_PRODUCT)
+ aligned_query_T[this->_data_dim - 1] = 0;
+
+ query_norm = std::sqrt(query_norm);
+
+ for (size_t i = 0; i < inherent_dim; i++)
+ {
+ aligned_query_T[i] = (T)(aligned_query_T[i] / query_norm);
+ }
+ pq_query_scratch->initialize(this->_data_dim, aligned_query_T);
+ }
+ else
+ {
+ for (size_t i = 0; i < this->_data_dim; i++)
+ {
+ aligned_query_T[i] = query1[i];
+ }
+ pq_query_scratch->initialize(this->_data_dim, aligned_query_T);
+ }
+
+ // pointers to buffers for data
+ T *data_buf = query_scratch->coord_scratch;
+ _mm_prefetch((char *)data_buf, _MM_HINT_T1);
+
+ // sector scratch
+ char *sector_scratch = query_scratch->sector_scratch;
+ uint64_t §or_scratch_idx = query_scratch->sector_idx;
+ const uint64_t num_sectors_per_node =
+ _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);
+
+ // query <-> PQ chunk centers distances
+ _pq_table.preprocess_query(query_rotated); // center the query and rotate if
+ // we have a rotation matrix
+ float *pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch;
+ _pq_table.populate_chunk_distances(query_rotated, pq_dists);
+
+ // query <-> neighbor list
+ float *dist_scratch = pq_query_scratch->aligned_dist_scratch;
+ uint8_t *pq_coord_scratch = pq_query_scratch->aligned_pq_coord_scratch;
+
+ // lambda to batch compute query<-> node distances in PQ space
+ auto compute_dists = [this, pq_coord_scratch, pq_dists](const uint32_t *ids, const uint64_t n_ids,
+ float *dists_out) {
+ diskann::aggregate_coords(ids, n_ids, this->data, this->_n_chunks, pq_coord_scratch);
+ diskann::pq_dist_lookup(pq_coord_scratch, n_ids, this->_n_chunks, pq_dists, dists_out);
+ };
+ Timer query_timer, io_timer, cpu_timer;
+
+ tsl::robin_set<uint64_t> &visited = query_scratch->visited;
+ NeighborPriorityQueue &retset = query_scratch->retset;
+ retset.reserve(l_search);
+ std::vector<Neighbor> &full_retset = query_scratch->full_retset;
+
+ uint32_t best_medoid = 0;
+ float best_dist = (std::numeric_limits<float>::max)();
+ for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++)
+ {
+ float cur_expanded_dist =
+ _dist_cmp_float->compare(query_float, _centroid_data + _aligned_dim * cur_m, (uint32_t)_aligned_dim);
+ if (cur_expanded_dist < best_dist)
+ {
+ best_medoid = _medoids[cur_m];
+ best_dist = cur_expanded_dist;
+ }
+ }
+
+ compute_dists(&best_medoid, 1, dist_scratch);
+ retset.insert(Neighbor(best_medoid, dist_scratch[0]));
+ visited.insert(best_medoid);
+
+ uint32_t cmps = 0;
+ uint32_t hops = 0;
+ uint32_t num_ios = 0;
+
+ // cleared every iteration
+ std::vector<uint32_t> frontier;
+ frontier.reserve(2 * beam_width);
+ std::vector<std::pair<uint32_t, char *>> frontier_nhoods;
+ frontier_nhoods.reserve(2 * beam_width);
+ std::vector<AlignedRead> frontier_read_reqs;
+ frontier_read_reqs.reserve(2 * beam_width);
+ std::vector<std::pair<uint32_t, std::pair<uint32_t, uint32_t *>>> cached_nhoods;
+ cached_nhoods.reserve(2 * beam_width);
+
+ while (retset.has_unexpanded_node())
+ {
+ // clear iteration state
+ frontier.clear();
+ frontier_nhoods.clear();
+ frontier_read_reqs.clear();
+ cached_nhoods.clear();
+ sector_scratch_idx = 0;
+ // find new beam
+ uint32_t num_seen = 0;
+ while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width)
+ {
+ auto nbr = retset.closest_unexpanded();
+ num_seen++;
+ auto iter = _nhood_cache.find(nbr.id);
+ if (iter != _nhood_cache.end())
+ {
+ cached_nhoods.push_back(std::make_pair(nbr.id, iter->second));
+ if (stats != nullptr)
+ {
+ stats->n_cache_hits++;
+ }
+ }
+ else
+ {
+ frontier.push_back(nbr.id);
+ }
+ if (this->_count_visited_nodes)
+ {
+ reinterpret_cast<std::atomic<uint32_t> &>(this->_node_visit_counter[nbr.id].second).fetch_add(1);
+ }
+ if (filter != nullptr && filter->is_member(nbr.id)){
+ retset.remove_pre_expanded_node();
+ }
+ }
+
+ // read nhoods of frontier ids
+ if (!frontier.empty())
+ {
+ if (stats != nullptr)
+ stats->n_hops++;
+ for (uint64_t i = 0; i < frontier.size(); i++)
+ {
+ auto id = frontier[i];
+ std::pair<uint32_t, char *> fnhood;
+ fnhood.first = id;
+ fnhood.second = sector_scratch + num_sectors_per_node * sector_scratch_idx * defaults::SECTOR_LEN;
+ sector_scratch_idx++;
+ frontier_nhoods.push_back(fnhood);
+ frontier_read_reqs.emplace_back(get_node_sector((size_t)id) * defaults::SECTOR_LEN,
+ num_sectors_per_node * defaults::SECTOR_LEN, fnhood.second);
+ if (stats != nullptr)
+ {
+ stats->n_4k++;
+ stats->n_ios++;
+ }
+ num_ios++;
+ }
+ io_timer.reset();
+ _batch_reader->read(frontier_read_reqs); // synchronous IO linux
+ if (stats != nullptr)
+ {
+ stats->io_us += (float)io_timer.elapsed();
+ }
+ }
+
+ // process cached nhoods
+ for (auto &cached_nhood : cached_nhoods)
+ {
+ auto global_cache_iter = _coord_cache.find(cached_nhood.first);
+ T *node_fp_coords_copy = global_cache_iter->second;
+ float cur_expanded_dist;
+ if (filter == nullptr || !filter->is_member(cached_nhood.first)) {
+ if (!_use_disk_index_pq)
+ {
+ cur_expanded_dist = _dist_cmp->compare(aligned_query_T, node_fp_coords_copy, (uint32_t)_aligned_dim);
+ }
+ else
+ {
+ if (metric == diskann::Metric::INNER_PRODUCT)
+ cur_expanded_dist = _disk_pq_table.inner_product(query_float, (uint8_t *)node_fp_coords_copy);
+ else
+ cur_expanded_dist = _disk_pq_table.l2_distance( // disk_pq does not support OPQ yet
+ query_float, (uint8_t *)node_fp_coords_copy);
+ }
+ full_retset.push_back(Neighbor((uint32_t)cached_nhood.first, cur_expanded_dist));
+ }
+
+ uint64_t nnbrs = cached_nhood.second.first;
+ uint32_t *node_nbrs = cached_nhood.second.second;
+
+ // compute node_nbrs <-> query dists in PQ space
+ cpu_timer.reset();
+ compute_dists(node_nbrs, nnbrs, dist_scratch);
+ if (stats != nullptr)
+ {
+ stats->n_cmps += (uint32_t)nnbrs;
+ stats->cpu_us += (float)cpu_timer.elapsed();
+ }
+
+ // process prefetched nhood
+ for (uint64_t m = 0; m < nnbrs; ++m)
+ {
+ uint32_t id = node_nbrs[m];
+ if (visited.insert(id).second)
+ {
+ cmps++;
+ float dist = dist_scratch[m];
+ Neighbor nn(id, dist);
+ retset.insert(nn);
+ }
+ }
+ }
+ for (auto &frontier_nhood : frontier_nhoods)
+ {
+ char *node_disk_buf = offset_to_node(frontier_nhood.second, frontier_nhood.first);
+ uint32_t *node_buf = offset_to_node_nhood(node_disk_buf);
+ uint64_t nnbrs = (uint64_t)(*node_buf);
+ T *node_fp_coords = offset_to_node_coords(node_disk_buf);
+ memcpy(data_buf, node_fp_coords, _disk_bytes_per_point);
+ float cur_expanded_dist;
+ if (filter == nullptr || !filter->is_member(frontier_nhood.first)) {
+ if (!_use_disk_index_pq)
+ {
+ // for(int kk=0;kk < _aligned_dim;kk++){
+ // std::cout << aligned_query_T[kk] << ",";
+ // }
+ // for(int kk=0;kk < _aligned_dim;kk++){
+ // std::cout << data_buf[kk] << ",";
+ // }
+ //std::cout << std::endl;
+ cur_expanded_dist = _dist_cmp->compare(aligned_query_T, data_buf, (uint32_t)_aligned_dim);
+ }
+ else
+ {
+ if (metric == diskann::Metric::INNER_PRODUCT)
+ cur_expanded_dist = _disk_pq_table.inner_product(query_float, (uint8_t *)data_buf);
+ else
+ cur_expanded_dist = _disk_pq_table.l2_distance(query_float, (uint8_t *)data_buf);
+ }
+ full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist));
+ }
+ uint32_t *node_nbrs = (node_buf + 1);
+ // compute node_nbrs <-> query dist in PQ space
+ cpu_timer.reset();
+ compute_dists(node_nbrs, nnbrs, dist_scratch);
+ if (stats != nullptr)
+ {
+ stats->n_cmps += (uint32_t)nnbrs;
+ stats->cpu_us += (float)cpu_timer.elapsed();
+ }
+
+ cpu_timer.reset();
+ // process prefetch-ed nhood
+ for (uint64_t m = 0; m < nnbrs; ++m)
+ {
+ uint32_t id = node_nbrs[m];
+ if (visited.insert(id).second)
+ {
+ cmps++;
+ float dist = dist_scratch[m];
+ if (stats != nullptr)
+ {
+ stats->n_cmps++;
+ }
+
+ Neighbor nn(id, dist);
+ retset.insert(nn);
+ }
+ }
+
+ if (stats != nullptr)
+ {
+ stats->cpu_us += (float)cpu_timer.elapsed();
+ }
+ }
+
+ hops++;
+ }
+
+ // re-sort by distance
+ std::sort(full_retset.begin(), full_retset.end());
+
+ // copy k_search values
+ uint32_t result_size = 0;
+ for (uint64_t i = 0; i < full_retset.size(); i++)
+ {
+ indices[i] = full_retset[i].id;
+ auto key = (uint32_t)indices[i];
+ if (_dummy_pts.find(key) != _dummy_pts.end())
+ {
+ indices[i] = _dummy_to_real_map[key];
+ }
+
+ if (distances != nullptr)
+ {
+ distances[i] = full_retset[i].distance;
+ if (metric == diskann::Metric::INNER_PRODUCT)
+ {
+ // flip the sign to convert min to max
+ distances[i] = (-distances[i]);
+ // rescale to revert back to original norms (cancelling the
+ // effect of base and query pre-processing)
+ if (_max_base_norm != 0)
+ distances[i] *= (_max_base_norm * query_norm);
+ }
+ }
+ result_size++;
+ if (result_size >= k_search) {
+ break;
+ }
+ }
+ if (stats != nullptr)
+ {
+ stats->total_us = (float)query_timer.elapsed();
+ }
+ return result_size;
+}
+
+// range search returns results of all neighbors within distance of range.
+// indices and distances need to be pre-allocated of size l_search and the
+// return value is the number of matching hits.
+template <typename T, typename LabelT>
+uint32_t PQFlashIndex<T, LabelT>::range_search(const T *query1, const double range, const uint64_t min_l_search,
+ const uint64_t max_l_search, std::vector<uint64_t> &indices,
+ std::vector<float> &distances, const uint64_t min_beam_width,
+ QueryStats *stats)
+{
+ uint32_t res_count = 0;
+
+ bool stop_flag = false;
+
+ uint32_t l_search = (uint32_t)min_l_search; // starting size of the candidate list
+ while (!stop_flag)
+ {
+ indices.resize(l_search);
+ distances.resize(l_search);
+ uint64_t cur_bw = min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5;
+ cur_bw = (cur_bw > 100) ? 100 : cur_bw;
+ for (auto &x : distances)
+ x = std::numeric_limits<float>::max();
+ this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, nullptr, stats);
+ for (uint32_t i = 0; i < l_search; i++)
+ {
+ if (distances[i] > (float)range)
+ {
+ res_count = i;
+ break;
+ }
+ else if (i == l_search - 1)
+ res_count = l_search;
+ }
+ if (res_count < (uint32_t)(l_search / 2.0))
+ stop_flag = true;
+ l_search = l_search * 2;
+ if (l_search > max_l_search)
+ stop_flag = true;
+ }
+ indices.resize(res_count);
+ distances.resize(res_count);
+ return res_count;
+}
+
+template <typename T, typename LabelT> uint64_t PQFlashIndex<T, LabelT>::get_data_dim()
+{
+ return _data_dim;
+}
+
+template <typename T, typename LabelT> diskann::Metric PQFlashIndex<T, LabelT>::get_metric()
+{
+ return this->metric;
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename T, typename LabelT> char *PQFlashIndex<T, LabelT>::getHeaderBytes()
+{
+ IOContext &ctx = reader->get_ctx();
+ AlignedRead readReq;
+ readReq.buf = new char[PQFlashIndex<T, LabelT>::HEADER_SIZE];
+ readReq.len = PQFlashIndex<T, LabelT>::HEADER_SIZE;
+ readReq.offset = 0;
+
+ std::vector<AlignedRead> readReqs;
+ readReqs.push_back(readReq);
+
+ reader->read(readReqs, ctx, false);
+
+ return (char *)readReq.buf;
+}
+#endif
+
+template <typename T, typename LabelT>
+std::vector<std::uint8_t> PQFlashIndex<T, LabelT>::get_pq_vector(std::uint64_t vid)
+{
+ std::uint8_t *pqVec = &this->data[vid * this->_n_chunks];
+ return std::vector<std::uint8_t>(pqVec, pqVec + this->_n_chunks);
+}
+
+template <typename T, typename LabelT> std::uint64_t PQFlashIndex<T, LabelT>::get_num_points()
+{
+ return _num_points;
+}
+
+// instantiations
+template class PQFlashIndex<uint8_t>;
+template class PQFlashIndex<int8_t>;
+template class PQFlashIndex<float>;
+template class PQFlashIndex<uint8_t, uint16_t>;
+template class PQFlashIndex<int8_t, uint16_t>;
+template class PQFlashIndex<float, uint16_t>;
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/pq_l2_distance.cpp b/be/src/extern/diskann/src/pq_l2_distance.cpp
new file mode 100644
index 0000000..c08744c
--- /dev/null
+++ b/be/src/extern/diskann/src/pq_l2_distance.cpp
@@ -0,0 +1,284 @@
+
+#include "pq.h"
+#include "pq_l2_distance.h"
+#include "pq_scratch.h"
+
+// block size for reading/processing large files and matrices in blocks
+#define BLOCK_SIZE 5000000
+
+namespace diskann
+{
+
+template <typename data_t>
+PQL2Distance<data_t>::PQL2Distance(uint32_t num_chunks, bool use_opq) : _num_chunks(num_chunks), _is_opq(use_opq)
+{
+}
+
+template <typename data_t> PQL2Distance<data_t>::~PQL2Distance()
+{
+#ifndef EXEC_ENV_OLS
+ if (_tables != nullptr)
+ delete[] _tables;
+ if (_chunk_offsets != nullptr)
+ delete[] _chunk_offsets;
+ if (_centroid != nullptr)
+ delete[] _centroid;
+ if (_rotmat_tr != nullptr)
+ delete[] _rotmat_tr;
+#endif
+ if (_tables_tr != nullptr)
+ delete[] _tables_tr;
+}
+
+template <typename data_t> bool PQL2Distance<data_t>::is_opq() const
+{
+ return this->_is_opq;
+}
+
+template <typename data_t>
+std::string PQL2Distance<data_t>::get_quantized_vectors_filename(const std::string &prefix) const
+{
+ if (_num_chunks == 0)
+ {
+ throw diskann::ANNException("Must set num_chunks before calling get_quantized_vectors_filename", -1,
+ __FUNCSIG__, __FILE__, __LINE__);
+ }
+ return diskann::get_quantized_vectors_filename(prefix, _is_opq, (uint32_t)_num_chunks);
+}
+template <typename data_t> std::string PQL2Distance<data_t>::get_pivot_data_filename(const std::string &prefix) const
+{
+ if (_num_chunks == 0)
+ {
+ throw diskann::ANNException("Must set num_chunks before calling get_pivot_data_filename", -1, __FUNCSIG__,
+ __FILE__, __LINE__);
+ }
+ return diskann::get_pivot_data_filename(prefix, _is_opq, (uint32_t)_num_chunks);
+}
+template <typename data_t>
+std::string PQL2Distance<data_t>::get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const
+{
+ return diskann::get_rotation_matrix_suffix(pq_pivots_filename);
+}
+
+#ifdef EXEC_ENV_OLS
+template <typename data_t>
+void PQL2Distance<data_t>::load_pivot_data(MemoryMappedFiles &files, const std::string &pq_table_file,
+ size_t num_chunks)
+{
+#else
+template <typename data_t>
+void PQL2Distance<data_t>::load_pivot_data(const std::string &pq_table_file, size_t num_chunks)
+{
+#endif
+ uint64_t nr, nc;
+ // std::string rotmat_file = get_opq_rot_matrix_filename(pq_table_file,
+ // false);
+
+#ifdef EXEC_ENV_OLS
+ size_t *file_offset_data; // since load_bin only sets the pointer, no need
+ // to delete.
+ diskann::load_bin<size_t>(files, pq_table_file, file_offset_data, nr, nc);
+#else
+ std::unique_ptr<size_t[]> file_offset_data;
+ diskann::load_bin<size_t>(pq_table_file, file_offset_data, nr, nc);
+#endif
+
+ bool use_old_filetype = false;
+
+ if (nr != 4 && nr != 5)
+ {
+ diskann::cout << "Error reading pq_pivots file " << pq_table_file
+ << ". Offsets dont contain correct metadata, # offsets = " << nr << ", but expecting " << 4
+ << " or " << 5;
+ throw diskann::ANNException("Error reading pq_pivots file at offsets data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ if (nr == 4)
+ {
+ diskann::cout << "Offsets: " << file_offset_data[0] << " " << file_offset_data[1] << " " << file_offset_data[2]
+ << " " << file_offset_data[3] << std::endl;
+ }
+ else if (nr == 5)
+ {
+ use_old_filetype = true;
+ diskann::cout << "Offsets: " << file_offset_data[0] << " " << file_offset_data[1] << " " << file_offset_data[2]
+ << " " << file_offset_data[3] << file_offset_data[4] << std::endl;
+ }
+ else
+ {
+ throw diskann::ANNException("Wrong number of offsets in pq_pivots", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+#ifdef EXEC_ENV_OLS
+ diskann::load_bin<float>(files, pq_table_file, tables, nr, nc, file_offset_data[0]);
+#else
+ diskann::load_bin<float>(pq_table_file, _tables, nr, nc, file_offset_data[0]);
+#endif
+
+ if ((nr != NUM_PQ_CENTROIDS))
+ {
+ diskann::cout << "Error reading pq_pivots file " << pq_table_file << ". file_num_centers = " << nr
+ << " but expecting " << NUM_PQ_CENTROIDS << " centers";
+ throw diskann::ANNException("Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ this->_ndims = nc;
+
+#ifdef EXEC_ENV_OLS
+ diskann::load_bin<float>(files, pq_table_file, centroid, nr, nc, file_offset_data[1]);
+#else
+ diskann::load_bin<float>(pq_table_file, _centroid, nr, nc, file_offset_data[1]);
+#endif
+
+ if ((nr != this->_ndims) || (nc != 1))
+ {
+ diskann::cerr << "Error reading centroids from pq_pivots file " << pq_table_file << ". file_dim = " << nr
+ << ", file_cols = " << nc << " but expecting " << this->_ndims << " entries in 1 dimension.";
+ throw diskann::ANNException("Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ int chunk_offsets_index = 2;
+ if (use_old_filetype)
+ {
+ chunk_offsets_index = 3;
+ }
+#ifdef EXEC_ENV_OLS
+ diskann::load_bin<uint32_t>(files, pq_table_file, chunk_offsets, nr, nc, file_offset_data[chunk_offsets_index]);
+#else
+ diskann::load_bin<uint32_t>(pq_table_file, _chunk_offsets, nr, nc, file_offset_data[chunk_offsets_index]);
+#endif
+
+ if (nc != 1 || (nr != num_chunks + 1 && num_chunks != 0))
+ {
+ diskann::cerr << "Error loading chunk offsets file. numc: " << nc << " (should be 1). numr: " << nr
+ << " (should be " << num_chunks + 1 << " or 0 if we need to infer)" << std::endl;
+ throw diskann::ANNException("Error loading chunk offsets file", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ this->_num_chunks = nr - 1;
+ diskann::cout << "Loaded PQ Pivots: #ctrs: " << NUM_PQ_CENTROIDS << ", #dims: " << this->_ndims
+ << ", #chunks: " << this->_num_chunks << std::endl;
+
+ // For OPQ there will be a rotation matrix to load.
+ if (this->_is_opq)
+ {
+ std::string rotmat_file = get_rotation_matrix_suffix(pq_table_file);
+#ifdef EXEC_ENV_OLS
+ diskann::load_bin<float>(files, rotmat_file, (float *&)rotmat_tr, nr, nc);
+#else
+ diskann::load_bin<float>(rotmat_file, _rotmat_tr, nr, nc);
+#endif
+ if (nr != this->_ndims || nc != this->_ndims)
+ {
+ diskann::cerr << "Error loading rotation matrix file" << std::endl;
+ throw diskann::ANNException("Error loading rotation matrix file", -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+
+ // alloc and compute transpose
+ _tables_tr = new float[256 * this->_ndims];
+ for (size_t i = 0; i < 256; i++)
+ {
+ for (size_t j = 0; j < this->_ndims; j++)
+ {
+ _tables_tr[j * 256 + i] = _tables[i * this->_ndims + j];
+ }
+ }
+}
+
+template <typename data_t> uint32_t PQL2Distance<data_t>::get_num_chunks() const
+{
+ return static_cast<uint32_t>(_num_chunks);
+}
+
+// REFACTOR: Instead of doing half the work in the caller and half in this
+// function, we let this function
+// do all of the work, making it easier for the caller.
+template <typename data_t>
+void PQL2Distance<data_t>::preprocess_query(const data_t *aligned_query, uint32_t dim, PQScratch<data_t> &scratch)
+{
+ // Copy query vector to float and then to "rotated" query
+ for (size_t d = 0; d < dim; d++)
+ {
+ scratch.aligned_query_float[d] = (float)aligned_query[d];
+ }
+ scratch.initialize(dim, aligned_query);
+
+ for (uint32_t d = 0; d < _ndims; d++)
+ {
+ scratch.rotated_query[d] -= _centroid[d];
+ }
+ std::vector<float> tmp(_ndims, 0);
+ if (_is_opq)
+ {
+ for (uint32_t d = 0; d < _ndims; d++)
+ {
+ for (uint32_t d1 = 0; d1 < _ndims; d1++)
+ {
+ tmp[d] += scratch.rotated_query[d1] * _rotmat_tr[d1 * _ndims + d];
+ }
+ }
+ std::memcpy(scratch.rotated_query, tmp.data(), _ndims * sizeof(float));
+ }
+ this->prepopulate_chunkwise_distances(scratch.rotated_query, scratch.aligned_pqtable_dist_scratch);
+}
+
+template <typename data_t>
+void PQL2Distance<data_t>::preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t n_ids, float *dists_out)
+{
+ pq_dist_lookup(pq_scratch.aligned_pq_coord_scratch, n_ids, _num_chunks, pq_scratch.aligned_pqtable_dist_scratch,
+ dists_out);
+}
+
+template <typename data_t>
+void PQL2Distance<data_t>::preprocessed_distance(PQScratch<data_t> &pq_scratch, const uint32_t n_ids,
+ std::vector<float> &dists_out)
+{
+ pq_dist_lookup(pq_scratch.aligned_pq_coord_scratch, n_ids, _num_chunks, pq_scratch.aligned_pqtable_dist_scratch,
+ dists_out);
+}
+
+template <typename data_t> float PQL2Distance<data_t>::brute_force_distance(const float *query_vec, uint8_t *base_vec)
+{
+ float res = 0;
+ for (size_t chunk = 0; chunk < _num_chunks; chunk++)
+ {
+ for (size_t j = _chunk_offsets[chunk]; j < _chunk_offsets[chunk + 1]; j++)
+ {
+ const float *centers_dim_vec = _tables_tr + (256 * j);
+ float diff = centers_dim_vec[base_vec[chunk]] - (query_vec[j]);
+ res += diff * diff;
+ }
+ }
+ return res;
+}
+
+template <typename data_t>
+void PQL2Distance<data_t>::prepopulate_chunkwise_distances(const float *query_vec, float *dist_vec)
+{
+ memset(dist_vec, 0, 256 * _num_chunks * sizeof(float));
+ // chunk wise distance computation
+ for (size_t chunk = 0; chunk < _num_chunks; chunk++)
+ {
+ // sum (q-c)^2 for the dimensions associated with this chunk
+ float *chunk_dists = dist_vec + (256 * chunk);
+ for (size_t j = _chunk_offsets[chunk]; j < _chunk_offsets[chunk + 1]; j++)
+ {
+ const float *centers_dim_vec = _tables_tr + (256 * j);
+ for (size_t idx = 0; idx < 256; idx++)
+ {
+ double diff = centers_dim_vec[idx] - (query_vec[j]);
+ chunk_dists[idx] += (float)(diff * diff);
+ }
+ }
+ }
+}
+
+template DISKANN_DLLEXPORT class PQL2Distance<int8_t>;
+template DISKANN_DLLEXPORT class PQL2Distance<uint8_t>;
+template DISKANN_DLLEXPORT class PQL2Distance<float>;
+
+} // namespace diskann
\ No newline at end of file
diff --git a/be/src/extern/diskann/src/scratch.cpp b/be/src/extern/diskann/src/scratch.cpp
new file mode 100644
index 0000000..f8a857a
--- /dev/null
+++ b/be/src/extern/diskann/src/scratch.cpp
@@ -0,0 +1,181 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include <vector>
+#include <boost/dynamic_bitset.hpp>
+
+#include "scratch.h"
+#include "pq_scratch.h"
+
+namespace diskann
+{
+//
+// Functions to manage scratch space for in-memory index based search
+//
+template <typename T>
+InMemQueryScratch<T>::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim,
+ size_t aligned_dim, size_t alignment_factor, bool init_pq_scratch)
+ : _L(0), _R(r), _maxc(maxc)
+{
+ if (search_l == 0 || indexing_l == 0 || r == 0 || dim == 0)
+ {
+ std::stringstream ss;
+ ss << "In InMemQueryScratch, one of search_l = " << search_l << ", indexing_l = " << indexing_l
+ << ", dim = " << dim << " or r = " << r << " is zero." << std::endl;
+ throw diskann::ANNException(ss.str(), -1);
+ }
+
+ alloc_aligned(((void **)&this->_aligned_query_T), aligned_dim * sizeof(T), alignment_factor * sizeof(T));
+ memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T));
+
+ if (init_pq_scratch)
+ this->_pq_scratch = new PQScratch<T>(defaults::MAX_GRAPH_DEGREE, aligned_dim);
+ else
+ this->_pq_scratch = nullptr;
+
+ _occlude_factor.reserve(maxc);
+ _inserted_into_pool_bs = new boost::dynamic_bitset<>();
+ _id_scratch.reserve((size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R));
+ _dist_scratch.reserve((size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R));
+
+ resize_for_new_L(std::max(search_l, indexing_l));
+}
+
+template <typename T> void InMemQueryScratch<T>::clear()
+{
+ _pool.clear();
+ _best_l_nodes.clear();
+ _occlude_factor.clear();
+
+ _inserted_into_pool_rs.clear();
+ _inserted_into_pool_bs->reset();
+
+ _id_scratch.clear();
+ _dist_scratch.clear();
+
+ _expanded_nodes_set.clear();
+ _expanded_nghrs_vec.clear();
+ _occlude_list_output.clear();
+}
+
+template <typename T> void InMemQueryScratch<T>::resize_for_new_L(uint32_t new_l)
+{
+ if (new_l > _L)
+ {
+ _L = new_l;
+ _pool.reserve(3 * _L + _R);
+ _best_l_nodes.reserve(_L);
+
+ _inserted_into_pool_rs.reserve(20 * _L);
+ }
+}
+
+template <typename T> InMemQueryScratch<T>::~InMemQueryScratch()
+{
+ if (this->_aligned_query_T != nullptr)
+ {
+ aligned_free(this->_aligned_query_T);
+ this->_aligned_query_T = nullptr;
+ }
+
+ delete this->_pq_scratch;
+ delete _inserted_into_pool_bs;
+}
+
+//
+// Functions to manage scratch space for SSD based search
+//
+template <typename T> void SSDQueryScratch<T>::reset()
+{
+ sector_idx = 0;
+ visited.clear();
+ retset.clear();
+ full_retset.clear();
+}
+
+template <typename T> SSDQueryScratch<T>::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve)
+{
+ size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256);
+
+ diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256);
+ diskann::alloc_aligned((void **)§or_scratch, defaults::MAX_N_SECTOR_READS * defaults::SECTOR_LEN,
+ defaults::SECTOR_LEN);
+ diskann::alloc_aligned((void **)&this->_aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T));
+ this->_pq_scratch = new PQScratch<T>(defaults::MAX_GRAPH_DEGREE, aligned_dim);
+
+ memset(coord_scratch, 0, coord_alloc_size);
+ memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T));
+
+ visited.reserve(visited_reserve);
+ full_retset.reserve(visited_reserve);
+}
+
+template <typename T> SSDQueryScratch<T>::~SSDQueryScratch()
+{
+ diskann::aligned_free((void *)coord_scratch);
+ diskann::aligned_free((void *)sector_scratch);
+ diskann::aligned_free((void *)this->_aligned_query_T);
+
+ delete this->_pq_scratch;
+}
+
+template <typename T>
+SSDThreadData<T>::SSDThreadData(size_t aligned_dim, size_t visited_reserve) : scratch(aligned_dim, visited_reserve)
+{
+}
+
+template <typename T> void SSDThreadData<T>::clear()
+{
+ scratch.reset();
+}
+
+template <typename T> PQScratch<T>::PQScratch(size_t graph_degree, size_t aligned_dim)
+{
+ diskann::alloc_aligned((void **)&aligned_pq_coord_scratch,
+ (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256);
+ diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), 256);
+ diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256);
+ diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float));
+ diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float));
+
+ memset(aligned_query_float, 0, aligned_dim * sizeof(float));
+ memset(rotated_query, 0, aligned_dim * sizeof(float));
+}
+
+template <typename T> PQScratch<T>::~PQScratch()
+{
+ diskann::aligned_free((void *)aligned_pq_coord_scratch);
+ diskann::aligned_free((void *)aligned_pqtable_dist_scratch);
+ diskann::aligned_free((void *)aligned_dist_scratch);
+ diskann::aligned_free((void *)aligned_query_float);
+ diskann::aligned_free((void *)rotated_query);
+}
+
+template <typename T> void PQScratch<T>::initialize(size_t dim, const T *query, const float norm)
+{
+ for (size_t d = 0; d < dim; ++d)
+ {
+ if (norm != 1.0f)
+ rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]) / norm;
+ else
+ rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]);
+ }
+}
+
+template DISKANN_DLLEXPORT class InMemQueryScratch<int8_t>;
+template DISKANN_DLLEXPORT class InMemQueryScratch<uint8_t>;
+template DISKANN_DLLEXPORT class InMemQueryScratch<float>;
+
+template DISKANN_DLLEXPORT class SSDQueryScratch<int8_t>;
+template DISKANN_DLLEXPORT class SSDQueryScratch<uint8_t>;
+template DISKANN_DLLEXPORT class SSDQueryScratch<float>;
+
+template DISKANN_DLLEXPORT class PQScratch<int8_t>;
+template DISKANN_DLLEXPORT class PQScratch<uint8_t>;
+template DISKANN_DLLEXPORT class PQScratch<float>;
+
+template DISKANN_DLLEXPORT class SSDThreadData<int8_t>;
+template DISKANN_DLLEXPORT class SSDThreadData<uint8_t>;
+template DISKANN_DLLEXPORT class SSDThreadData<float>;
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/utils.cpp b/be/src/extern/diskann/src/utils.cpp
new file mode 100644
index 0000000..3773cda
--- /dev/null
+++ b/be/src/extern/diskann/src/utils.cpp
@@ -0,0 +1,477 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#include "utils.h"
+
+#include <stdio.h>
+
+#ifdef EXEC_ENV_OLS
+#include "aligned_file_reader.h"
+#endif
+
+const uint32_t MAX_REQUEST_SIZE = 1024 * 1024 * 1024; // 64MB
+const uint32_t MAX_SIMULTANEOUS_READ_REQUESTS = 128;
+
+#ifdef _WINDOWS
+#include <intrin.h>
+
+// Taken from:
+// https://insufficientlycomplicated.wordpress.com/2011/11/07/detecting-intel-advanced-vector-extensions-avx-in-visual-studio/
+bool cpuHasAvxSupport()
+{
+ bool avxSupported = false;
+
+ // Checking for AVX requires 3 things:
+ // 1) CPUID indicates that the OS uses XSAVE and XRSTORE
+ // instructions (allowing saving YMM registers on context
+ // switch)
+ // 2) CPUID indicates support for AVX
+ // 3) XGETBV indicates the AVX registers will be saved and
+ // restored on context switch
+ //
+ // Note that XGETBV is only available on 686 or later CPUs, so
+ // the instruction needs to be conditionally run.
+ int cpuInfo[4];
+ __cpuid(cpuInfo, 1);
+
+ bool osUsesXSAVE_XRSTORE = cpuInfo[2] & (1 << 27) || false;
+ bool cpuAVXSuport = cpuInfo[2] & (1 << 28) || false;
+
+ if (osUsesXSAVE_XRSTORE && cpuAVXSuport)
+ {
+ // Check if the OS will save the YMM registers
+ unsigned long long xcrFeatureMask = _xgetbv(_XCR_XFEATURE_ENABLED_MASK);
+ avxSupported = (xcrFeatureMask & 0x6) || false;
+ }
+
+ return avxSupported;
+}
+
+bool cpuHasAvx2Support()
+{
+ int cpuInfo[4];
+ __cpuid(cpuInfo, 0);
+ int n = cpuInfo[0];
+ if (n >= 7)
+ {
+ __cpuidex(cpuInfo, 7, 0);
+ static int avx2Mask = 0x20;
+ return (cpuInfo[1] & avx2Mask) > 0;
+ }
+ return false;
+}
+
+bool AvxSupportedCPU = cpuHasAvxSupport();
+bool Avx2SupportedCPU = cpuHasAvx2Support();
+
+#else
+
+bool Avx2SupportedCPU = true;
+bool AvxSupportedCPU = false;
+#endif
+
+namespace diskann
+{
+
+void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, size_t npts, size_t ndims)
+{
+ readr.read((char *)read_buf, npts * ndims * sizeof(float));
+ uint32_t ndims_u32 = (uint32_t)ndims;
+#pragma omp parallel for
+ for (int64_t i = 0; i < (int64_t)npts; i++)
+ {
+ float norm_pt = std::numeric_limits<float>::epsilon();
+ for (uint32_t dim = 0; dim < ndims_u32; dim++)
+ {
+ norm_pt += *(read_buf + i * ndims + dim) * *(read_buf + i * ndims + dim);
+ }
+ norm_pt = std::sqrt(norm_pt);
+ for (uint32_t dim = 0; dim < ndims_u32; dim++)
+ {
+ *(read_buf + i * ndims + dim) = *(read_buf + i * ndims + dim) / norm_pt;
+ }
+ }
+ writr.write((char *)read_buf, npts * ndims * sizeof(float));
+}
+
+void normalize_data_file(const std::string &inFileName, const std::string &outFileName)
+{
+ std::ifstream readr(inFileName, std::ios::binary);
+ std::ofstream writr(outFileName, std::ios::binary);
+
+ int npts_s32, ndims_s32;
+ readr.read((char *)&npts_s32, sizeof(int32_t));
+ readr.read((char *)&ndims_s32, sizeof(int32_t));
+
+ writr.write((char *)&npts_s32, sizeof(int32_t));
+ writr.write((char *)&ndims_s32, sizeof(int32_t));
+
+ size_t npts = (size_t)npts_s32;
+ size_t ndims = (size_t)ndims_s32;
+ diskann::cout << "Normalizing FLOAT vectors in file: " << inFileName << std::endl;
+ diskann::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
+
+ size_t blk_size = 131072;
+ size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
+ diskann::cout << "# blks: " << nblks << std::endl;
+
+ float *read_buf = new float[npts * ndims];
+ for (size_t i = 0; i < nblks; i++)
+ {
+ size_t cblk_size = std::min(npts - i * blk_size, blk_size);
+ block_convert(writr, readr, read_buf, cblk_size, ndims);
+ }
+ delete[] read_buf;
+
+ diskann::cout << "Wrote normalized points to file: " << outFileName << std::endl;
+}
+
+double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs,
+ uint32_t *our_results, uint32_t dim_or, uint32_t recall_at)
+{
+ double total_recall = 0;
+ std::set<uint32_t> gt, res;
+
+ for (size_t i = 0; i < num_queries; i++)
+ {
+ gt.clear();
+ res.clear();
+ uint32_t *gt_vec = gold_std + dim_gs * i;
+ uint32_t *res_vec = our_results + dim_or * i;
+ size_t tie_breaker = recall_at;
+ if (gs_dist != nullptr)
+ {
+ tie_breaker = recall_at - 1;
+ float *gt_dist_vec = gs_dist + dim_gs * i;
+ while (tie_breaker < dim_gs && gt_dist_vec[tie_breaker] == gt_dist_vec[recall_at - 1])
+ tie_breaker++;
+ }
+
+ gt.insert(gt_vec, gt_vec + tie_breaker);
+ res.insert(res_vec,
+ res_vec + recall_at); // change to recall_at for recall k@k
+ // or dim_or for k@dim_or
+ uint32_t cur_recall = 0;
+ for (auto &v : gt)
+ {
+ if (res.find(v) != res.end())
+ {
+ cur_recall++;
+ }
+ }
+ total_recall += cur_recall;
+ }
+ return total_recall / (num_queries) * (100.0 / recall_at);
+}
+
+double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs,
+ uint32_t *our_results, uint32_t dim_or, uint32_t recall_at,
+ const tsl::robin_set<uint32_t> &active_tags)
+{
+ double total_recall = 0;
+ std::set<uint32_t> gt, res;
+ bool printed = false;
+ for (size_t i = 0; i < num_queries; i++)
+ {
+ gt.clear();
+ res.clear();
+ uint32_t *gt_vec = gold_std + dim_gs * i;
+ uint32_t *res_vec = our_results + dim_or * i;
+ size_t tie_breaker = recall_at;
+ uint32_t active_points_count = 0;
+ uint32_t cur_counter = 0;
+ while (active_points_count < recall_at && cur_counter < dim_gs)
+ {
+ if (active_tags.find(*(gt_vec + cur_counter)) != active_tags.end())
+ {
+ active_points_count++;
+ }
+ cur_counter++;
+ }
+ if (active_tags.empty())
+ cur_counter = recall_at;
+
+ if ((active_points_count < recall_at && !active_tags.empty()) && !printed)
+ {
+ diskann::cout << "Warning: Couldn't find enough closest neighbors " << active_points_count << "/"
+ << recall_at
+ << " from "
+ "truthset for query # "
+ << i << ". Will result in under-reported value of recall." << std::endl;
+ printed = true;
+ }
+ if (gs_dist != nullptr)
+ {
+ tie_breaker = cur_counter - 1;
+ float *gt_dist_vec = gs_dist + dim_gs * i;
+ while (tie_breaker < dim_gs && gt_dist_vec[tie_breaker] == gt_dist_vec[cur_counter - 1])
+ tie_breaker++;
+ }
+
+ gt.insert(gt_vec, gt_vec + tie_breaker);
+ res.insert(res_vec, res_vec + recall_at);
+ uint32_t cur_recall = 0;
+ for (auto &v : res)
+ {
+ if (gt.find(v) != gt.end())
+ {
+ cur_recall++;
+ }
+ }
+ total_recall += cur_recall;
+ }
+ return ((double)(total_recall / (num_queries))) * ((double)(100.0 / recall_at));
+}
+
+double calculate_range_search_recall(uint32_t num_queries, std::vector<std::vector<uint32_t>> &groundtruth,
+ std::vector<std::vector<uint32_t>> &our_results)
+{
+ double total_recall = 0;
+ std::set<uint32_t> gt, res;
+
+ for (size_t i = 0; i < num_queries; i++)
+ {
+ gt.clear();
+ res.clear();
+
+ gt.insert(groundtruth[i].begin(), groundtruth[i].end());
+ res.insert(our_results[i].begin(), our_results[i].end());
+ uint32_t cur_recall = 0;
+ for (auto &v : gt)
+ {
+ if (res.find(v) != res.end())
+ {
+ cur_recall++;
+ }
+ }
+ if (gt.size() != 0)
+ total_recall += ((100.0 * cur_recall) / gt.size());
+ else
+ total_recall += 100;
+ }
+ return total_recall / (num_queries);
+}
+
+#ifdef EXEC_ENV_OLS
+void get_bin_metadata(AlignedFileReader &reader, size_t &npts, size_t &ndim, size_t offset)
+{
+ std::vector<AlignedRead> readReqs;
+ AlignedRead readReq;
+ uint32_t buf[2]; // npts/ndim are uint32_ts.
+
+ readReq.buf = buf;
+ readReq.offset = offset;
+ readReq.len = 2 * sizeof(uint32_t);
+ readReqs.push_back(readReq);
+
+ IOContext &ctx = reader.get_ctx();
+ reader.read(readReqs, ctx); // synchronous
+ if ((*(ctx.m_pRequestsStatus))[0] == IOContext::READ_SUCCESS)
+ {
+ npts = buf[0];
+ ndim = buf[1];
+ diskann::cout << "File has: " << npts << " points, " << ndim << " dimensions at offset: " << offset
+ << std::endl;
+ }
+ else
+ {
+ std::stringstream str;
+ str << "Could not read binary metadata from index file at offset: " << offset << std::endl;
+ throw diskann::ANNException(str.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+}
+
+template <typename T> void load_bin(AlignedFileReader &reader, T *&data, size_t &npts, size_t &ndim, size_t offset)
+{
+ // Code assumes that the reader is already setup correctly.
+ get_bin_metadata(reader, npts, ndim, offset);
+ data = new T[npts * ndim];
+
+ size_t data_size = npts * ndim * sizeof(T);
+ size_t write_offset = 0;
+ size_t read_start = offset + 2 * sizeof(uint32_t);
+
+ // BingAlignedFileReader can only read uint32_t bytes of data. So,
+ // we limit ourselves even more to reading 1GB at a time.
+ std::vector<AlignedRead> readReqs;
+ while (data_size > 0)
+ {
+ AlignedRead readReq;
+ readReq.buf = data + write_offset;
+ readReq.offset = read_start + write_offset;
+ readReq.len = data_size > MAX_REQUEST_SIZE ? MAX_REQUEST_SIZE : data_size;
+ readReqs.push_back(readReq);
+ // in the corner case, the loop will not execute
+ data_size -= readReq.len;
+ write_offset += readReq.len;
+ }
+ IOContext &ctx = reader.get_ctx();
+ reader.read(readReqs, ctx);
+ for (int i = 0; i < readReqs.size(); i++)
+ {
+ // Since we are making sync calls, no request will be in the
+ // READ_WAIT state.
+ if ((*(ctx.m_pRequestsStatus))[i] != IOContext::READ_SUCCESS)
+ {
+ std::stringstream str;
+ str << "Could not read binary data from index file at offset: " << readReqs[i].offset << std::endl;
+ throw diskann::ANNException(str.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+}
+template <typename T>
+void load_bin(AlignedFileReader &reader, std::unique_ptr<T[]> &data, size_t &npts, size_t &ndim, size_t offset)
+{
+ T *ptr = nullptr;
+ load_bin(reader, ptr, npts, ndim, offset);
+ data.reset(ptr);
+}
+
+template <typename T>
+void copy_aligned_data_from_file(AlignedFileReader &reader, T *&data, size_t &npts, size_t &ndim,
+ const size_t &rounded_dim, size_t offset)
+{
+ if (data == nullptr)
+ {
+ diskann::cerr << "Memory was not allocated for " << data << " before calling the load function. Exiting..."
+ << std::endl;
+ throw diskann::ANNException("Null pointer passed to copy_aligned_data_from_file()", -1, __FUNCSIG__, __FILE__,
+ __LINE__);
+ }
+
+ size_t pts, dim;
+ get_bin_metadata(reader, pts, dim, offset);
+
+ if (ndim != dim || npts != pts)
+ {
+ std::stringstream ss;
+ ss << "Either file dimension: " << dim << " is != passed dimension: " << ndim << " or file #pts: " << pts
+ << " is != passed #pts: " << npts << std::endl;
+ throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+
+ // Instead of reading one point of ndim size and setting (rounded_dim - dim)
+ // values to zero We'll set everything to zero and read in chunks of data at
+ // the appropriate locations.
+ size_t read_offset = offset + 2 * sizeof(uint32_t);
+ memset(data, 0, npts * rounded_dim * sizeof(T));
+ int i = 0;
+ std::vector<AlignedRead> read_requests;
+
+ while (i < npts)
+ {
+ int j = 0;
+ read_requests.clear();
+ while (j < MAX_SIMULTANEOUS_READ_REQUESTS && i < npts)
+ {
+ AlignedRead read_req;
+ read_req.buf = data + i * rounded_dim;
+ read_req.len = dim * sizeof(T);
+ read_req.offset = read_offset + i * dim * sizeof(T);
+ read_requests.push_back(read_req);
+ i++;
+ j++;
+ }
+ IOContext &ctx = reader.get_ctx();
+ reader.read(read_requests, ctx);
+ for (int k = 0; k < read_requests.size(); k++)
+ {
+ if ((*ctx.m_pRequestsStatus)[k] != IOContext::READ_SUCCESS)
+ {
+ throw diskann::ANNException("Load data from file using AlignedReader failed.", -1, __FUNCSIG__,
+ __FILE__, __LINE__);
+ }
+ }
+ }
+}
+
+// Unlike load_bin, assumes that data is already allocated 'size' entries
+template <typename T> void read_array(AlignedFileReader &reader, T *data, size_t size, size_t offset)
+{
+ if (data == nullptr)
+ {
+ throw diskann::ANNException("read_array requires an allocated buffer.", -1);
+ }
+
+ if (size * sizeof(T) > MAX_REQUEST_SIZE)
+ {
+ std::stringstream ss;
+ ss << "Cannot read more than " << MAX_REQUEST_SIZE << " bytes. Current request size: " << std::to_string(size)
+ << " sizeof(T): " << sizeof(T) << std::endl;
+ throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ std::vector<AlignedRead> read_requests;
+ AlignedRead read_req;
+ read_req.buf = data;
+ read_req.len = size * sizeof(T);
+ read_req.offset = offset;
+ read_requests.push_back(read_req);
+ IOContext &ctx = reader.get_ctx();
+ reader.read(read_requests, ctx);
+
+ if ((*(ctx.m_pRequestsStatus))[0] != IOContext::READ_SUCCESS)
+ {
+ std::stringstream ss;
+ ss << "Failed to read_array() of size: " << size * sizeof(T) << " at offset: " << offset << " from reader. "
+ << std::endl;
+ throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+}
+
+template <typename T> void read_value(AlignedFileReader &reader, T &value, size_t offset)
+{
+ read_array(reader, &value, 1, offset);
+}
+
+template DISKANN_DLLEXPORT void load_bin<uint8_t>(AlignedFileReader &reader, std::unique_ptr<uint8_t[]> &data,
+ size_t &npts, size_t &ndim, size_t offset);
+template DISKANN_DLLEXPORT void load_bin<int8_t>(AlignedFileReader &reader, std::unique_ptr<int8_t[]> &data,
+ size_t &npts, size_t &ndim, size_t offset);
+template DISKANN_DLLEXPORT void load_bin<uint32_t>(AlignedFileReader &reader, std::unique_ptr<uint32_t[]> &data,
+ size_t &npts, size_t &ndim, size_t offset);
+template DISKANN_DLLEXPORT void load_bin<uint64_t>(AlignedFileReader &reader, std::unique_ptr<uint64_t[]> &data,
+ size_t &npts, size_t &ndim, size_t offset);
+template DISKANN_DLLEXPORT void load_bin<int64_t>(AlignedFileReader &reader, std::unique_ptr<int64_t[]> &data,
+ size_t &npts, size_t &ndim, size_t offset);
+template DISKANN_DLLEXPORT void load_bin<float>(AlignedFileReader &reader, std::unique_ptr<float[]> &data, size_t &npts,
+ size_t &ndim, size_t offset);
+
+template DISKANN_DLLEXPORT void load_bin<uint8_t>(AlignedFileReader &reader, uint8_t *&data, size_t &npts, size_t &ndim,
+ size_t offset);
+template DISKANN_DLLEXPORT void load_bin<int64_t>(AlignedFileReader &reader, int64_t *&data, size_t &npts, size_t &ndim,
+ size_t offset);
+template DISKANN_DLLEXPORT void load_bin<uint64_t>(AlignedFileReader &reader, uint64_t *&data, size_t &npts,
+ size_t &ndim, size_t offset);
+template DISKANN_DLLEXPORT void load_bin<uint32_t>(AlignedFileReader &reader, uint32_t *&data, size_t &npts,
+ size_t &ndim, size_t offset);
+template DISKANN_DLLEXPORT void load_bin<int32_t>(AlignedFileReader &reader, int32_t *&data, size_t &npts, size_t &ndim,
+ size_t offset);
+
+template DISKANN_DLLEXPORT void copy_aligned_data_from_file<uint8_t>(AlignedFileReader &reader, uint8_t *&data,
+ size_t &npts, size_t &dim,
+ const size_t &rounded_dim, size_t offset);
+template DISKANN_DLLEXPORT void copy_aligned_data_from_file<int8_t>(AlignedFileReader &reader, int8_t *&data,
+ size_t &npts, size_t &dim,
+ const size_t &rounded_dim, size_t offset);
+template DISKANN_DLLEXPORT void copy_aligned_data_from_file<float>(AlignedFileReader &reader, float *&data,
+ size_t &npts, size_t &dim, const size_t &rounded_dim,
+ size_t offset);
+
+template DISKANN_DLLEXPORT void read_array<char>(AlignedFileReader &reader, char *data, size_t size, size_t offset);
+
+template DISKANN_DLLEXPORT void read_array<uint8_t>(AlignedFileReader &reader, uint8_t *data, size_t size,
+ size_t offset);
+template DISKANN_DLLEXPORT void read_array<int8_t>(AlignedFileReader &reader, int8_t *data, size_t size, size_t offset);
+template DISKANN_DLLEXPORT void read_array<uint32_t>(AlignedFileReader &reader, uint32_t *data, size_t size,
+ size_t offset);
+template DISKANN_DLLEXPORT void read_array<float>(AlignedFileReader &reader, float *data, size_t size, size_t offset);
+
+template DISKANN_DLLEXPORT void read_value<uint8_t>(AlignedFileReader &reader, uint8_t &value, size_t offset);
+template DISKANN_DLLEXPORT void read_value<int8_t>(AlignedFileReader &reader, int8_t &value, size_t offset);
+template DISKANN_DLLEXPORT void read_value<float>(AlignedFileReader &reader, float &value, size_t offset);
+template DISKANN_DLLEXPORT void read_value<uint32_t>(AlignedFileReader &reader, uint32_t &value, size_t offset);
+template DISKANN_DLLEXPORT void read_value<uint64_t>(AlignedFileReader &reader, uint64_t &value, size_t offset);
+
+#endif
+
+} // namespace diskann
diff --git a/be/src/extern/diskann/src/windows_aligned_file_reader.cpp b/be/src/extern/diskann/src/windows_aligned_file_reader.cpp
new file mode 100644
index 0000000..3650b92
--- /dev/null
+++ b/be/src/extern/diskann/src/windows_aligned_file_reader.cpp
@@ -0,0 +1,189 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license.
+
+#ifdef _WINDOWS
+#ifndef USE_BING_INFRA
+#include "windows_aligned_file_reader.h"
+#include <iostream>
+#include "utils.h"
+#include <stdlib.h>
+
+#define SECTOR_LEN 4096
+
+void WindowsAlignedFileReader::open(const std::string &fname)
+{
+#ifdef UNICODE
+ m_filename = std::wstring(fname.begin(), fname.end());
+#else
+ m_filename = fname;
+#endif
+
+ this->register_thread();
+}
+
+void WindowsAlignedFileReader::close()
+{
+ for (auto &k_v : ctx_map)
+ {
+ IOContext ctx = ctx_map[k_v.first];
+ CloseHandle(ctx.fhandle);
+ }
+}
+
+void WindowsAlignedFileReader::register_thread()
+{
+ std::unique_lock<std::mutex> lk(this->ctx_mut);
+ if (this->ctx_map.find(std::this_thread::get_id()) != ctx_map.end())
+ {
+ diskann::cout << "Warning:: Duplicate registration for thread_id : " << std::this_thread::get_id() << std::endl;
+ }
+
+ IOContext ctx;
+ ctx.fhandle = CreateFile(
+ m_filename.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING,
+ FILE_ATTRIBUTE_READONLY | FILE_FLAG_NO_BUFFERING | FILE_FLAG_OVERLAPPED | FILE_FLAG_RANDOM_ACCESS, NULL);
+ if (ctx.fhandle == INVALID_HANDLE_VALUE)
+ {
+ const size_t c_max_filepath_len = 256;
+ size_t actual_len = 0;
+ char filePath[c_max_filepath_len];
+ if (wcstombs_s(&actual_len, filePath, c_max_filepath_len, m_filename.c_str(), m_filename.length()) == 0)
+ {
+ diskann::cout << "Error opening " << filePath << " -- error=" << GetLastError() << std::endl;
+ }
+ else
+ {
+ diskann::cout << "Error converting wchar to char -- error=" << GetLastError() << std::endl;
+ }
+ }
+
+ // create IOCompletionPort
+ ctx.iocp = CreateIoCompletionPort(ctx.fhandle, ctx.iocp, 0, 0);
+
+ // create MAX_DEPTH # of reqs
+ for (uint64_t i = 0; i < MAX_IO_DEPTH; i++)
+ {
+ OVERLAPPED os;
+ memset(&os, 0, sizeof(OVERLAPPED));
+ // os.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL);
+ ctx.reqs.push_back(os);
+ }
+ this->ctx_map.insert(std::make_pair(std::this_thread::get_id(), ctx));
+}
+
+IOContext &WindowsAlignedFileReader::get_ctx()
+{
+ std::unique_lock<std::mutex> lk(this->ctx_mut);
+ if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end())
+ {
+ std::stringstream stream;
+ stream << "unable to find IOContext for thread_id : " << std::this_thread::get_id() << "\n";
+ throw diskann::ANNException(stream.str(), -2, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ IOContext &ctx = ctx_map[std::this_thread::get_id()];
+ lk.unlock();
+ return ctx;
+}
+
+void WindowsAlignedFileReader::read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async)
+{
+ using namespace std::chrono_literals;
+ // execute each request sequentially
+ size_t n_reqs = read_reqs.size();
+ uint64_t n_batches = ROUND_UP(n_reqs, MAX_IO_DEPTH) / MAX_IO_DEPTH;
+ for (uint64_t i = 0; i < n_batches; i++)
+ {
+ // reset all OVERLAPPED objects
+ for (auto &os : ctx.reqs)
+ {
+ // HANDLE evt = os.hEvent;
+ memset(&os, 0, sizeof(os));
+ // os.hEvent = evt;
+
+ /*
+ if (ResetEvent(os.hEvent) == 0) {
+ diskann::cerr << "ResetEvent failed" << std::endl;
+ exit(-3);
+ }
+ */
+ }
+
+ // batch start/end
+ uint64_t batch_start = MAX_IO_DEPTH * i;
+ uint64_t batch_size = std::min((uint64_t)(n_reqs - batch_start), (uint64_t)MAX_IO_DEPTH);
+
+ // fill OVERLAPPED and issue them
+ for (uint64_t j = 0; j < batch_size; j++)
+ {
+ AlignedRead &req = read_reqs[batch_start + j];
+ OVERLAPPED &os = ctx.reqs[j];
+
+ uint64_t offset = req.offset;
+ uint64_t nbytes = req.len;
+ char *read_buf = (char *)req.buf;
+ assert(IS_ALIGNED(read_buf, SECTOR_LEN));
+ assert(IS_ALIGNED(offset, SECTOR_LEN));
+ assert(IS_ALIGNED(nbytes, SECTOR_LEN));
+
+ // fill in OVERLAPPED struct
+ os.Offset = offset & 0xffffffff;
+ os.OffsetHigh = (offset >> 32);
+
+ BOOL ret = ReadFile(ctx.fhandle, read_buf, (DWORD)nbytes, NULL, &os);
+ if (ret == FALSE)
+ {
+ auto error = GetLastError();
+ if (error != ERROR_IO_PENDING)
+ {
+ diskann::cerr << "Error queuing IO -- " << error << "\n";
+ }
+ }
+ else
+ {
+ diskann::cerr << "Error queueing IO -- ReadFile returned TRUE" << std::endl;
+ }
+ }
+ DWORD n_read = 0;
+ uint64_t n_complete = 0;
+ ULONG_PTR completion_key = 0;
+ OVERLAPPED *lp_os;
+ while (n_complete < batch_size)
+ {
+ if (GetQueuedCompletionStatus(ctx.iocp, &n_read, &completion_key, &lp_os, INFINITE) != 0)
+ {
+ // successfully dequeued a completed I/O
+ n_complete++;
+ }
+ else
+ {
+ // failed to dequeue OR dequeued failed I/O
+ if (lp_os == NULL)
+ {
+ DWORD error = GetLastError();
+ if (error != WAIT_TIMEOUT)
+ {
+ diskann::cerr << "GetQueuedCompletionStatus() failed "
+ "with error = "
+ << error << std::endl;
+ throw diskann::ANNException("GetQueuedCompletionStatus failed with error: ", error, __FUNCSIG__,
+ __FILE__, __LINE__);
+ }
+ // no completion packet dequeued ==> sleep for 5us and try
+ // again
+ std::this_thread::sleep_for(5us);
+ }
+ else
+ {
+ // completion packet for failed IO dequeued
+ auto op_idx = lp_os - ctx.reqs.data();
+ std::stringstream stream;
+ stream << "I/O failed , offset: " << read_reqs[op_idx].offset
+ << "with error code: " << GetLastError() << std::endl;
+ throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
+ }
+ }
+ }
+ }
+}
+#endif
+#endif
diff --git a/be/src/index-tools/ann_tool.cpp b/be/src/index-tools/ann_tool.cpp
new file mode 100644
index 0000000..1d9b6b5
--- /dev/null
+++ b/be/src/index-tools/ann_tool.cpp
@@ -0,0 +1,361 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <CLucene.h>
+#include <CLucene/config/repl_wchar.h>
+#include <gen_cpp/PaloInternalService_types.h>
+#include <gen_cpp/olap_file.pb.h>
+#include <gflags/gflags.h>
+
+#include <filesystem>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <nlohmann/json.hpp>
+#include <roaring/roaring.hh>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "io/fs/file_reader.h"
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wshadow-field"
+#endif
+#include "CLucene/analysis/standard95/StandardAnalyzer.h"
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
+#include "gutil/strings/strip.h"
+#include "io/fs/local_file_system.h"
+#include "olap/rowset/segment_v2/inverted_index/query/conjunction_query.h"
+#include "olap/rowset/segment_v2/inverted_index_compound_reader.h"
+#include "olap/rowset/segment_v2/inverted_index_desc.h"
+#include "olap/rowset/segment_v2/inverted_index_file_reader.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
+#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
+#include "olap/tablet_schema.h"
+#include "io/fs/file_system.h"
+#include "vector/diskann_vector_index.h"
+#include "vector/vector_index.h"
+#include "olap/rowset/segment_v2/ann_index_writer.h"
+
+#include "olap/options.h"
+#include "common/signal_handler.h"
+#include "util/disk_info.h"
+#include "util/mem_info.h"
+#include "io/fs/file_writer.h"
+
+using doris::segment_v2::DorisCompoundReader;
+using doris::segment_v2::DorisFSDirectoryFactory;
+using doris::segment_v2::XIndexFileWriter;
+using doris::segment_v2::InvertedIndexDescriptor;
+using doris::segment_v2::InvertedIndexFileReader;
+using doris::io::FileInfo;
+using doris::TabletIndex;
+using namespace doris::segment_v2;
+using namespace lucene::analysis;
+using namespace lucene::index;
+using namespace lucene::util;
+using namespace lucene::search;
+using doris::io::FileSystem;
+
+#include "io/fs/path.h"
+
+
+const std::string& file_dir = "/home/users/clz/run/test_diskann/123456_0.idx";
+std::filesystem::path index_data_path(file_dir);
+
+int index_id = 100;
+std::string rowset_id="rowset_id";
+int seg_id = 0;
+
+
+std::shared_ptr<FileSystem> get_local_file_filesystem(){
+ return doris::io::global_local_filesystem();
+}
+
+void test_add(){
+ auto fs = get_local_file_filesystem();
+ int index_id = 100;
+
+ doris::io::FileWriterPtr file_writer;
+
+ auto st = doris::io::global_local_filesystem()->create_file(file_dir, &file_writer);
+ if(!st.ok()){
+ std::cerr<<"failed create_file" << file_dir << std::endl;
+ }
+ std::unique_ptr<XIndexFileWriter> index_file_writer = std::make_unique<XIndexFileWriter>(
+ fs,
+ index_data_path.parent_path(),
+ rowset_id, seg_id, doris::InvertedIndexStorageFormatPB::V2, std::move(file_writer));
+
+ doris::TabletIndexPB index_pb;
+
+ index_pb.set_index_id(index_id);
+ TabletIndex index_meta;
+ index_meta.init_from_pb(index_pb);
+ index_meta._index_type = doris::IndexType::ANN;
+ index_meta._properties["index_type"]="diskann";
+ index_meta._properties["metric_type"]="l2";
+ index_meta._properties["dim"]="7";
+ index_meta._properties["max_degree"]="32";
+ index_meta._properties["search_list"]="100";
+
+
+ std::string field_name="word_embeding";
+ std::unique_ptr<AnnIndexColumnWriter> ann_writer = std::make_unique<AnnIndexColumnWriter>(
+ field_name, index_file_writer.get(), &index_meta, true);
+ st = ann_writer->init();
+ if (!st.ok()) {
+ std::cout << "failed to ann_writer->init()" << std::endl;
+ return ;
+ }
+
+ //writer
+ int field_size = 4;
+ float value[14]={1,2,3,4,5,6,7,7,6,5,4,3,2,1};
+ int64_t offsets[3]={0,7,14};
+ st = ann_writer->add_array_values(field_size, (const void*)value, nullptr, (const uint8_t*)offsets, 2);
+ if (!st.ok()) {
+ std::cout << "failed to ann_writer->add_array_values" << std::endl;
+ return ;
+ }
+
+ //finish
+ st = ann_writer->finish();
+ if (!st.ok()) {
+ std::cout << "failed to ann_writer->finish" << std::endl;
+ return ;
+ }
+
+ //save to disk
+ st = index_file_writer->close();
+ if (!st.ok()) {
+ std::cout << "failed to indexwriter->close" << std::endl;
+ return ;
+ }
+}
+
+
+void init_env(){
+ doris::CpuInfo::init();
+ doris::DiskInfo::init();
+ doris::MemInfo::init();
+
+ string custom_conffile = "/home/users/clz/run/be_1.conf";
+ if (!doris::config::init(custom_conffile.c_str(), true, false, true)) {
+ fprintf(stderr, "error read custom config file. \n");
+ return ;
+ }
+
+ std::vector<doris::StorePath> paths;
+ std::string storage="/home/users/clz/run/storage";
+ std::string spill="/home/users/clz/run/splill";
+ std::string broken_storage_path="/home/users/clz/run/broken";
+
+ auto olap_res = doris::parse_conf_store_paths(storage, &paths);
+ if (!olap_res) {
+ LOG(ERROR) << "parse config storage path failed, path=" << storage;
+ exit(-1);
+ }
+
+ std::vector<doris::StorePath> spill_paths;
+ olap_res = doris::parse_conf_store_paths(spill, &spill_paths);
+ if (!olap_res) {
+ LOG(ERROR) << "parse config spill storage path failed, path="
+ << spill;
+ exit(-1);
+ }
+ std::set<std::string> broken_paths;
+ doris::parse_conf_broken_store_paths(broken_storage_path, &broken_paths);
+
+ // auto it = paths.begin();
+ // for (; it != paths.end();) {
+ // if (broken_paths.count(it->path) > 0) {
+ // if (doris::config::ignore_broken_disk) {
+ // LOG(WARNING) << "ignore broken disk, path = " << it->path;
+ // it = paths.erase(it);
+ // } else {
+ // LOG(ERROR) << "a broken disk is found " << it->path;
+ // exit(-1);
+ // }
+ // } else if (!doris::check_datapath_rw(it->path)) {
+ // if (doris::config::ignore_broken_disk) {
+ // LOG(WARNING) << "read write test file failed, path=" << it->path;
+ // it = paths.erase(it);
+ // } else {
+ // LOG(ERROR) << "read write test file failed, path=" << it->path;
+ // // if only one disk and the disk is full, also need exit because rocksdb will open failed
+ // exit(-1);
+ // }
+ // } else {
+ // ++it;
+ // }
+ // }
+
+ // if (paths.empty()) {
+ // LOG(ERROR) << "All disks are broken, exit.";
+ // exit(-1);
+ // }
+
+ // it = spill_paths.begin();
+ // for (; it != spill_paths.end();) {
+ // if (!doris::check_datapath_rw(it->path)) {
+ // if (doris::config::ignore_broken_disk) {
+ // LOG(WARNING) << "read write test file failed, path=" << it->path;
+ // it = spill_paths.erase(it);
+ // } else {
+ // LOG(ERROR) << "read write test file failed, path=" << it->path;
+ // exit(-1);
+ // }
+ // } else {
+ // ++it;
+ // }
+ // }
+ // if (spill_paths.empty()) {
+ // LOG(ERROR) << "All spill disks are broken, exit.";
+ // exit(-1);
+ // }
+
+ // // initialize libcurl here to avoid concurrent initialization
+ // auto curl_ret = curl_global_init(CURL_GLOBAL_ALL);
+ // if (curl_ret != 0) {
+ // LOG(ERROR) << "fail to initialize libcurl, curl_ret=" << curl_ret;
+ // exit(-1);
+ // }
+ // // add logger for thrift internal
+ // apache::thrift::GlobalOutput.setOutputFunction(doris::thrift_output);
+
+ // Status status = Status::OK();
+ // if (doris::config::enable_java_support) {
+ // // Init jni
+ // status = doris::JniUtil::Init();
+ // if (!status.ok()) {
+ // LOG(WARNING) << "Failed to initialize JNI: " << status;
+ // exit(1);
+ // } else {
+ // LOG(INFO) << "Doris backend JNI is initialized.";
+ // }
+ // }
+
+ // // Doris own signal handler must be register after jvm is init.
+ // // Or our own sig-handler for SIGINT & SIGTERM will not be chained ...
+ // // https://www.oracle.com/java/technologies/javase/signals.html
+ // doris::init_signals();
+ // // ATTN: MUST init before `ExecEnv`, `StorageEngine` and other daemon services
+ // //
+ // // Daemon ───┬──► StorageEngine ──► ExecEnv ──► Disk/Mem/CpuInfo
+ // // │
+ // // │
+ // // BackendService ─┘
+ // doris::CpuInfo::init();
+ // doris::DiskInfo::init();
+ // doris::MemInfo::init();
+
+ // LOG(INFO) << doris::CpuInfo::debug_string();
+ // LOG(INFO) << doris::DiskInfo::debug_string();
+ // LOG(INFO) << doris::MemInfo::debug_string();
+
+ // // PHDR speed up exception handling, but exceptions from dynamically loaded libraries (dlopen)
+ // // will work only after additional call of this function.
+ // // rewrites dl_iterate_phdr will cause Jemalloc to fail to run after enable profile. see #
+ // // updatePHDRCache();
+ // if (!doris::BackendOptions::init()) {
+ // exit(-1);
+ // }
+
+ // doris::ThreadLocalHandle::create_thread_local_if_not_exits();
+
+ // init exec env
+ //auto* exec_env(doris::ExecEnv::GetInstance());
+ doris::Status status = doris::ExecEnv::init(doris::ExecEnv::GetInstance(), paths, spill_paths, broken_paths);
+ if(!status.ok()){
+ std::cout << "init fail" << std::endl;
+ }
+}
+
+void test_search(){
+ auto fs = get_local_file_filesystem();
+ auto index_file_reader = std::make_unique<InvertedIndexFileReader>(
+ fs, "/home/users/clz/run/test_diskann/123456_0", doris::InvertedIndexStorageFormatPB::V2);
+ auto st = index_file_reader->init(4096);
+ if (!st.ok()) {
+ std::cout << "failed to index_file_reader->init" << st << std::endl;
+ return ;
+ }
+ doris::TabletIndexPB index_pb;
+ index_pb.set_index_id(index_id);
+ TabletIndex index_meta;
+ index_meta.init_from_pb(index_pb);
+
+
+ auto ret = index_file_reader->open(&index_meta);
+ if (!ret.has_value()) {
+ std::cerr << "InvertedIndexFileReader open error:" << ret.error() << std::endl;
+ return ;
+ }
+ using T = std::decay_t<decltype(ret)>;
+ std::shared_ptr<DorisCompoundReader> dir = std::forward<T>(ret).value();
+
+ std::shared_ptr<DiskannVectorIndex> vindex = std::make_shared<DiskannVectorIndex>(dir);
+ st = vindex->load(VectorIndex::Metric::L2);
+ if (!st.ok()) {
+ std::cout << "failed to vindex->load" << std::endl;
+ return ;
+ }
+ float query_vec[7]={1,2,3,4,5,6,7};
+ SearchResult result;
+ std::shared_ptr<DiskannSearchParameter> searchParams = std::make_shared<DiskannSearchParameter>();
+ searchParams->with_search_list(100);
+ searchParams->with_beam_width(2);
+
+ //设置过滤条件
+ std::shared_ptr<IDFilter> filter = nullptr;
+ std::shared_ptr<roaring::Roaring> bitmap = std::make_shared<roaring::Roaring>();
+ // bitmap->add(1);
+ // filter.reset(new IDFilter(bitmap));
+ // searchParams->set_filter(filter);
+ st = vindex->search(query_vec, 5, &result, searchParams.get());
+ if (!st.ok()) {
+ std::cout << "failed to vindex->search" << std::endl;
+ return ;
+ }
+ if(result.has_rows()){
+ for(int i=0;i<result.row_count();i++){
+ std::cout << "idx:" << result.get_id(i) << ", distance:" << result.get_distance(i) << std::endl;
+ }
+ }
+}
+
+int main(int argc, char** argv) {
+ if (argc < 2) {
+ std::cerr << "Usage: " << argv[0] << " <test_add|test_search>" << std::endl;
+ return 1;
+ }
+ doris::signal::InstallFailureSignalHandler();
+ init_env();
+ std::string command = argv[1];
+ if(command=="add"){
+ test_add();
+ }else if(command=="search"){
+ test_search();
+ }else{
+ std::cout << "unkonw command" << std::endl;
+ }
+ return 0;
+}
diff --git a/be/src/index-tools/index_tool.cpp b/be/src/index-tools/index_tool.cpp
index e45902c..bcad403 100644
--- a/be/src/index-tools/index_tool.cpp
+++ b/be/src/index-tools/index_tool.cpp
@@ -46,13 +46,13 @@
#include "olap/rowset/segment_v2/inverted_index_compound_reader.h"
#include "olap/rowset/segment_v2/inverted_index_desc.h"
#include "olap/rowset/segment_v2/inverted_index_file_reader.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
#include "olap/tablet_schema.h"
using doris::segment_v2::DorisCompoundReader;
using doris::segment_v2::DorisFSDirectoryFactory;
-using doris::segment_v2::InvertedIndexFileWriter;
+using doris::segment_v2::XIndexFileWriter;
using doris::segment_v2::InvertedIndexDescriptor;
using doris::segment_v2::InvertedIndexFileReader;
using doris::io::FileInfo;
@@ -548,14 +548,14 @@
}
auto fs = doris::io::global_local_filesystem();
- auto index_file_writer = std::make_unique<InvertedIndexFileWriter>(
+ auto index_file_writer = std::make_unique<XIndexFileWriter>(
fs,
std::string {InvertedIndexDescriptor::get_index_file_path_prefix(
doris::local_segment_path(file_dir, rowset_id, seg_id))},
rowset_id, seg_id, doris::InvertedIndexStorageFormatPB::V2);
auto st = index_file_writer->open(&index_meta);
if (!st.has_value()) {
- std::cerr << "InvertedIndexFileWriter init error:" << st.error() << std::endl;
+ std::cerr << "XIndexFileWriter init error:" << st.error() << std::endl;
return -1;
}
using T = std::decay_t<decltype(st)>;
@@ -599,7 +599,7 @@
auto ret = index_file_writer->close();
if (!ret.ok()) {
- std::cerr << "InvertedIndexFileWriter close error:" << ret.msg() << std::endl;
+ std::cerr << "XIndexFileWriter close error:" << ret.msg() << std::endl;
return -1;
}
} else if (FLAGS_operation == "show_nested_files_v2") {
diff --git a/be/src/io/CMakeLists.txt b/be/src/io/CMakeLists.txt
index 02b34f2..38ee068 100644
--- a/be/src/io/CMakeLists.txt
+++ b/be/src/io/CMakeLists.txt
@@ -16,6 +16,11 @@
# under the License.
# where to put generated libraries
+set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
+add_compile_options(-Wno-error=attributes)
+add_compile_options(-Wno-deprecated-copy)
+add_compile_options(-Wno-reorder)
+add_compile_options(-Wno-unused-but-set-variable)
set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/io")
# where to put generated binaries
diff --git a/be/src/olap/CMakeLists.txt b/be/src/olap/CMakeLists.txt
index bf19ef2..34b01aa 100644
--- a/be/src/olap/CMakeLists.txt
+++ b/be/src/olap/CMakeLists.txt
@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
+set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
+add_compile_options(-Wno-error=attributes)
+
# where to put generated libraries
set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/olap")
diff --git a/be/src/olap/compaction.cpp b/be/src/olap/compaction.cpp
index aec3869..0e9af91 100644
--- a/be/src/olap/compaction.cpp
+++ b/be/src/olap/compaction.cpp
@@ -60,7 +60,7 @@
#include "olap/rowset/segment_v2/inverted_index_compaction.h"
#include "olap/rowset/segment_v2/inverted_index_desc.h"
#include "olap/rowset/segment_v2/inverted_index_file_reader.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
#include "olap/storage_engine.h"
#include "olap/storage_policy.h"
@@ -685,21 +685,21 @@
// dest index files
// format: rowsetId_segmentId
- auto& inverted_index_file_writers = dynamic_cast<BaseBetaRowsetWriter*>(_output_rs_writer.get())
- ->inverted_index_file_writers();
+ auto& x_index_file_writers = dynamic_cast<BaseBetaRowsetWriter*>(_output_rs_writer.get())
+ ->x_index_file_writers();
DBUG_EXECUTE_IF(
- "Compaction::do_inverted_index_compaction_inverted_index_file_writers_size_error",
- { inverted_index_file_writers.clear(); })
- if (inverted_index_file_writers.size() != dest_segment_num) {
+ "Compaction::do_inverted_index_compaction_x_index_file_writers_size_error",
+ { x_index_file_writers.clear(); })
+ if (x_index_file_writers.size() != dest_segment_num) {
LOG(WARNING) << "failed to do index compaction, dest segment num not match. tablet_id="
<< _tablet->tablet_id() << " dest_segment_num=" << dest_segment_num
- << " inverted_index_file_writers.size()="
- << inverted_index_file_writers.size();
+ << " x_index_file_writers.size()="
+ << x_index_file_writers.size();
mark_skip_index_compaction(ctx, error_handler);
return Status::Error<INVERTED_INDEX_COMPACTION_ERROR>(
"dest segment num not match. tablet_id={} dest_segment_num={} "
- "inverted_index_file_writers.size()={}",
- _tablet->tablet_id(), dest_segment_num, inverted_index_file_writers.size());
+ "x_index_file_writers.size()={}",
+ _tablet->tablet_id(), dest_segment_num, x_index_file_writers.size());
}
// use tmp file dir to store index files
@@ -745,10 +745,10 @@
src_idx_dirs[src_segment_id] = std::move(res.value());
}
for (int dest_segment_id = 0; dest_segment_id < dest_segment_num; dest_segment_id++) {
- auto res = inverted_index_file_writers[dest_segment_id]->open(index_meta);
- DBUG_EXECUTE_IF("Compaction::open_inverted_index_file_writer", {
+ auto res = x_index_file_writers[dest_segment_id]->open(index_meta);
+ DBUG_EXECUTE_IF("Compaction::open_x_index_file_writer", {
res = ResultError(Status::Error<ErrorCode::INVERTED_INDEX_CLUCENE_ERROR>(
- "debug point: Compaction::open_inverted_index_file_writer error"));
+ "debug point: Compaction::open_x_index_file_writer error"));
})
if (!res.has_value()) {
LOG(WARNING) << "failed to do index compaction, open inverted index file "
@@ -759,7 +759,7 @@
throw Exception(ErrorCode::INVERTED_INDEX_COMPACTION_ERROR, res.error().msg());
}
// Destination directories in dest_index_dirs do not need to be deconstructed,
- // but their lifecycle must be managed by inverted_index_file_writers.
+ // but their lifecycle must be managed by x_index_file_writers.
dest_index_dirs[dest_segment_id] = res.value().get();
}
auto st = compact_column(index_meta->index_id(), src_idx_dirs, dest_index_dirs,
diff --git a/be/src/olap/rowset/beta_rowset_writer.cpp b/be/src/olap/rowset/beta_rowset_writer.cpp
index 9207c52..43d537a 100644
--- a/be/src/olap/rowset/beta_rowset_writer.cpp
+++ b/be/src/olap/rowset/beta_rowset_writer.cpp
@@ -192,20 +192,20 @@
InvertedIndexFileCollection::~InvertedIndexFileCollection() = default;
-Status InvertedIndexFileCollection::add(int seg_id, InvertedIndexFileWriterPtr&& index_writer) {
+Status InvertedIndexFileCollection::add(int seg_id, XIndexFileWriterPtr&& index_writer) {
std::lock_guard lock(_lock);
- if (_inverted_index_file_writers.find(seg_id) != _inverted_index_file_writers.end())
+ if (_x_index_file_writers.find(seg_id) != _x_index_file_writers.end())
[[unlikely]] {
DCHECK(false);
return Status::InternalError("The seg_id already exists, seg_id is {}", seg_id);
}
- _inverted_index_file_writers.emplace(seg_id, std::move(index_writer));
+ _x_index_file_writers.emplace(seg_id, std::move(index_writer));
return Status::OK();
}
Status InvertedIndexFileCollection::close() {
std::lock_guard lock(_lock);
- for (auto&& [id, writer] : _inverted_index_file_writers) {
+ for (auto&& [id, writer] : _x_index_file_writers) {
RETURN_IF_ERROR(writer->close());
_total_size += writer->get_index_file_total_size();
}
@@ -218,9 +218,9 @@
std::lock_guard lock(_lock);
Status st;
- std::vector<const InvertedIndexFileInfo*> idx_file_info(_inverted_index_file_writers.size());
+ std::vector<const InvertedIndexFileInfo*> idx_file_info(_x_index_file_writers.size());
bool succ = std::all_of(
- _inverted_index_file_writers.begin(), _inverted_index_file_writers.end(),
+ _x_index_file_writers.begin(), _x_index_file_writers.end(),
[&](auto&& it) {
auto&& [seg_id, writer] = it;
@@ -233,7 +233,7 @@
st = Status::InternalError(err_msg);
return false;
}
- idx_file_info[idx] = _inverted_index_file_writers[seg_id]->get_index_file_info();
+ idx_file_info[idx] = _x_index_file_writers[seg_id]->get_index_file_info();
return true;
});
@@ -952,9 +952,9 @@
fmt::format("failed to create file = {}, file type = {}", segment_path, file_type));
}
-Status BaseBetaRowsetWriter::create_inverted_index_file_writer(
- uint32_t segment_id, InvertedIndexFileWriterPtr* index_file_writer) {
- RETURN_IF_ERROR(RowsetWriter::create_inverted_index_file_writer(segment_id, index_file_writer));
+Status BaseBetaRowsetWriter::create_x_index_file_writer(
+ uint32_t segment_id, XIndexFileWriterPtr* index_file_writer) {
+ RETURN_IF_ERROR(RowsetWriter::create_x_index_file_writer(segment_id, index_file_writer));
// used for inverted index format v1
(*index_file_writer)->set_file_writer_opts(_context.get_file_writer_options());
return Status::OK();
@@ -968,7 +968,7 @@
io::FileWriterPtr file_writer;
RETURN_IF_ERROR(_create_file_writer(path, file_writer));
- InvertedIndexFileWriterPtr index_file_writer;
+ XIndexFileWriterPtr index_file_writer;
if (_context.tablet_schema->has_inverted_index()) {
io::FileWriterPtr idx_file_writer;
std::string prefix(InvertedIndexDescriptor::get_index_file_path_prefix(path));
@@ -977,7 +977,7 @@
std::string index_path = InvertedIndexDescriptor::get_index_file_path_v2(prefix);
RETURN_IF_ERROR(_create_file_writer(index_path, idx_file_writer));
}
- index_file_writer = std::make_unique<InvertedIndexFileWriter>(
+ index_file_writer = std::make_unique<XIndexFileWriter>(
_context.fs(), prefix, _context.rowset_id.to_string(), _num_segcompacted,
_context.tablet_schema->get_inverted_index_storage_format(),
std::move(idx_file_writer));
@@ -999,11 +999,11 @@
RETURN_IF_ERROR(_segcompaction_worker->get_file_writer()->close());
}
_segcompaction_worker->get_file_writer().reset(file_writer.release());
- if (auto& idx_file_writer = _segcompaction_worker->get_inverted_index_file_writer();
+ if (auto& idx_file_writer = _segcompaction_worker->get_x_index_file_writer();
idx_file_writer != nullptr) {
RETURN_IF_ERROR(idx_file_writer->close());
}
- _segcompaction_worker->get_inverted_index_file_writer().reset(index_file_writer.release());
+ _segcompaction_worker->get_x_index_file_writer().reset(index_file_writer.release());
return Status::OK();
}
diff --git a/be/src/olap/rowset/beta_rowset_writer.h b/be/src/olap/rowset/beta_rowset_writer.h
index a69d106..33278ba 100644
--- a/be/src/olap/rowset/beta_rowset_writer.h
+++ b/be/src/olap/rowset/beta_rowset_writer.h
@@ -42,7 +42,7 @@
#include "olap/rowset/rowset_writer.h"
#include "olap/rowset/rowset_writer_context.h"
#include "olap/rowset/segment_creator.h"
-#include "segment_v2/inverted_index_file_writer.h"
+#include "segment_v2/x_index_file_writer.h"
#include "segment_v2/segment.h"
#include "util/spinlock.h"
@@ -90,27 +90,27 @@
~InvertedIndexFileCollection();
// `seg_id` -> inverted index file writer
- Status add(int seg_id, InvertedIndexFileWriterPtr&& writer);
+ Status add(int seg_id, XIndexFileWriterPtr&& writer);
// Close all file writers
// If the inverted index file writer is not closed, an error will be thrown during destruction
Status close();
// Get inverted index file info in segment id order.
- // `seg_id_offset` is the offset of the segment id relative to the subscript of `_inverted_index_file_writers`,
+ // `seg_id_offset` is the offset of the segment id relative to the subscript of `_x_index_file_writers`,
// for more details, see `Tablet::create_transient_rowset_writer`.
Result<std::vector<const InvertedIndexFileInfo*>> inverted_index_file_info(int seg_id_offset);
// return all inverted index file writers
- std::unordered_map<int, InvertedIndexFileWriterPtr>& get_file_writers() {
- return _inverted_index_file_writers;
+ std::unordered_map<int, XIndexFileWriterPtr>& get_file_writers() {
+ return _x_index_file_writers;
}
int64_t get_total_index_size() const { return _total_size; }
private:
mutable SpinLock _lock;
- std::unordered_map<int /* seg_id */, InvertedIndexFileWriterPtr> _inverted_index_file_writers;
+ std::unordered_map<int /* seg_id */, XIndexFileWriterPtr> _x_index_file_writers;
int64_t _total_size = 0;
};
@@ -132,8 +132,8 @@
Status create_file_writer(uint32_t segment_id, io::FileWriterPtr& writer,
FileType file_type = FileType::SEGMENT_FILE) override;
- Status create_inverted_index_file_writer(uint32_t segment_id,
- InvertedIndexFileWriterPtr* writer) override;
+ Status create_x_index_file_writer(uint32_t segment_id,
+ XIndexFileWriterPtr* writer) override;
Status add_segment(uint32_t segment_id, const SegmentStatistics& segstat,
TabletSchemaSPtr flush_schema) override;
@@ -194,7 +194,7 @@
return _seg_files.get_file_writers();
}
- std::unordered_map<int, InvertedIndexFileWriterPtr>& inverted_index_file_writers() {
+ std::unordered_map<int, XIndexFileWriterPtr>& x_index_file_writers() {
return this->_idx_files.get_file_writers();
}
@@ -219,7 +219,7 @@
// Only during vertical compaction is this method called
// Some index files are written during normal compaction and some files are written during index compaction.
// After all index writes are completed, call this method to write the final compound index file.
- Status _close_inverted_index_file_writers() {
+ Status _close_x_index_file_writers() {
RETURN_NOT_OK_STATUS_WITH_WARN(_idx_files.close(),
"failed to close index file when build new rowset");
this->_total_index_size += _idx_files.get_total_index_size();
diff --git a/be/src/olap/rowset/rowset_writer.h b/be/src/olap/rowset/rowset_writer.h
index 0a0d36e..798aa0d 100644
--- a/be/src/olap/rowset/rowset_writer.h
+++ b/be/src/olap/rowset/rowset_writer.h
@@ -31,7 +31,7 @@
#include "olap/column_mapping.h"
#include "olap/rowset/rowset.h"
#include "olap/rowset/rowset_writer_context.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/tablet_fwd.h"
#include "olap/tablet_schema.h"
#include "vec/core/block.h"
@@ -96,8 +96,8 @@
return Status::NotSupported("RowsetWriter does not support create_file_writer");
}
- virtual Status create_inverted_index_file_writer(
- uint32_t segment_id, InvertedIndexFileWriterPtr* index_file_writer) {
+ virtual Status create_x_index_file_writer(
+ uint32_t segment_id, XIndexFileWriterPtr* index_file_writer) {
// Create file writer for the inverted index format v2.
io::FileWriterPtr idx_file_v2_ptr;
if (_context.tablet_schema->get_inverted_index_storage_format() !=
@@ -107,7 +107,7 @@
}
std::string segment_prefix {InvertedIndexDescriptor::get_index_file_path_prefix(
_context.segment_path(segment_id))};
- *index_file_writer = std::make_unique<InvertedIndexFileWriter>(
+ *index_file_writer = std::make_unique<XIndexFileWriter>(
_context.fs(), segment_prefix, _context.rowset_id.to_string(), segment_id,
_context.tablet_schema->get_inverted_index_storage_format(),
std::move(idx_file_v2_ptr));
diff --git a/be/src/olap/rowset/segcompaction.cpp b/be/src/olap/rowset/segcompaction.cpp
index bf12ce8..b482355 100644
--- a/be/src/olap/rowset/segcompaction.cpp
+++ b/be/src/olap/rowset/segcompaction.cpp
@@ -344,8 +344,8 @@
_writer->_num_segcompacted);
}
RETURN_IF_ERROR(_writer->_rename_compacted_segments(begin, end));
- if (_inverted_index_file_writer != nullptr) {
- _inverted_index_file_writer.reset();
+ if (_x_index_file_writer != nullptr) {
+ _x_index_file_writer.reset();
}
if (VLOG_DEBUG_IS_ON) {
_writer->vlog_buffer.clear();
diff --git a/be/src/olap/rowset/segcompaction.h b/be/src/olap/rowset/segcompaction.h
index 0279b5b..497af7d 100644
--- a/be/src/olap/rowset/segcompaction.h
+++ b/be/src/olap/rowset/segcompaction.h
@@ -25,7 +25,7 @@
#include "olap/merger.h"
#include "olap/simple_rowid_conversion.h"
#include "olap/tablet.h"
-#include "segment_v2/inverted_index_file_writer.h"
+#include "segment_v2/x_index_file_writer.h"
#include "segment_v2/segment.h"
namespace doris {
@@ -70,8 +70,8 @@
DeleteBitmapPtr get_converted_delete_bitmap() { return _converted_delete_bitmap; }
io::FileWriterPtr& get_file_writer() { return _file_writer; }
- InvertedIndexFileWriterPtr& get_inverted_index_file_writer() {
- return _inverted_index_file_writer;
+ XIndexFileWriterPtr& get_x_index_file_writer() {
+ return _x_index_file_writer;
}
// set the cancel flag, tasks already started will not be cancelled.
@@ -101,7 +101,7 @@
// Currently cloud storage engine doesn't need segcompaction
BetaRowsetWriter* _writer = nullptr;
io::FileWriterPtr _file_writer;
- InvertedIndexFileWriterPtr _inverted_index_file_writer = nullptr;
+ XIndexFileWriterPtr _x_index_file_writer = nullptr;
// for unique key mow table
std::unique_ptr<SimpleRowIdConversion> _rowid_conversion = nullptr;
diff --git a/be/src/olap/rowset/segment_creator.cpp b/be/src/olap/rowset/segment_creator.cpp
index e0eb753..296168e 100644
--- a/be/src/olap/rowset/segment_creator.cpp
+++ b/be/src/olap/rowset/segment_creator.cpp
@@ -140,10 +140,10 @@
io::FileWriterPtr segment_file_writer;
RETURN_IF_ERROR(_context.file_writer_creator->create(segment_id, segment_file_writer));
- InvertedIndexFileWriterPtr inverted_index_file_writer;
- if (_context.tablet_schema->has_inverted_index()) {
+ XIndexFileWriterPtr x_index_file_writer;
+ if (_context.tablet_schema->has_extra_index()) {
RETURN_IF_ERROR(
- _context.file_writer_creator->create(segment_id, &inverted_index_file_writer));
+ _context.file_writer_creator->create(segment_id, &x_index_file_writer));
}
segment_v2::SegmentWriterOptions writer_options;
@@ -158,10 +158,10 @@
writer = std::make_unique<segment_v2::SegmentWriter>(
segment_file_writer.get(), segment_id, _context.tablet_schema, _context.tablet,
- _context.data_dir, writer_options, inverted_index_file_writer.get());
+ _context.data_dir, writer_options, x_index_file_writer.get());
RETURN_IF_ERROR(_seg_files.add(segment_id, std::move(segment_file_writer)));
- if (_context.tablet_schema->has_inverted_index()) {
- RETURN_IF_ERROR(_idx_files.add(segment_id, std::move(inverted_index_file_writer)));
+ if (_context.tablet_schema->has_extra_index()) {
+ RETURN_IF_ERROR(_idx_files.add(segment_id, std::move(x_index_file_writer)));
}
auto s = writer->init();
if (!s.ok()) {
@@ -178,10 +178,10 @@
io::FileWriterPtr segment_file_writer;
RETURN_IF_ERROR(_context.file_writer_creator->create(segment_id, segment_file_writer));
- InvertedIndexFileWriterPtr inverted_index_file_writer;
- if (_context.tablet_schema->has_inverted_index()) {
+ XIndexFileWriterPtr x_index_file_writer;
+ if (_context.tablet_schema->has_extra_index()) {
RETURN_IF_ERROR(
- _context.file_writer_creator->create(segment_id, &inverted_index_file_writer));
+ _context.file_writer_creator->create(segment_id, &x_index_file_writer));
}
segment_v2::VerticalSegmentWriterOptions writer_options;
@@ -195,10 +195,10 @@
writer = std::make_unique<segment_v2::VerticalSegmentWriter>(
segment_file_writer.get(), segment_id, _context.tablet_schema, _context.tablet,
- _context.data_dir, writer_options, inverted_index_file_writer.get());
+ _context.data_dir, writer_options, x_index_file_writer.get());
RETURN_IF_ERROR(_seg_files.add(segment_id, std::move(segment_file_writer)));
- if (_context.tablet_schema->has_inverted_index()) {
- RETURN_IF_ERROR(_idx_files.add(segment_id, std::move(inverted_index_file_writer)));
+ if (_context.tablet_schema->has_extra_index()) {
+ RETURN_IF_ERROR(_idx_files.add(segment_id, std::move(x_index_file_writer)));
}
auto s = writer->init();
if (!s.ok()) {
diff --git a/be/src/olap/rowset/segment_creator.h b/be/src/olap/rowset/segment_creator.h
index f8afd579..1775926 100644
--- a/be/src/olap/rowset/segment_creator.h
+++ b/be/src/olap/rowset/segment_creator.h
@@ -29,7 +29,7 @@
#include "io/fs/file_reader_writer_fwd.h"
#include "olap/olap_common.h"
#include "olap/rowset/rowset_writer_context.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/tablet_fwd.h"
#include "util/spinlock.h"
#include "vec/core/block.h"
@@ -56,7 +56,7 @@
virtual Status create(uint32_t segment_id, io::FileWriterPtr& file_writer,
FileType file_type = FileType::SEGMENT_FILE) = 0;
- virtual Status create(uint32_t segment_id, InvertedIndexFileWriterPtr* file_writer) = 0;
+ virtual Status create(uint32_t segment_id, XIndexFileWriterPtr* file_writer) = 0;
};
template <class T>
@@ -70,8 +70,8 @@
return _t->create_file_writer(segment_id, file_writer, file_type);
}
- Status create(uint32_t segment_id, InvertedIndexFileWriterPtr* file_writer) override {
- return _t->create_inverted_index_file_writer(segment_id, file_writer);
+ Status create(uint32_t segment_id, XIndexFileWriterPtr* file_writer) override {
+ return _t->create_x_index_file_writer(segment_id, file_writer);
}
private:
diff --git a/be/src/olap/rowset/segment_v2/ann_index_writer.cpp b/be/src/olap/rowset/segment_v2/ann_index_writer.cpp
new file mode 100644
index 0000000..0368843
--- /dev/null
+++ b/be/src/olap/rowset/segment_v2/ann_index_writer.cpp
@@ -0,0 +1,115 @@
+#include "olap/rowset/segment_v2/ann_index_writer.h"
+
+namespace doris::segment_v2 {
+
+
+AnnIndexColumnWriter::AnnIndexColumnWriter(const std::string& field_name,
+ XIndexFileWriter* index_file_writer,
+ const TabletIndex* index_meta,
+ const bool single_field)
+ : _single_field(single_field),
+ _index_file_writer(index_file_writer),
+ _index_meta(index_meta) {
+ _field_name = StringUtil::string_to_wstring(field_name);
+}
+
+
+AnnIndexColumnWriter::~AnnIndexColumnWriter() {}
+
+
+Status AnnIndexColumnWriter::init() {
+ RETURN_IF_ERROR(open_index_directory());
+ RETURN_IF_ERROR(init_ann_index());
+ return Status::OK();
+}
+
+std::string get_or_default(const std::map<std::string, std::string>& properties,
+ const std::string& key,
+ const std::string& default_value) {
+ auto it = properties.find(key);
+ if (it != properties.end()) {
+ return it->second;
+ }
+ return default_value;
+}
+
+Status AnnIndexColumnWriter::init_ann_index(){
+ if(get_or_default(_index_meta->properties(), INDEX_TYPE, "")=="diskann"){
+ _vector_index_writer = std::make_shared<DiskannVectorIndex>(_dir);
+ std::shared_ptr<DiskannBuilderParameter> builderParameterPtr = std::make_shared<DiskannBuilderParameter>();
+ builderParameterPtr->with_dim(std::stoi(get_or_default(_index_meta->properties(), DIM,"")))
+ .with_L(std::stoi(get_or_default(_index_meta->properties(), DISKANN_SEARCH_LIST,"")))
+ .with_R(std::stoi(get_or_default(_index_meta->properties(), DISKANN_MAX_DEGREE,"")))
+ .with_build_num_threads(8)
+ .with_sample_rate(1)
+ .with_indexing_ram_budget_mb(10*1024)
+ .with_search_ram_budget_mb(30)
+ .with_mertic_type(VectorIndex::string_to_metric(get_or_default(_index_meta->properties(), METRIC_TYPE,"")));
+ _vector_index_writer->set_build_params(std::static_pointer_cast<BuilderParameter>(builderParameterPtr));
+ return Status::OK();
+ }else{
+ return Status::NotSupported("index type is invalid, only support diskann");
+ }
+}
+
+
+
+Status AnnIndexColumnWriter::open_index_directory() {
+ _dir = DORIS_TRY(_index_file_writer->open(_index_meta));
+ return Status::OK();
+}
+
+
+Status AnnIndexColumnWriter::add_values(const std::string fn, const void* values, size_t count){
+ return Status::OK();
+}
+
+void AnnIndexColumnWriter::close_on_error(){
+
+}
+
+Status AnnIndexColumnWriter::add_array_values(size_t field_size, const void* value_ptr,
+ const uint8_t* null_map,
+ const uint8_t* offsets_ptr, size_t count) {
+ if (count == 0) {
+ return Status::OK();
+ }
+ const auto* offsets = reinterpret_cast<const uint64_t*>(offsets_ptr);
+ size_t start_off = 0;
+ for (int i = 0; i < count; ++i) {
+ auto array_elem_size = offsets[i + 1] - offsets[i];
+ const float* p = &reinterpret_cast<const float*>(value_ptr)[start_off];
+ RETURN_IF_ERROR(_vector_index_writer->add(1, p));
+ start_off += array_elem_size;
+ _rid++;
+ }
+ return Status::OK();
+}
+
+Status AnnIndexColumnWriter::add_array_values(size_t field_size, const CollectionValue* values,
+ size_t count) {
+ return Status::OK();
+}
+
+Status AnnIndexColumnWriter::add_nulls(uint32_t count) {
+ // 实现逻辑
+ return Status::OK();
+}
+
+Status AnnIndexColumnWriter::add_array_nulls(uint32_t row_id) {
+ // 实现逻辑
+ return Status::OK();
+}
+
+
+
+int64_t AnnIndexColumnWriter::size() const {
+ return 0; // TODO: 获取倒排索引的内存大小
+}
+
+
+Status AnnIndexColumnWriter::finish() {
+ return _vector_index_writer->save();
+}
+
+} // namespace doris::segment_v2
diff --git a/be/src/olap/rowset/segment_v2/ann_index_writer.h b/be/src/olap/rowset/segment_v2/ann_index_writer.h
new file mode 100644
index 0000000..19e44b9
--- /dev/null
+++ b/be/src/olap/rowset/segment_v2/ann_index_writer.h
@@ -0,0 +1,111 @@
+#pragma once
+
+
+#include "olap/rowset/segment_v2/index_writer.h"
+
+#include <CLucene.h> // IWYU pragma: keep
+#include <CLucene/analysis/LanguageBasedAnalyzer.h>
+#include <CLucene/util/bkd/bkd_writer.h>
+#include <glog/logging.h>
+
+#include <limits>
+#include <memory>
+#include <ostream>
+#include <roaring/roaring.hh>
+#include <string>
+#include <vector>
+
+#include "io/fs/local_file_system.h"
+
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wshadow-field"
+#endif
+
+#include "CLucene/analysis/standard95/StandardAnalyzer.h"
+
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
+
+#include "common/config.h"
+#include "gutil/strings/strip.h"
+#include "olap/field.h"
+#include "olap/inverted_index_parser.h"
+#include "olap/key_coder.h"
+#include "olap/olap_common.h"
+#include "olap/rowset/segment_v2/common.h"
+#include "olap/rowset/segment_v2/inverted_index/analyzer/analyzer.h"
+#include "olap/rowset/segment_v2/inverted_index/char_filter/char_filter_factory.h"
+#include "olap/rowset/segment_v2/inverted_index_common.h"
+#include "olap/rowset/segment_v2/inverted_index_desc.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
+#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
+#include "olap/tablet_schema.h"
+#include "olap/types.h"
+#include "runtime/collection_value.h"
+#include "runtime/exec_env.h"
+#include "util/debug_points.h"
+#include "util/faststring.h"
+#include "util/slice.h"
+#include "util/string_util.h"
+
+#include "olap/rowset/segment_v2/index_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
+#include "vector/vector_index.h"
+#include "vector/diskann_vector_index.h"
+#include "olap/tablet_schema.h"
+#include "util/string_util.h"
+
+namespace doris::segment_v2 {
+
+const int32_t MAX_FIELD_LEN = 0x7FFFFFFFL;
+const int32_t MERGE_FACTOR = 100000000;
+const int32_t MAX_LEAF_COUNT = 1024;
+const float MAXMBSortInHeap = 512.0 * 8;
+const int DIMS = 1;
+
+class AnnIndexColumnWriter : public IndexColumnWriter {
+public:
+ static constexpr const char* INDEX_TYPE = "index_type";
+ static constexpr const char* METRIC_TYPE = "metric_type";
+ static constexpr const char* DIM = "dim";
+ static constexpr const char* DISKANN_MAX_DEGREE = "max_degree";
+ static constexpr const char* DISKANN_SEARCH_LIST = "search_list";
+
+
+ explicit AnnIndexColumnWriter(const std::string& field_name,
+ XIndexFileWriter* index_file_writer,
+ const TabletIndex* index_meta,
+ const bool single_field = true);
+
+ ~AnnIndexColumnWriter() override;
+
+ Status init() override;
+ void close_on_error() override;
+ Status add_nulls(uint32_t count) override;
+ Status add_array_nulls(uint32_t row_id) override;
+ Status add_values(const std::string fn, const void* values, size_t count) override;
+ Status add_array_values(size_t field_size, const void* value_ptr, const uint8_t* null_map,
+ const uint8_t* offsets_ptr, size_t count) override;
+ Status add_array_values(size_t field_size, const CollectionValue* values,
+ size_t count) override;
+ int64_t size() const override;
+ Status finish() override;
+
+private:
+ Status open_index_directory();
+ Status init_ann_index();
+
+private:
+ rowid_t _rid = 0;
+ bool _single_field = true;
+ std::shared_ptr<DorisFSDirectory> _dir = nullptr;
+ std::shared_ptr<VectorIndex> _vector_index_writer;
+ XIndexFileWriter* _index_file_writer;
+ uint32_t _ignore_above;
+ std::wstring _field_name;
+ const TabletIndex* _index_meta;
+};
+
+} // namespace doris::segment_v2
diff --git a/be/src/olap/rowset/segment_v2/column_writer.cpp b/be/src/olap/rowset/segment_v2/column_writer.cpp
index 2637017..05d33a2 100644
--- a/be/src/olap/rowset/segment_v2/column_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/column_writer.cpp
@@ -32,7 +32,7 @@
#include "olap/rowset/segment_v2/bloom_filter.h"
#include "olap/rowset/segment_v2/bloom_filter_index_writer.h"
#include "olap/rowset/segment_v2/encoding_info.h"
-#include "olap/rowset/segment_v2/inverted_index_writer.h"
+#include "olap/rowset/segment_v2/index_writer.h"
#include "olap/rowset/segment_v2/options.h"
#include "olap/rowset/segment_v2/ordinal_page_index.h"
#include "olap/rowset/segment_v2/page_builder.h"
@@ -49,6 +49,7 @@
#include "vec/core/types.h"
#include "vec/data_types/data_type_agg_state.h"
#include "vec/data_types/data_type_factory.hpp"
+#include "olap/rowset/segment_v2/index_writer.h"
namespace doris::segment_v2 {
@@ -452,7 +453,7 @@
if (_opts.need_inverted_index) {
do {
DBUG_EXECUTE_IF("column_writer.init", {
- class InvertedIndexColumnWriterEmptyImpl final : public InvertedIndexColumnWriter {
+ class IndexColumnWriterEmptyImpl final : public IndexColumnWriter {
public:
Status init() override { return Status::OK(); }
Status add_values(const std::string name, const void* values,
@@ -475,13 +476,13 @@
void close_on_error() override {}
};
- _inverted_index_builder = std::make_unique<InvertedIndexColumnWriterEmptyImpl>();
+ _inverted_index_builder = std::make_unique<IndexColumnWriterEmptyImpl>();
break;
});
- RETURN_IF_ERROR(InvertedIndexColumnWriter::create(get_field(), &_inverted_index_builder,
- _opts.inverted_index_file_writer,
+ RETURN_IF_ERROR(IndexColumnWriter::create(get_field(), &_inverted_index_builder,
+ _opts.x_index_file_writer,
_opts.inverted_index));
} while (false);
}
@@ -895,11 +896,19 @@
if (_opts.need_inverted_index) {
auto* writer = dynamic_cast<ScalarColumnWriter*>(_item_writer.get());
if (writer != nullptr) {
- RETURN_IF_ERROR(InvertedIndexColumnWriter::create(get_field(), &_inverted_index_builder,
- _opts.inverted_index_file_writer,
+ RETURN_IF_ERROR(IndexColumnWriter::create(get_field(), &_inverted_index_builder,
+ _opts.x_index_file_writer,
_opts.inverted_index));
}
}
+ if(_opts.need_ann_index){
+ auto* writer = dynamic_cast<ScalarColumnWriter*>(_item_writer.get());
+ if (writer != nullptr) {
+ RETURN_IF_ERROR(IndexColumnWriter::create(get_field(), &_ann_index_builder,
+ _opts.x_index_file_writer,
+ _opts.ann_index));
+ }
+ }
return Status::OK();
}
@@ -910,6 +919,13 @@
return Status::OK();
}
+Status ArrayColumnWriter::write_ann_index() {
+ if (_opts.need_ann_index) {
+ return _ann_index_builder->finish();
+ }
+ return Status::OK();
+}
+
// batch append data for array
Status ArrayColumnWriter::append_data(const uint8_t** ptr, size_t num_rows) {
// data_ptr contains
@@ -936,6 +952,17 @@
}
}
+ if (_opts.need_ann_index) {
+ auto* writer = dynamic_cast<ScalarColumnWriter*>(_item_writer.get());
+ // now only support nested type is scala
+ if (writer != nullptr) {
+ //NOTE: use array field name as index field, but item_writer size should be used when moving item_data_ptr
+ RETURN_IF_ERROR(_ann_index_builder->add_array_values(
+ _item_writer->get_field()->size(), reinterpret_cast<const void*>(data),
+ reinterpret_cast<const uint8_t*>(nested_null_map), offsets_ptr, num_rows));
+ }
+ }
+
RETURN_IF_ERROR(_offset_writer->append_data(&offsets_ptr, num_rows));
return Status::OK();
}
diff --git a/be/src/olap/rowset/segment_v2/column_writer.h b/be/src/olap/rowset/segment_v2/column_writer.h
index 2d66b94..38c2172 100644
--- a/be/src/olap/rowset/segment_v2/column_writer.h
+++ b/be/src/olap/rowset/segment_v2/column_writer.h
@@ -31,7 +31,7 @@
#include "common/status.h" // for Status
#include "olap/field.h" // for Field
#include "olap/rowset/segment_v2/common.h"
-#include "olap/rowset/segment_v2/inverted_index_writer.h"
+#include "olap/rowset/segment_v2/index_writer.h"
#include "util/bitmap.h" // for BitmapChange
#include "util/slice.h" // for OwnedSlice
@@ -61,11 +61,13 @@
bool need_bloom_filter = false;
bool is_ngram_bf_index = false;
bool need_inverted_index = false;
+ bool need_ann_index = false;
uint8_t gram_size;
uint16_t gram_bf_size;
std::vector<const TabletIndex*> indexes; // unused
const TabletIndex* inverted_index = nullptr;
- InvertedIndexFileWriter* inverted_index_file_writer;
+ const TabletIndex* ann_index = nullptr;
+ XIndexFileWriter* x_index_file_writer;
std::string to_string() const {
std::stringstream ss;
ss << std::boolalpha << "meta=" << meta->DebugString()
@@ -155,6 +157,9 @@
virtual Status write_bitmap_index() = 0;
virtual Status write_inverted_index() = 0;
+ virtual Status write_ann_index() {
+ return Status::OK();
+ }
virtual Status write_bloom_filter_index() = 0;
@@ -275,7 +280,7 @@
std::unique_ptr<OrdinalIndexWriter> _ordinal_index_builder;
std::unique_ptr<ZoneMapIndexWriter> _zone_map_index_builder;
std::unique_ptr<BitmapIndexWriter> _bitmap_index_builder;
- std::unique_ptr<InvertedIndexColumnWriter> _inverted_index_builder;
+ std::unique_ptr<IndexColumnWriter> _inverted_index_builder;
std::unique_ptr<BloomFilterIndexWriter> _bloom_filter_index_builder;
// call before flush data page.
@@ -388,6 +393,7 @@
return Status::OK();
}
Status write_inverted_index() override;
+ Status write_ann_index() override;
Status write_bloom_filter_index() override {
if (_opts.need_bloom_filter) {
return Status::NotSupported("array not support bloom filter index");
@@ -404,7 +410,8 @@
std::unique_ptr<OffsetColumnWriter> _offset_writer;
std::unique_ptr<ScalarColumnWriter> _null_writer;
std::unique_ptr<ColumnWriter> _item_writer;
- std::unique_ptr<InvertedIndexColumnWriter> _inverted_index_builder;
+ std::unique_ptr<IndexColumnWriter> _inverted_index_builder;
+ std::unique_ptr<IndexColumnWriter> _ann_index_builder;
ColumnWriterOptions _opts;
};
@@ -458,7 +465,7 @@
// we need null writer to make sure a row is null or not
std::unique_ptr<ScalarColumnWriter> _null_writer;
std::unique_ptr<OffsetColumnWriter> _offsets_writer;
- std::unique_ptr<InvertedIndexColumnWriter> _inverted_index_builder;
+ std::unique_ptr<IndexColumnWriter> _inverted_index_builder;
ColumnWriterOptions _opts;
};
diff --git a/be/src/olap/rowset/segment_v2/index_writer.cpp b/be/src/olap/rowset/segment_v2/index_writer.cpp
new file mode 100644
index 0000000..eeb9d8d
--- /dev/null
+++ b/be/src/olap/rowset/segment_v2/index_writer.cpp
@@ -0,0 +1,169 @@
+#include "olap/rowset/segment_v2/index_writer.h"
+
+#include <CLucene.h> // IWYU pragma: keep
+#include <CLucene/analysis/LanguageBasedAnalyzer.h>
+#include <CLucene/util/bkd/bkd_writer.h>
+#include <glog/logging.h>
+
+#include <limits>
+#include <memory>
+#include <ostream>
+#include <roaring/roaring.hh>
+#include <string>
+#include <vector>
+
+#include "io/fs/local_file_system.h"
+
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wshadow-field"
+#endif
+
+#include "CLucene/analysis/standard95/StandardAnalyzer.h"
+
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
+
+#include "common/config.h"
+#include "gutil/strings/strip.h"
+#include "olap/field.h"
+#include "olap/inverted_index_parser.h"
+#include "olap/key_coder.h"
+#include "olap/olap_common.h"
+#include "olap/rowset/segment_v2/common.h"
+#include "olap/rowset/segment_v2/inverted_index/analyzer/analyzer.h"
+#include "olap/rowset/segment_v2/inverted_index/char_filter/char_filter_factory.h"
+#include "olap/rowset/segment_v2/inverted_index_common.h"
+#include "olap/rowset/segment_v2/inverted_index_desc.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
+#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
+#include "olap/tablet_schema.h"
+#include "olap/types.h"
+#include "runtime/collection_value.h"
+#include "runtime/exec_env.h"
+#include "util/debug_points.h"
+#include "util/faststring.h"
+#include "util/slice.h"
+#include "util/string_util.h"
+
+#include "vector/vector_index.h"
+#include "vector/diskann_vector_index.h"
+
+#include "olap/rowset/segment_v2/ann_index_writer.h"
+#include "olap/rowset/segment_v2/inverted_index_writer.h"
+
+namespace doris {
+namespace segment_v2 {
+
+bool IndexColumnWriter::check_support_inverted_index(const TabletColumn& column){
+ // bellow types are not supported in inverted index for extracted columns
+ static std::set<FieldType> invalid_types = {
+ FieldType::OLAP_FIELD_TYPE_DOUBLE,
+ FieldType::OLAP_FIELD_TYPE_JSONB,
+ FieldType::OLAP_FIELD_TYPE_ARRAY,
+ FieldType::OLAP_FIELD_TYPE_FLOAT,
+ };
+ if (column.is_extracted_column() && (invalid_types.contains(column.type()))) {
+ return false;
+ }
+ if (column.is_variant_type()) {
+ return false;
+ }
+ return true;
+}
+
+bool IndexColumnWriter::check_support_ann_index(const TabletColumn& column){
+ // bellow types are not supported in inverted index for extracted columns
+ return column.is_array_type();
+}
+
+
+Status IndexColumnWriter::create(const Field* field,
+ std::unique_ptr<IndexColumnWriter>* res,
+ XIndexFileWriter* index_file_writer,
+ const TabletIndex* index_meta) {
+ const auto* typeinfo = field->type_info();
+ FieldType type = typeinfo->type();
+ std::string field_name;
+ auto storage_format = index_file_writer->get_storage_format();
+ if (storage_format == InvertedIndexStorageFormatPB::V1) {
+ field_name = field->name();
+ } else {
+ if (field->is_extracted_column()) {
+ // variant sub col
+ // field_name format: parent_unique_id.sub_col_name
+ field_name = std::to_string(field->parent_unique_id()) + "." + field->name();
+ } else {
+ field_name = std::to_string(field->unique_id());
+ }
+ }
+ bool single_field = true;
+ if (type == FieldType::OLAP_FIELD_TYPE_ARRAY) {
+ const auto* array_typeinfo = dynamic_cast<const ArrayTypeInfo*>(typeinfo);
+ DBUG_EXECUTE_IF("IndexColumnWriter::create_array_typeinfo_is_nullptr",
+ { array_typeinfo = nullptr; })
+ if (array_typeinfo != nullptr) {
+ typeinfo = array_typeinfo->item_type_info();
+ type = typeinfo->type();
+ single_field = false;
+ } else {
+ return Status::NotSupported("unsupported array type for inverted index: " +
+ std::to_string(int(type)));
+ }
+ }
+
+ if(index_meta->index_type() == IndexType::ANN){
+ *res = std::make_unique<AnnIndexColumnWriter>(
+ field_name, index_file_writer, index_meta, single_field);
+ RETURN_IF_ERROR((*res)->init());
+ return Status::OK();
+ }
+
+ DBUG_EXECUTE_IF("IndexColumnWriter::create_unsupported_type_for_inverted_index",
+ { type = FieldType::OLAP_FIELD_TYPE_FLOAT; })
+ switch (type) {
+#define M(TYPE) \
+ case TYPE: \
+ *res = std::make_unique<InvertedIndexColumnWriter<TYPE>>( \
+ field_name, index_file_writer, index_meta, single_field); \
+ break;
+ M(FieldType::OLAP_FIELD_TYPE_TINYINT)
+ M(FieldType::OLAP_FIELD_TYPE_SMALLINT)
+ M(FieldType::OLAP_FIELD_TYPE_INT)
+ M(FieldType::OLAP_FIELD_TYPE_UNSIGNED_INT)
+ M(FieldType::OLAP_FIELD_TYPE_BIGINT)
+ M(FieldType::OLAP_FIELD_TYPE_LARGEINT)
+ M(FieldType::OLAP_FIELD_TYPE_CHAR)
+ M(FieldType::OLAP_FIELD_TYPE_VARCHAR)
+ M(FieldType::OLAP_FIELD_TYPE_STRING)
+ M(FieldType::OLAP_FIELD_TYPE_DATE)
+ M(FieldType::OLAP_FIELD_TYPE_DATETIME)
+ M(FieldType::OLAP_FIELD_TYPE_DECIMAL)
+ M(FieldType::OLAP_FIELD_TYPE_DATEV2)
+ M(FieldType::OLAP_FIELD_TYPE_DATETIMEV2)
+ M(FieldType::OLAP_FIELD_TYPE_DECIMAL32)
+ M(FieldType::OLAP_FIELD_TYPE_DECIMAL64)
+ M(FieldType::OLAP_FIELD_TYPE_DECIMAL128I)
+ M(FieldType::OLAP_FIELD_TYPE_DECIMAL256)
+ M(FieldType::OLAP_FIELD_TYPE_BOOL)
+ M(FieldType::OLAP_FIELD_TYPE_IPV4)
+ M(FieldType::OLAP_FIELD_TYPE_IPV6)
+#undef M
+ default:
+ return Status::NotSupported("unsupported type for inverted index: " +
+ std::to_string(int(type)));
+ }
+ if (*res != nullptr) {
+ auto st = (*res)->init();
+ if (!st.ok()) {
+ (*res)->close_on_error();
+ return st;
+ }
+ }
+ return Status::OK();
+}
+
+
+}
+}
\ No newline at end of file
diff --git a/be/src/olap/rowset/segment_v2/index_writer.h b/be/src/olap/rowset/segment_v2/index_writer.h
new file mode 100644
index 0000000..4c237a4
--- /dev/null
+++ b/be/src/olap/rowset/segment_v2/index_writer.h
@@ -0,0 +1,117 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <butil/macros.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include <atomic>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "common/config.h"
+#include "common/status.h"
+#include "gutil/strings/split.h"
+#include "io/fs/file_system.h"
+#include "io/fs/local_file_system.h"
+#include "olap/olap_common.h"
+#include "olap/options.h"
+
+namespace doris {
+class CollectionValue;
+
+class Field;
+
+class TabletIndex;
+class TabletColumn;
+
+namespace segment_v2 {
+class XIndexFileWriter;
+
+class IndexColumnWriter {
+public:
+ static Status create(const Field* field, std::unique_ptr<IndexColumnWriter>* res,
+ XIndexFileWriter* index_file_writer,
+ const TabletIndex* inverted_index);
+ virtual Status init() = 0;
+
+ IndexColumnWriter() = default;
+ virtual ~IndexColumnWriter() = default;
+
+ virtual Status add_values(const std::string name, const void* values, size_t count) = 0;
+ virtual Status add_array_values(size_t field_size, const CollectionValue* values,
+ size_t count) = 0;
+
+ virtual Status add_array_values(size_t field_size, const void* value_ptr,
+ const uint8_t* null_map, const uint8_t* offsets_ptr,
+ size_t count) = 0;
+
+ virtual Status add_nulls(uint32_t count) = 0;
+ virtual Status add_array_nulls(uint32_t row_id) = 0;
+
+ virtual Status finish() = 0;
+
+ virtual int64_t size() const = 0;
+
+ virtual void close_on_error() = 0;
+
+ // check if the column is valid for inverted index, some columns
+ // are generated from variant, but not all of them are supported
+ static bool check_support_inverted_index(const TabletColumn& column);
+ static bool check_support_ann_index(const TabletColumn& column);
+
+private:
+ DISALLOW_COPY_AND_ASSIGN(IndexColumnWriter);
+};
+
+class TmpFileDirs {
+public:
+ TmpFileDirs(const std::vector<doris::StorePath>& store_paths) {
+ for (const auto& store_path : store_paths) {
+ _tmp_file_dirs.emplace_back(store_path.path + "/" + config::tmp_file_dir);
+ }
+ };
+
+ Status init() {
+ for (auto& tmp_file_dir : _tmp_file_dirs) {
+ // delete the tmp dir to avoid the tmp files left by last crash
+ RETURN_IF_ERROR(io::global_local_filesystem()->delete_directory(tmp_file_dir));
+ RETURN_IF_ERROR(io::global_local_filesystem()->create_directory(tmp_file_dir));
+ }
+ return Status::OK();
+ };
+
+ io::Path get_tmp_file_dir() {
+ std::cout << "TmpFileDirs size: " << _tmp_file_dirs.size() << std::endl;
+ size_t cur_index = _next_index.fetch_add(1);
+ return _tmp_file_dirs[cur_index % _tmp_file_dirs.size()];
+ };
+
+ ~TmpFileDirs(){
+ std::cout << "TmpFileDirs destroyed!" << std::endl;
+ }
+
+private:
+ std::vector<io::Path> _tmp_file_dirs;
+ std::atomic_size_t _next_index {0}; // use for round-robin
+};
+
+} // namespace segment_v2
+} // namespace doris
diff --git a/be/src/olap/rowset/segment_v2/inverted_index_compaction.cpp b/be/src/olap/rowset/segment_v2/inverted_index_compaction.cpp
index dcbdca9..0d7f1a4 100644
--- a/be/src/olap/rowset/segment_v2/inverted_index_compaction.cpp
+++ b/be/src/olap/rowset/segment_v2/inverted_index_compaction.cpp
@@ -17,7 +17,7 @@
#include "inverted_index_compaction.h"
-#include "inverted_index_file_writer.h"
+#include "x_index_file_writer.h"
#include "inverted_index_fs_directory.h"
#include "io/fs/local_file_system.h"
#include "olap/tablet_schema.h"
diff --git a/be/src/olap/rowset/segment_v2/inverted_index_compaction.h b/be/src/olap/rowset/segment_v2/inverted_index_compaction.h
index 1a6e474..f344f12 100644
--- a/be/src/olap/rowset/segment_v2/inverted_index_compaction.h
+++ b/be/src/olap/rowset/segment_v2/inverted_index_compaction.h
@@ -28,7 +28,7 @@
namespace doris {
class TabletIndex;
namespace segment_v2 {
-class InvertedIndexFileWriter;
+class XIndexFileWriter;
class InvertedIndexFileReader;
Status compact_column(int64_t index_id,
diff --git a/be/src/olap/rowset/segment_v2/inverted_index_file_reader.h b/be/src/olap/rowset/segment_v2/inverted_index_file_reader.h
index ed6ee85..f5977cf 100644
--- a/be/src/olap/rowset/segment_v2/inverted_index_file_reader.h
+++ b/be/src/olap/rowset/segment_v2/inverted_index_file_reader.h
@@ -32,7 +32,7 @@
#include "common/config.h"
#include "io/fs/file_system.h"
#include "olap/rowset/segment_v2/inverted_index_desc.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
namespace doris {
class TabletIndex;
diff --git a/be/src/olap/rowset/segment_v2/inverted_index_writer.cpp b/be/src/olap/rowset/segment_v2/inverted_index_writer.cpp
index 4136ada..b7c102b 100644
--- a/be/src/olap/rowset/segment_v2/inverted_index_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/inverted_index_writer.cpp
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-#include "olap/rowset/segment_v2/inverted_index_writer.h"
+#include "olap/rowset/segment_v2/index_writer.h"
#include <CLucene.h> // IWYU pragma: keep
#include <CLucene/analysis/LanguageBasedAnalyzer.h>
@@ -53,7 +53,7 @@
#include "olap/rowset/segment_v2/inverted_index/char_filter/char_filter_factory.h"
#include "olap/rowset/segment_v2/inverted_index_common.h"
#include "olap/rowset/segment_v2/inverted_index_desc.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
#include "olap/tablet_schema.h"
#include "olap/types.h"
@@ -63,6 +63,7 @@
#include "util/faststring.h"
#include "util/slice.h"
#include "util/string_util.h"
+#include "olap/rowset/segment_v2/inverted_index_writer.h"
namespace doris::segment_v2 {
const int32_t MAX_FIELD_LEN = 0x7FFFFFFFL;
@@ -71,32 +72,11 @@
const float MAXMBSortInHeap = 512.0 * 8;
const int DIMS = 1;
-bool InvertedIndexColumnWriter::check_support_inverted_index(const TabletColumn& column) {
- // bellow types are not supported in inverted index for extracted columns
- static std::set<FieldType> invalid_types = {
- FieldType::OLAP_FIELD_TYPE_DOUBLE,
- FieldType::OLAP_FIELD_TYPE_JSONB,
- FieldType::OLAP_FIELD_TYPE_ARRAY,
- FieldType::OLAP_FIELD_TYPE_FLOAT,
- };
- if (column.is_extracted_column() && (invalid_types.contains(column.type()))) {
- return false;
- }
- if (column.is_variant_type()) {
- return false;
- }
- return true;
-}
-
template <FieldType field_type>
-class InvertedIndexColumnWriterImpl : public InvertedIndexColumnWriter {
-public:
- using CppType = typename CppTypeTraits<field_type>::CppType;
-
- explicit InvertedIndexColumnWriterImpl(const std::string& field_name,
- InvertedIndexFileWriter* index_file_writer,
+InvertedIndexColumnWriter<field_type>::InvertedIndexColumnWriter(const std::string& field_name,
+ XIndexFileWriter* index_file_writer,
const TabletIndex* index_meta,
- const bool single_field = true)
+ const bool single_field)
: _single_field(single_field),
_index_meta(index_meta),
_index_file_writer(index_file_writer) {
@@ -106,19 +86,21 @@
_field_name = StringUtil::string_to_wstring(field_name);
}
- ~InvertedIndexColumnWriterImpl() override {
+template <FieldType field_type>
+InvertedIndexColumnWriter<field_type>::~InvertedIndexColumnWriter() {
if (_index_writer != nullptr) {
close_on_error();
}
- }
+}
- Status init() override {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::init() {
try {
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::init_field_type_not_supported", {
+ DBUG_EXECUTE_IF("IndexColumnWriter::init_field_type_not_supported", {
return Status::Error<doris::ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
"Field type not supported");
})
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::init_inverted_index_writer_init_error",
+ DBUG_EXECUTE_IF("IndexColumnWriter::init_inverted_index_writer_init_error",
{ _CLTHROWA(CL_ERR_IO, "debug point: init index error"); })
if constexpr (field_is_slice_type(field_type)) {
return init_fulltext_index();
@@ -134,9 +116,10 @@
}
}
- void close_on_error() override {
+template <FieldType field_type>
+void InvertedIndexColumnWriter<field_type>::close_on_error() {
try {
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::close_on_error_throw_exception",
+ DBUG_EXECUTE_IF("IndexColumnWriter::close_on_error_throw_exception",
{ _CLTHROWA(CL_ERR_IO, "debug point: close on error"); })
// delete directory must be done before index_writer close
// because index_writer will close the directory
@@ -151,7 +134,8 @@
}
}
- Status init_bkd_index() {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::init_bkd_index() {
size_t value_length = sizeof(CppType);
// NOTE: initialize with 0, set to max_row_id when finished.
int32_t max_doc = 0;
@@ -159,13 +143,13 @@
_bkd_writer = std::make_shared<lucene::util::bkd::bkd_writer>(
max_doc, DIMS, DIMS, value_length, MAX_LEAF_COUNT, MAXMBSortInHeap,
total_point_count, true, config::max_depth_in_bkd_tree);
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::init_bkd_index_throw_error", {
+ DBUG_EXECUTE_IF("IndexColumnWriter::init_bkd_index_throw_error", {
_CLTHROWA(CL_ERR_IllegalArgument, "debug point: create bkd_writer error");
})
return open_index_directory();
}
-
- Result<std::unique_ptr<lucene::util::Reader>> create_char_string_reader(
+template <FieldType field_type>
+Result<std::unique_ptr<lucene::util::Reader>> InvertedIndexColumnWriter<field_type>::create_char_string_reader(
CharFilterMap& char_filter_map) {
try {
return inverted_index::InvertedIndexAnalyzer::create_reader(char_filter_map);
@@ -175,25 +159,26 @@
}
}
- Status open_index_directory() {
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::open_index_directory_error", {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::open_index_directory() {
+ DBUG_EXECUTE_IF("IndexColumnWriter::open_index_directory_error", {
return Status::Error<ErrorCode::INTERNAL_ERROR>(
"debug point: open_index_directory_error");
})
_dir = DORIS_TRY(_index_file_writer->open(_index_meta));
return Status::OK();
}
-
- std::unique_ptr<lucene::index::IndexWriter> create_index_writer() {
+template <FieldType field_type>
+std::unique_ptr<lucene::index::IndexWriter> InvertedIndexColumnWriter<field_type>::create_index_writer() {
bool create_index = true;
bool close_dir_on_shutdown = true;
auto index_writer = std::make_unique<lucene::index::IndexWriter>(
_dir.get(), _analyzer.get(), create_index, close_dir_on_shutdown);
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::create_index_writer_setRAMBufferSizeMB_error",
+ DBUG_EXECUTE_IF("IndexColumnWriter::create_index_writer_setRAMBufferSizeMB_error",
{ index_writer->setRAMBufferSizeMB(-100); })
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::create_index_writer_setMaxBufferedDocs_error",
+ DBUG_EXECUTE_IF("IndexColumnWriter::create_index_writer_setMaxBufferedDocs_error",
{ index_writer->setMaxBufferedDocs(1); })
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::create_index_writer_setMergeFactor_error",
+ DBUG_EXECUTE_IF("IndexColumnWriter::create_index_writer_setMergeFactor_error",
{ index_writer->setMergeFactor(1); })
index_writer->setRAMBufferSizeMB(config::inverted_index_ram_buffer_size);
index_writer->setMaxBufferedDocs(config::inverted_index_max_buffered_docs);
@@ -204,7 +189,8 @@
return index_writer;
}
- Status create_field(lucene::document::Field** field) {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::create_field(lucene::document::Field** field) {
int field_config = int(lucene::document::Field::STORE_NO) |
int(lucene::document::Field::INDEX_NONORMS);
field_config |= (_parser_type == InvertedIndexParserType::PARSER_NONE)
@@ -214,10 +200,10 @@
(*field)->setOmitTermFreqAndPositions(
!(get_parser_phrase_support_string_from_properties(_index_meta->properties()) ==
INVERTED_INDEX_PARSER_PHRASE_SUPPORT_YES));
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::create_field_v3", {
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::create_field_v3", {
if (_index_file_writer->get_storage_format() != InvertedIndexStorageFormatPB::V3) {
return Status::Error<doris::ErrorCode::INVERTED_INDEX_CLUCENE_ERROR>(
- "debug point: InvertedIndexColumnWriterImpl::create_field_v3 error");
+ "debug point: InvertedIndexColumnWriter::create_field_v3 error");
}
})
if (_index_file_writer->get_storage_format() >= InvertedIndexStorageFormatPB::V3) {
@@ -225,11 +211,11 @@
// Only effective in v3
std::string dict_compression =
get_parser_dict_compression_from_properties(_index_meta->properties());
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::create_field_dic_compression", {
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::create_field_dic_compression", {
if (dict_compression != INVERTED_INDEX_PARSER_TRUE) {
return Status::Error<doris::ErrorCode::INVERTED_INDEX_CLUCENE_ERROR>(
"debug point: "
- "InvertedIndexColumnWriterImpl::create_field_dic_compression error");
+ "InvertedIndexColumnWriter::create_field_dic_compression error");
}
})
if (dict_compression == INVERTED_INDEX_PARSER_TRUE) {
@@ -239,7 +225,8 @@
return Status::OK();
}
- Result<std::unique_ptr<lucene::analysis::Analyzer>> create_analyzer(
+template <FieldType field_type>
+Result<std::unique_ptr<lucene::analysis::Analyzer>> InvertedIndexColumnWriter<field_type>::create_analyzer(
std::shared_ptr<InvertedIndexCtx>& inverted_index_ctx) {
try {
return inverted_index::InvertedIndexAnalyzer::create_analyzer(inverted_index_ctx.get());
@@ -249,7 +236,8 @@
}
}
- Status init_fulltext_index() {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::init_fulltext_index() {
_inverted_index_ctx = std::make_shared<InvertedIndexCtx>(
get_inverted_index_parser_type_from_string(
get_parser_string_from_properties(_index_meta->properties())),
@@ -276,12 +264,13 @@
return Status::OK();
}
- Status add_document() {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_document() {
DBUG_EXECUTE_IF("inverted_index_writer.add_document", { return Status::OK(); });
try {
_index_writer->addDocument(_doc.get());
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_document_throw_error",
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_document_throw_error",
{ _CLTHROWA(CL_ERR_IO, "debug point: add_document io error"); })
} catch (const CLuceneError& e) {
close_on_error();
@@ -291,10 +280,11 @@
return Status::OK();
}
- Status add_null_document() {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_null_document() {
try {
_index_writer->addNullDocument(_doc.get());
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_null_document_throw_error",
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_null_document_throw_error",
{ _CLTHROWA(CL_ERR_IO, "debug point: add_null_document io error"); })
} catch (const CLuceneError& e) {
close_on_error();
@@ -304,13 +294,14 @@
return Status::OK();
}
- Status add_nulls(uint32_t count) override {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_nulls(uint32_t count) {
_null_bitmap.addRange(_rid, _rid + count);
_rid += count;
if constexpr (field_is_slice_type(field_type)) {
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_nulls_field_nullptr",
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_nulls_field_nullptr",
{ _field = nullptr; })
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_nulls_index_writer_nullptr",
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_nulls_index_writer_nullptr",
{ _index_writer = nullptr; })
if (_field == nullptr || _index_writer == nullptr) {
LOG(ERROR) << "field or index writer is null in inverted index writer.";
@@ -325,12 +316,14 @@
return Status::OK();
}
- Status add_array_nulls(uint32_t row_id) override {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_array_nulls(uint32_t row_id) {
_null_bitmap.add(row_id);
return Status::OK();
}
- Status new_inverted_index_field(const char* field_value_data, size_t field_value_size) {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::new_inverted_index_field(const char* field_value_data, size_t field_value_size) {
try {
if (_parser_type != InvertedIndexParserType::PARSER_UNKNOWN &&
_parser_type != InvertedIndexParserType::PARSER_NONE) {
@@ -345,10 +338,11 @@
return Status::OK();
}
- void new_char_token_stream(const char* s, size_t len, lucene::document::Field* field) {
+template <FieldType field_type>
+void InvertedIndexColumnWriter<field_type>::new_char_token_stream(const char* s, size_t len, lucene::document::Field* field) {
_char_string_reader->init(s, len, false);
DBUG_EXECUTE_IF(
- "InvertedIndexColumnWriterImpl::new_char_token_stream__char_string_reader_init_"
+ "InvertedIndexColumnWriter::new_char_token_stream__char_string_reader_init_"
"error",
{
_CLTHROWA(CL_ERR_UnsupportedOperation,
@@ -358,22 +352,25 @@
field->setValue(stream);
}
- void new_field_value(const char* s, size_t len, lucene::document::Field* field) {
+template <FieldType field_type>
+void InvertedIndexColumnWriter<field_type>::new_field_value(const char* s, size_t len, lucene::document::Field* field) {
auto* field_value = lucene::util::Misc::_charToWide(s, len);
field->setValue(field_value, false);
// setValue did not duplicate value, so we don't have to delete
//_CLDELETE_ARRAY(field_value)
}
- void new_field_char_value(const char* s, size_t len, lucene::document::Field* field) {
+template <FieldType field_type>
+void InvertedIndexColumnWriter<field_type>::new_field_char_value(const char* s, size_t len, lucene::document::Field* field) {
field->setValue((char*)s, len);
}
- Status add_values(const std::string fn, const void* values, size_t count) override {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_values(const std::string fn, const void* values, size_t count) {
if constexpr (field_is_slice_type(field_type)) {
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_values_field_is_nullptr",
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_values_field_is_nullptr",
{ _field = nullptr; })
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_values_index_writer_is_nullptr",
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_values_index_writer_is_nullptr",
{ _index_writer = nullptr; })
if (_field == nullptr || _index_writer == nullptr) {
LOG(ERROR) << "field or index writer is null in inverted index writer.";
@@ -400,9 +397,10 @@
return Status::OK();
}
- Status add_array_values(size_t field_size, const void* value_ptr, const uint8_t* null_map,
- const uint8_t* offsets_ptr, size_t count) override {
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_array_values_count_is_zero",
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_array_values(size_t field_size, const void* value_ptr, const uint8_t* null_map,
+ const uint8_t* offsets_ptr, size_t count) {
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_array_values_count_is_zero",
{ count = 0; })
if (count == 0) {
// no values to add inverted index
@@ -411,7 +409,7 @@
const auto* offsets = reinterpret_cast<const uint64_t*>(offsets_ptr);
if constexpr (field_is_slice_type(field_type)) {
DBUG_EXECUTE_IF(
- "InvertedIndexColumnWriterImpl::add_array_values_index_writer_is_nullptr",
+ "InvertedIndexColumnWriter::add_array_values_index_writer_is_nullptr",
{ _index_writer = nullptr; })
if (_index_writer == nullptr) {
LOG(ERROR) << "index writer is null in inverted index writer.";
@@ -440,7 +438,7 @@
// now we temp create field . later make a pool
Status st = create_field(&new_field);
DBUG_EXECUTE_IF(
- "InvertedIndexColumnWriterImpl::add_array_values_create_field_"
+ "InvertedIndexColumnWriter::add_array_values_create_field_"
"error",
{
st = Status::Error<ErrorCode::INTERNAL_ERROR>(
@@ -500,7 +498,7 @@
// resetCurrentFieldData
Status st = create_field(&new_field);
DBUG_EXECUTE_IF(
- "InvertedIndexColumnWriterImpl::add_array_values_create_field_error_2",
+ "InvertedIndexColumnWriter::add_array_values_create_field_error_2",
{
st = Status::Error<ErrorCode::INTERNAL_ERROR>(
"debug point: add_array_values_create_field_error_2");
@@ -536,13 +534,14 @@
return Status::OK();
}
- Status add_array_values(size_t field_size, const CollectionValue* values,
- size_t count) override {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_array_values(size_t field_size, const CollectionValue* values,
+ size_t count) {
if constexpr (field_is_slice_type(field_type)) {
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_array_values_field_is_nullptr",
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_array_values_field_is_nullptr",
{ _field = nullptr; })
DBUG_EXECUTE_IF(
- "InvertedIndexColumnWriterImpl::add_array_values_index_writer_is_nullptr",
+ "InvertedIndexColumnWriter::add_array_values_index_writer_is_nullptr",
{ _index_writer = nullptr; })
if (_field == nullptr || _index_writer == nullptr) {
LOG(ERROR) << "field or index writer is null in inverted index writer.";
@@ -588,7 +587,8 @@
return Status::OK();
}
- Status add_numeric_values(const void* values, size_t count) {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_numeric_values(const void* values, size_t count) {
auto p = reinterpret_cast<const CppType*>(values);
for (size_t i = 0; i < count; ++i) {
RETURN_IF_ERROR(add_value(*p));
@@ -599,12 +599,13 @@
return Status::OK();
}
- Status add_value(const CppType& value) {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::add_value(const CppType& value) {
try {
std::string new_value;
size_t value_length = sizeof(CppType);
- DBUG_EXECUTE_IF("InvertedIndexColumnWriterImpl::add_value_bkd_writer_add_throw_error", {
+ DBUG_EXECUTE_IF("InvertedIndexColumnWriter::add_value_bkd_writer_add_throw_error", {
_CLTHROWA(CL_ERR_IllegalArgument, ("packedValue should be length=xxx"));
});
@@ -617,12 +618,13 @@
return Status::OK();
}
- int64_t size() const override {
+template <FieldType field_type>
+int64_t InvertedIndexColumnWriter<field_type>::size() const {
//TODO: get memory size of inverted index
return 0;
}
-
- void write_null_bitmap(lucene::store::IndexOutput* null_bitmap_out) {
+template <FieldType field_type>
+void InvertedIndexColumnWriter<field_type>::write_null_bitmap(lucene::store::IndexOutput* null_bitmap_out) {
// write null_bitmap file
_null_bitmap.runOptimize();
size_t size = _null_bitmap.getSizeInBytes(false);
@@ -634,7 +636,8 @@
}
}
- Status finish() override {
+template <FieldType field_type>
+Status InvertedIndexColumnWriter<field_type>::finish() {
if (_dir != nullptr) {
std::unique_ptr<lucene::store::IndexOutput> null_bitmap_out = nullptr;
std::unique_ptr<lucene::store::IndexOutput> data_out = nullptr;
@@ -707,106 +710,30 @@
"Inverted index writer finish error occurred: dir is nullptr");
}
-private:
- rowid_t _rid = 0;
- uint32_t _row_ids_seen_for_bkd = 0;
- roaring::Roaring _null_bitmap;
- uint64_t _reverted_index_size;
+#define M(TYPE) template class InvertedIndexColumnWriter<TYPE>;
- std::unique_ptr<lucene::document::Document> _doc = nullptr;
- lucene::document::Field* _field = nullptr;
- bool _single_field = true;
- // Since _index_writer's write.lock is created by _dir.lockFactory,
- // _dir must destruct after _index_writer, so _dir must be defined before _index_writer.
- std::shared_ptr<DorisFSDirectory> _dir = nullptr;
- std::unique_ptr<lucene::index::IndexWriter> _index_writer = nullptr;
- std::unique_ptr<lucene::analysis::Analyzer> _analyzer = nullptr;
- std::unique_ptr<lucene::util::Reader> _char_string_reader = nullptr;
- std::shared_ptr<lucene::util::bkd::bkd_writer> _bkd_writer = nullptr;
- InvertedIndexCtxSPtr _inverted_index_ctx = nullptr;
- const KeyCoder* _value_key_coder;
- const TabletIndex* _index_meta;
- InvertedIndexParserType _parser_type;
- std::wstring _field_name;
- InvertedIndexFileWriter* _index_file_writer;
- uint32_t _ignore_above;
-};
+// 实例化模板
+M(FieldType::OLAP_FIELD_TYPE_TINYINT)
+M(FieldType::OLAP_FIELD_TYPE_SMALLINT)
+M(FieldType::OLAP_FIELD_TYPE_INT)
+M(FieldType::OLAP_FIELD_TYPE_UNSIGNED_INT)
+M(FieldType::OLAP_FIELD_TYPE_BIGINT)
+M(FieldType::OLAP_FIELD_TYPE_LARGEINT)
+M(FieldType::OLAP_FIELD_TYPE_CHAR)
+M(FieldType::OLAP_FIELD_TYPE_VARCHAR)
+M(FieldType::OLAP_FIELD_TYPE_STRING)
+M(FieldType::OLAP_FIELD_TYPE_DATE)
+M(FieldType::OLAP_FIELD_TYPE_DATETIME)
+M(FieldType::OLAP_FIELD_TYPE_DECIMAL)
+M(FieldType::OLAP_FIELD_TYPE_DATEV2)
+M(FieldType::OLAP_FIELD_TYPE_DATETIMEV2)
+M(FieldType::OLAP_FIELD_TYPE_DECIMAL32)
+M(FieldType::OLAP_FIELD_TYPE_DECIMAL64)
+M(FieldType::OLAP_FIELD_TYPE_DECIMAL128I)
+M(FieldType::OLAP_FIELD_TYPE_DECIMAL256)
+M(FieldType::OLAP_FIELD_TYPE_BOOL)
+M(FieldType::OLAP_FIELD_TYPE_IPV4)
+M(FieldType::OLAP_FIELD_TYPE_IPV6)
-Status InvertedIndexColumnWriter::create(const Field* field,
- std::unique_ptr<InvertedIndexColumnWriter>* res,
- InvertedIndexFileWriter* index_file_writer,
- const TabletIndex* index_meta) {
- const auto* typeinfo = field->type_info();
- FieldType type = typeinfo->type();
- std::string field_name;
- auto storage_format = index_file_writer->get_storage_format();
- if (storage_format == InvertedIndexStorageFormatPB::V1) {
- field_name = field->name();
- } else {
- if (field->is_extracted_column()) {
- // variant sub col
- // field_name format: parent_unique_id.sub_col_name
- field_name = std::to_string(field->parent_unique_id()) + "." + field->name();
- } else {
- field_name = std::to_string(field->unique_id());
- }
- }
- bool single_field = true;
- if (type == FieldType::OLAP_FIELD_TYPE_ARRAY) {
- const auto* array_typeinfo = dynamic_cast<const ArrayTypeInfo*>(typeinfo);
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::create_array_typeinfo_is_nullptr",
- { array_typeinfo = nullptr; })
- if (array_typeinfo != nullptr) {
- typeinfo = array_typeinfo->item_type_info();
- type = typeinfo->type();
- single_field = false;
- } else {
- return Status::NotSupported("unsupported array type for inverted index: " +
- std::to_string(int(type)));
- }
- }
- DBUG_EXECUTE_IF("InvertedIndexColumnWriter::create_unsupported_type_for_inverted_index",
- { type = FieldType::OLAP_FIELD_TYPE_FLOAT; })
- switch (type) {
-#define M(TYPE) \
- case TYPE: \
- *res = std::make_unique<InvertedIndexColumnWriterImpl<TYPE>>( \
- field_name, index_file_writer, index_meta, single_field); \
- break;
- M(FieldType::OLAP_FIELD_TYPE_TINYINT)
- M(FieldType::OLAP_FIELD_TYPE_SMALLINT)
- M(FieldType::OLAP_FIELD_TYPE_INT)
- M(FieldType::OLAP_FIELD_TYPE_UNSIGNED_INT)
- M(FieldType::OLAP_FIELD_TYPE_BIGINT)
- M(FieldType::OLAP_FIELD_TYPE_LARGEINT)
- M(FieldType::OLAP_FIELD_TYPE_CHAR)
- M(FieldType::OLAP_FIELD_TYPE_VARCHAR)
- M(FieldType::OLAP_FIELD_TYPE_STRING)
- M(FieldType::OLAP_FIELD_TYPE_DATE)
- M(FieldType::OLAP_FIELD_TYPE_DATETIME)
- M(FieldType::OLAP_FIELD_TYPE_DECIMAL)
- M(FieldType::OLAP_FIELD_TYPE_DATEV2)
- M(FieldType::OLAP_FIELD_TYPE_DATETIMEV2)
- M(FieldType::OLAP_FIELD_TYPE_DECIMAL32)
- M(FieldType::OLAP_FIELD_TYPE_DECIMAL64)
- M(FieldType::OLAP_FIELD_TYPE_DECIMAL128I)
- M(FieldType::OLAP_FIELD_TYPE_DECIMAL256)
- M(FieldType::OLAP_FIELD_TYPE_BOOL)
- M(FieldType::OLAP_FIELD_TYPE_IPV4)
- M(FieldType::OLAP_FIELD_TYPE_IPV6)
-#undef M
- default:
- return Status::NotSupported("unsupported type for inverted index: " +
- std::to_string(int(type)));
- }
- if (*res != nullptr) {
- auto st = (*res)->init();
- if (!st.ok()) {
- (*res)->close_on_error();
- return st;
- }
- }
- return Status::OK();
-}
} // namespace doris::segment_v2
diff --git a/be/src/olap/rowset/segment_v2/inverted_index_writer.h b/be/src/olap/rowset/segment_v2/inverted_index_writer.h
index da90752..17cab73 100644
--- a/be/src/olap/rowset/segment_v2/inverted_index_writer.h
+++ b/be/src/olap/rowset/segment_v2/inverted_index_writer.h
@@ -17,95 +17,122 @@
#pragma once
-#include <butil/macros.h>
-#include <stddef.h>
-#include <stdint.h>
+#include "olap/rowset/segment_v2/index_writer.h"
-#include <atomic>
+#include <CLucene.h> // IWYU pragma: keep
+#include <CLucene/analysis/LanguageBasedAnalyzer.h>
+#include <CLucene/util/bkd/bkd_writer.h>
+#include <glog/logging.h>
+
+#include <limits>
#include <memory>
+#include <ostream>
+#include <roaring/roaring.hh>
#include <string>
#include <vector>
-#include "common/config.h"
-#include "common/status.h"
-#include "gutil/strings/split.h"
-#include "io/fs/file_system.h"
#include "io/fs/local_file_system.h"
+
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wshadow-field"
+#endif
+
+#include "CLucene/analysis/standard95/StandardAnalyzer.h"
+
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
+
+#include "common/config.h"
+#include "gutil/strings/strip.h"
+#include "olap/field.h"
+#include "olap/inverted_index_parser.h"
+#include "olap/key_coder.h"
#include "olap/olap_common.h"
-#include "olap/options.h"
+#include "olap/rowset/segment_v2/common.h"
+#include "olap/rowset/segment_v2/inverted_index/analyzer/analyzer.h"
+#include "olap/rowset/segment_v2/inverted_index/char_filter/char_filter_factory.h"
+#include "olap/rowset/segment_v2/inverted_index_common.h"
+#include "olap/rowset/segment_v2/inverted_index_desc.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
+#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
+#include "olap/tablet_schema.h"
+#include "olap/types.h"
+#include "runtime/collection_value.h"
+#include "runtime/exec_env.h"
+#include "util/debug_points.h"
+#include "util/faststring.h"
+#include "util/slice.h"
+#include "util/string_util.h"
+#include "olap/rowset/segment_v2/index_writer.h"
-namespace doris {
-class CollectionValue;
+namespace doris::segment_v2 {
-class Field;
-
-class TabletIndex;
-class TabletColumn;
-
-namespace segment_v2 {
-class InvertedIndexFileWriter;
-
-class InvertedIndexColumnWriter {
+template <FieldType field_type>
+class InvertedIndexColumnWriter : public IndexColumnWriter {
public:
- static Status create(const Field* field, std::unique_ptr<InvertedIndexColumnWriter>* res,
- InvertedIndexFileWriter* index_file_writer,
- const TabletIndex* inverted_index);
- virtual Status init() = 0;
+ using CppType = typename CppTypeTraits<field_type>::CppType;
- InvertedIndexColumnWriter() = default;
- virtual ~InvertedIndexColumnWriter() = default;
+ explicit InvertedIndexColumnWriter(const std::string& field_name,
+ XIndexFileWriter* index_file_writer,
+ const TabletIndex* index_meta,
+ const bool single_field = true);
- virtual Status add_values(const std::string name, const void* values, size_t count) = 0;
- virtual Status add_array_values(size_t field_size, const CollectionValue* values,
- size_t count) = 0;
+ ~InvertedIndexColumnWriter() override;
- virtual Status add_array_values(size_t field_size, const void* value_ptr,
- const uint8_t* null_map, const uint8_t* offsets_ptr,
- size_t count) = 0;
-
- virtual Status add_nulls(uint32_t count) = 0;
- virtual Status add_array_nulls(uint32_t row_id) = 0;
-
- virtual Status finish() = 0;
-
- virtual int64_t size() const = 0;
-
- virtual void close_on_error() = 0;
-
- // check if the column is valid for inverted index, some columns
- // are generated from variant, but not all of them are supported
- static bool check_support_inverted_index(const TabletColumn& column);
+ Status init() override;
+ void close_on_error() override;
+ Status add_nulls(uint32_t count) override;
+ Status add_array_nulls(uint32_t row_id) override;
+ Status add_values(const std::string fn, const void* values, size_t count) override;
+ Status add_array_values(size_t field_size, const void* value_ptr, const uint8_t* null_map,
+ const uint8_t* offsets_ptr, size_t count) override;
+ Status add_array_values(size_t field_size, const CollectionValue* values,
+ size_t count) override;
+ int64_t size() const override;
+ Status finish() override;
private:
- DISALLOW_COPY_AND_ASSIGN(InvertedIndexColumnWriter);
+ Status init_bkd_index();
+ Result<std::unique_ptr<lucene::util::Reader>> create_char_string_reader(
+ CharFilterMap& char_filter_map);
+ Status open_index_directory();
+ std::unique_ptr<lucene::index::IndexWriter> create_index_writer();
+ Status create_field(lucene::document::Field** field);
+ Result<std::unique_ptr<lucene::analysis::Analyzer>> create_analyzer(
+ std::shared_ptr<InvertedIndexCtx>& inverted_index_ctx);
+ Status init_fulltext_index();
+ Status add_document();
+ Status add_null_document();
+ Status new_inverted_index_field(const char* field_value_data, size_t field_value_size);
+ void new_char_token_stream(const char* s, size_t len, lucene::document::Field* field);
+ void new_field_value(const char* s, size_t len, lucene::document::Field* field);
+ void new_field_char_value(const char* s, size_t len, lucene::document::Field* field);
+ Status add_numeric_values(const void* values, size_t count);
+ Status add_value(const CppType& value);
+ void write_null_bitmap(lucene::store::IndexOutput* null_bitmap_out);
+
+ rowid_t _rid = 0;
+ uint32_t _row_ids_seen_for_bkd = 0;
+ roaring::Roaring _null_bitmap;
+ uint64_t _reverted_index_size;
+
+ std::unique_ptr<lucene::document::Document> _doc = nullptr;
+ lucene::document::Field* _field = nullptr;
+ bool _single_field = true;
+ std::shared_ptr<DorisFSDirectory> _dir = nullptr;
+ std::unique_ptr<lucene::index::IndexWriter> _index_writer = nullptr;
+ std::unique_ptr<lucene::analysis::Analyzer> _analyzer = nullptr;
+ std::unique_ptr<lucene::util::Reader> _char_string_reader = nullptr;
+ std::shared_ptr<lucene::util::bkd::bkd_writer> _bkd_writer = nullptr;
+ InvertedIndexCtxSPtr _inverted_index_ctx = nullptr;
+ const KeyCoder* _value_key_coder;
+ const TabletIndex* _index_meta;
+ InvertedIndexParserType _parser_type;
+ std::wstring _field_name;
+ XIndexFileWriter* _index_file_writer;
+ uint32_t _ignore_above;
};
-class TmpFileDirs {
-public:
- TmpFileDirs(const std::vector<doris::StorePath>& store_paths) {
- for (const auto& store_path : store_paths) {
- _tmp_file_dirs.emplace_back(store_path.path + "/" + config::tmp_file_dir);
- }
- };
-
- Status init() {
- for (auto& tmp_file_dir : _tmp_file_dirs) {
- // delete the tmp dir to avoid the tmp files left by last crash
- RETURN_IF_ERROR(io::global_local_filesystem()->delete_directory(tmp_file_dir));
- RETURN_IF_ERROR(io::global_local_filesystem()->create_directory(tmp_file_dir));
- }
- return Status::OK();
- };
-
- io::Path get_tmp_file_dir() {
- size_t cur_index = _next_index.fetch_add(1);
- return _tmp_file_dirs[cur_index % _tmp_file_dirs.size()];
- };
-
-private:
- std::vector<io::Path> _tmp_file_dirs;
- std::atomic_size_t _next_index {0}; // use for round-robin
-};
-
-} // namespace segment_v2
-} // namespace doris
+}
\ No newline at end of file
diff --git a/be/src/olap/rowset/segment_v2/segment_writer.cpp b/be/src/olap/rowset/segment_v2/segment_writer.cpp
index 2457a44d..22155f6 100644
--- a/be/src/olap/rowset/segment_v2/segment_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/segment_writer.cpp
@@ -43,8 +43,8 @@
#include "olap/rowset/rowset_writer_context.h" // RowsetWriterContext
#include "olap/rowset/segment_creator.h"
#include "olap/rowset/segment_v2/column_writer.h" // ColumnWriter
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
-#include "olap/rowset/segment_v2/inverted_index_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
+#include "olap/rowset/segment_v2/index_writer.h"
#include "olap/rowset/segment_v2/page_io.h"
#include "olap/rowset/segment_v2/page_pointer.h"
#include "olap/segment_loader.h"
@@ -85,14 +85,14 @@
SegmentWriter::SegmentWriter(io::FileWriter* file_writer, uint32_t segment_id,
TabletSchemaSPtr tablet_schema, BaseTabletSPtr tablet,
DataDir* data_dir, const SegmentWriterOptions& opts,
- InvertedIndexFileWriter* inverted_file_writer)
+ XIndexFileWriter* inverted_file_writer)
: _segment_id(segment_id),
_tablet_schema(std::move(tablet_schema)),
_tablet(std::move(tablet)),
_data_dir(data_dir),
_opts(opts),
_file_writer(file_writer),
- _inverted_index_file_writer(inverted_file_writer),
+ _x_index_file_writer(inverted_file_writer),
_mem_tracker(std::make_unique<MemTracker>(segment_mem_tracker_name(segment_id))),
_mow_context(std::move(opts.mow_ctx)) {
CHECK_NOTNULL(file_writer);
@@ -224,10 +224,19 @@
index != nullptr && !skip_inverted_index) {
opts.inverted_index = index;
opts.need_inverted_index = true;
- DCHECK(_inverted_index_file_writer != nullptr);
- opts.inverted_index_file_writer = _inverted_index_file_writer;
+ DCHECK(_x_index_file_writer != nullptr);
+ opts.x_index_file_writer = _x_index_file_writer;
// TODO support multiple inverted index
}
+
+ // indexes for this column
+ if (const auto& index = schema->ann_index(column); index != nullptr) {
+ opts.ann_index = index;
+ opts.need_ann_index = true;
+ DCHECK(_x_index_file_writer != nullptr);
+ opts.x_index_file_writer = _x_index_file_writer;
+ }
+
#define DISABLE_INDEX_IF_FIELD_TYPE(TYPE, type_name) \
if (column.type() == FieldType::OLAP_FIELD_TYPE_##TYPE) { \
opts.need_zone_map = false; \
@@ -1029,7 +1038,7 @@
_num_rows_written = 0;
for (auto& column_writer : _column_writers) {
- RETURN_IF_ERROR(column_writer->finish());
+ RETURN_IF_ERROR(column_writer->finish()); //给索引收尾
}
RETURN_IF_ERROR(_write_data());
@@ -1042,6 +1051,7 @@
RETURN_IF_ERROR(_write_zone_map());
RETURN_IF_ERROR(_write_bitmap_index());
RETURN_IF_ERROR(_write_inverted_index());
+ RETURN_IF_ERROR(_write_ann_index());
RETURN_IF_ERROR(_write_bloom_filter_index());
*index_size = _file_writer->bytes_appended() - index_start;
@@ -1172,6 +1182,13 @@
return Status::OK();
}
+Status SegmentWriter::_write_ann_index() {
+ for (auto& column_writer : _column_writers) {
+ RETURN_IF_ERROR(column_writer->write_inverted_index());
+ }
+ return Status::OK();
+}
+
Status SegmentWriter::_write_bloom_filter_index() {
for (auto& column_writer : _column_writers) {
RETURN_IF_ERROR(column_writer->write_bloom_filter_index());
diff --git a/be/src/olap/rowset/segment_v2/segment_writer.h b/be/src/olap/rowset/segment_v2/segment_writer.h
index 6030038..095daf0 100644
--- a/be/src/olap/rowset/segment_v2/segment_writer.h
+++ b/be/src/olap/rowset/segment_v2/segment_writer.h
@@ -34,7 +34,7 @@
#include "gutil/strings/substitute.h"
#include "olap/olap_define.h"
#include "olap/rowset/segment_v2/column_writer.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/tablet.h"
#include "olap/tablet_schema.h"
#include "util/faststring.h"
@@ -84,7 +84,7 @@
explicit SegmentWriter(io::FileWriter* file_writer, uint32_t segment_id,
TabletSchemaSPtr tablet_schema, BaseTabletSPtr tablet, DataDir* data_dir,
const SegmentWriterOptions& opts,
- InvertedIndexFileWriter* inverted_file_writer);
+ XIndexFileWriter* inverted_file_writer);
~SegmentWriter();
Status init();
@@ -147,12 +147,12 @@
Status close_inverted_index(int64_t* inverted_index_file_size) {
// no inverted index
- if (_inverted_index_file_writer == nullptr) {
+ if (_x_index_file_writer == nullptr) {
*inverted_index_file_size = 0;
return Status::OK();
}
- RETURN_IF_ERROR(_inverted_index_file_writer->close());
- *inverted_index_file_size = _inverted_index_file_writer->get_index_file_total_size();
+ RETURN_IF_ERROR(_x_index_file_writer->close());
+ *inverted_index_file_size = _x_index_file_writer->get_index_file_total_size();
return Status::OK();
}
@@ -169,6 +169,7 @@
Status _write_zone_map();
Status _write_bitmap_index();
Status _write_inverted_index();
+ Status _write_ann_index();
Status _write_bloom_filter_index();
Status _write_short_key_index();
Status _write_primary_key_index();
@@ -214,7 +215,7 @@
// Not owned. owned by RowsetWriter or SegmentFlusher
io::FileWriter* _file_writer = nullptr;
// Not owned. owned by RowsetWriter or SegmentFlusher
- InvertedIndexFileWriter* _inverted_index_file_writer = nullptr;
+ XIndexFileWriter* _x_index_file_writer = nullptr;
SegmentFooterPB _footer;
// for mow tables with cluster key, the sort key is the cluster keys not unique keys
diff --git a/be/src/olap/rowset/segment_v2/vertical_segment_writer.cpp b/be/src/olap/rowset/segment_v2/vertical_segment_writer.cpp
index 0846b0f..9d6ed1d 100644
--- a/be/src/olap/rowset/segment_v2/vertical_segment_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/vertical_segment_writer.cpp
@@ -48,7 +48,7 @@
#include "olap/rowset/segment_creator.h"
#include "olap/rowset/segment_v2/column_writer.h" // ColumnWriter
#include "olap/rowset/segment_v2/inverted_index_desc.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/rowset/segment_v2/page_io.h"
#include "olap/rowset/segment_v2/page_pointer.h"
#include "olap/segment_loader.h"
@@ -90,14 +90,14 @@
TabletSchemaSPtr tablet_schema, BaseTabletSPtr tablet,
DataDir* data_dir,
const VerticalSegmentWriterOptions& opts,
- InvertedIndexFileWriter* inverted_file_writer)
+ XIndexFileWriter* inverted_file_writer)
: _segment_id(segment_id),
_tablet_schema(std::move(tablet_schema)),
_tablet(std::move(tablet)),
_data_dir(data_dir),
_opts(opts),
_file_writer(file_writer),
- _inverted_index_file_writer(inverted_file_writer),
+ _x_index_file_writer(inverted_file_writer),
_mem_tracker(std::make_unique<MemTracker>(
vertical_segment_writer_mem_tracker_name(segment_id))),
_mow_context(std::move(opts.mow_ctx)) {
@@ -216,11 +216,22 @@
index != nullptr && !skip_inverted_index) {
opts.inverted_index = index;
opts.need_inverted_index = true;
- DCHECK(_inverted_index_file_writer != nullptr);
- opts.inverted_index_file_writer = _inverted_index_file_writer;
+ DCHECK(_x_index_file_writer != nullptr);
+ opts.x_index_file_writer = _x_index_file_writer;
// TODO support multiple inverted index
}
+ if (const auto& index = tablet_schema->ann_index(column);
+ index != nullptr) {
+ opts.ann_index = index;
+ opts.need_ann_index = true;
+ DCHECK(_x_index_file_writer != nullptr);
+ opts.x_index_file_writer = _x_index_file_writer;
+ // TODO support multiple inverted index
+ }
+
+
+
#define DISABLE_INDEX_IF_FIELD_TYPE(TYPE, type_name) \
if (column.type() == FieldType::OLAP_FIELD_TYPE_##TYPE) { \
opts.need_zone_map = false; \
@@ -1427,6 +1438,7 @@
RETURN_IF_ERROR(_write_zone_map());
RETURN_IF_ERROR(_write_bitmap_index());
RETURN_IF_ERROR(_write_inverted_index());
+ RETURN_IF_ERROR(_write_ann_index());
RETURN_IF_ERROR(_write_bloom_filter_index());
*index_size = _file_writer->bytes_appended() - index_start;
@@ -1522,6 +1534,13 @@
return Status::OK();
}
+Status VerticalSegmentWriter::_write_ann_index() {
+ for (auto& column_writer : _column_writers) {
+ RETURN_IF_ERROR(column_writer->write_ann_index());
+ }
+ return Status::OK();
+}
+
Status VerticalSegmentWriter::_write_bloom_filter_index() {
for (auto& column_writer : _column_writers) {
RETURN_IF_ERROR(column_writer->write_bloom_filter_index());
diff --git a/be/src/olap/rowset/segment_v2/vertical_segment_writer.h b/be/src/olap/rowset/segment_v2/vertical_segment_writer.h
index 8cec6ed..5ccd53a 100644
--- a/be/src/olap/rowset/segment_v2/vertical_segment_writer.h
+++ b/be/src/olap/rowset/segment_v2/vertical_segment_writer.h
@@ -34,7 +34,7 @@
#include "gutil/strings/substitute.h"
#include "olap/olap_define.h"
#include "olap/rowset/segment_v2/column_writer.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/tablet.h"
#include "olap/tablet_schema.h"
#include "util/faststring.h"
@@ -60,7 +60,7 @@
} // namespace io
namespace segment_v2 {
-class InvertedIndexFileWriter;
+class XIndexFileWriter;
struct VerticalSegmentWriterOptions {
uint32_t num_rows_per_block = 1024;
@@ -83,7 +83,7 @@
explicit VerticalSegmentWriter(io::FileWriter* file_writer, uint32_t segment_id,
TabletSchemaSPtr tablet_schema, BaseTabletSPtr tablet,
DataDir* data_dir, const VerticalSegmentWriterOptions& opts,
- InvertedIndexFileWriter* inverted_file_writer);
+ XIndexFileWriter* inverted_file_writer);
~VerticalSegmentWriter();
VerticalSegmentWriter(const VerticalSegmentWriter&) = delete;
@@ -125,12 +125,12 @@
Status close_inverted_index(int64_t* inverted_index_file_size) {
// no inverted index
- if (_inverted_index_file_writer == nullptr) {
+ if (_x_index_file_writer == nullptr) {
*inverted_index_file_size = 0;
return Status::OK();
}
- RETURN_IF_ERROR(_inverted_index_file_writer->close());
- *inverted_index_file_size = _inverted_index_file_writer->get_index_file_total_size();
+ RETURN_IF_ERROR(_x_index_file_writer->close());
+ *inverted_index_file_size = _x_index_file_writer->get_index_file_total_size();
return Status::OK();
}
@@ -143,6 +143,7 @@
Status _write_zone_map();
Status _write_bitmap_index();
Status _write_inverted_index();
+ Status _write_ann_index();
Status _write_bloom_filter_index();
Status _write_short_key_index();
Status _write_primary_key_index();
@@ -223,7 +224,7 @@
// Not owned. owned by RowsetWriter
io::FileWriter* _file_writer = nullptr;
// Not owned. owned by RowsetWriter or SegmentFlusher
- InvertedIndexFileWriter* _inverted_index_file_writer = nullptr;
+ XIndexFileWriter* _x_index_file_writer = nullptr;
SegmentFooterPB _footer;
// for mow tables with cluster key, the sort key is the cluster keys not unique keys
diff --git a/be/src/olap/rowset/segment_v2/inverted_index_file_writer.cpp b/be/src/olap/rowset/segment_v2/x_index_file_writer.cpp
similarity index 88%
rename from be/src/olap/rowset/segment_v2/inverted_index_file_writer.cpp
rename to be/src/olap/rowset/segment_v2/x_index_file_writer.cpp
index 4d6892a..7f485c4 100644
--- a/be/src/olap/rowset/segment_v2/inverted_index_file_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/x_index_file_writer.cpp
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include <glog/logging.h>
@@ -30,30 +30,30 @@
namespace doris::segment_v2 {
-Status InvertedIndexFileWriter::initialize(InvertedIndexDirectoryMap& indices_dirs) {
+Status XIndexFileWriter::initialize(InvertedIndexDirectoryMap& indices_dirs) {
_indices_dirs = std::move(indices_dirs);
return Status::OK();
}
-Status InvertedIndexFileWriter::_insert_directory_into_map(int64_t index_id,
+Status XIndexFileWriter::_insert_directory_into_map(int64_t index_id,
const std::string& index_suffix,
std::shared_ptr<DorisFSDirectory> dir) {
auto key = std::make_pair(index_id, index_suffix);
auto [it, inserted] = _indices_dirs.emplace(key, std::move(dir));
if (!inserted) {
- LOG(ERROR) << "InvertedIndexFileWriter::open attempted to insert a duplicate key: ("
+ LOG(ERROR) << "XIndexFileWriter::open attempted to insert a duplicate key: ("
<< key.first << ", " << key.second << ")";
LOG(ERROR) << "Directories already in map: ";
for (const auto& entry : _indices_dirs) {
LOG(ERROR) << "Key: (" << entry.first.first << ", " << entry.first.second << ")";
}
return Status::InternalError(
- "InvertedIndexFileWriter::open attempted to insert a duplicate dir");
+ "XIndexFileWriter::open attempted to insert a duplicate dir");
}
return Status::OK();
}
-Result<std::shared_ptr<DorisFSDirectory>> InvertedIndexFileWriter::open(
+Result<std::shared_ptr<DorisFSDirectory>> XIndexFileWriter::open(
const TabletIndex* index_meta) {
auto local_fs_index_path = InvertedIndexDescriptor::get_temporary_index_path(
_tmp_dir, _rowset_id, _seg_id, index_meta->index_id(), index_meta->get_index_suffix());
@@ -69,8 +69,8 @@
return dir;
}
-Status InvertedIndexFileWriter::delete_index(const TabletIndex* index_meta) {
- DBUG_EXECUTE_IF("InvertedIndexFileWriter::delete_index_index_meta_nullptr",
+Status XIndexFileWriter::delete_index(const TabletIndex* index_meta) {
+ DBUG_EXECUTE_IF("XIndexFileWriter::delete_index_index_meta_nullptr",
{ index_meta = nullptr; });
if (!index_meta) {
return Status::Error<ErrorCode::INVALID_ARGUMENT>("Index metadata is null.");
@@ -81,7 +81,7 @@
// Check if the specified index exists
auto index_it = _indices_dirs.find(std::make_pair(index_id, index_suffix));
- DBUG_EXECUTE_IF("InvertedIndexFileWriter::delete_index_indices_dirs_reach_end",
+ DBUG_EXECUTE_IF("XIndexFileWriter::delete_index_indices_dirs_reach_end",
{ index_it = _indices_dirs.end(); })
if (index_it == _indices_dirs.end()) {
std::ostringstream errMsg;
@@ -95,7 +95,7 @@
return Status::OK();
}
-int64_t InvertedIndexFileWriter::headerLength() {
+int64_t XIndexFileWriter::headerLength() {
int64_t header_size = 0;
header_size +=
sizeof(int32_t) * 2; // Account for the size of the version number and number of indices
@@ -120,7 +120,7 @@
return header_size;
}
-Status InvertedIndexFileWriter::close() {
+Status XIndexFileWriter::close() {
DCHECK(!_closed) << debug_string();
_closed = true;
if (_indices_dirs.empty()) {
@@ -129,7 +129,7 @@
DBUG_EXECUTE_IF("inverted_index_storage_format_must_be_v2", {
if (_storage_format != InvertedIndexStorageFormatPB::V2) {
return Status::Error<ErrorCode::INVERTED_INDEX_CLUCENE_ERROR>(
- "InvertedIndexFileWriter::close fault injection:inverted index storage format "
+ "XIndexFileWriter::close fault injection:inverted index storage format "
"must be v2");
}
})
@@ -169,7 +169,7 @@
return Status::OK();
}
-void InvertedIndexFileWriter::sort_files(std::vector<FileInfo>& file_infos) {
+void XIndexFileWriter::sort_files(std::vector<FileInfo>& file_infos) {
auto file_priority = [](const std::string& filename) {
if (filename.find("segments") != std::string::npos) {
return 1;
@@ -192,13 +192,13 @@
});
}
-void InvertedIndexFileWriter::copyFile(const char* fileName, lucene::store::Directory* dir,
+void XIndexFileWriter::copyFile(const char* fileName, lucene::store::Directory* dir,
lucene::store::IndexOutput* output, uint8_t* buffer,
int64_t bufferLength) {
lucene::store::IndexInput* tmp = nullptr;
CLuceneError err;
auto open = dir->openInput(fileName, tmp, err);
- DBUG_EXECUTE_IF("InvertedIndexFileWriter::copyFile_openInput_error", {
+ DBUG_EXECUTE_IF("XIndexFileWriter::copyFile_openInput_error", {
open = false;
err.set(CL_ERR_IO, "debug point: copyFile_openInput_error");
});
@@ -218,7 +218,7 @@
output->writeBytes(buffer, len);
remainder -= len;
}
- DBUG_EXECUTE_IF("InvertedIndexFileWriter::copyFile_remainder_is_not_zero", { remainder = 10; });
+ DBUG_EXECUTE_IF("XIndexFileWriter::copyFile_remainder_is_not_zero", { remainder = 10; });
if (remainder != 0) {
std::ostringstream errMsg;
errMsg << "Non-zero remainder length after copying: " << remainder << " (id: " << fileName
@@ -229,7 +229,7 @@
int64_t end_ptr = output->getFilePointer();
int64_t diff = end_ptr - start_ptr;
- DBUG_EXECUTE_IF("InvertedIndexFileWriter::copyFile_diff_not_equals_length",
+ DBUG_EXECUTE_IF("XIndexFileWriter::copyFile_diff_not_equals_length",
{ diff = length - 10; });
if (diff != length) {
std::ostringstream errMsg;
@@ -241,7 +241,7 @@
input->close();
}
-Status InvertedIndexFileWriter::write_v1() {
+Status XIndexFileWriter::write_v1() {
int64_t total_size = 0;
std::unique_ptr<lucene::store::Directory, DirectoryDeleter> out_dir = nullptr;
std::unique_ptr<lucene::store::IndexOutput> output = nullptr;
@@ -293,7 +293,7 @@
return Status::OK();
}
-Status InvertedIndexFileWriter::write() {
+Status XIndexFileWriter::write() {
std::unique_ptr<lucene::store::Directory, DirectoryDeleter> out_dir = nullptr;
std::unique_ptr<lucene::store::IndexOutput> compound_file_output = nullptr;
ErrorContext error_context;
@@ -337,7 +337,7 @@
}
// Helper function implementations
-std::vector<FileInfo> InvertedIndexFileWriter::prepare_sorted_files(
+std::vector<FileInfo> XIndexFileWriter::prepare_sorted_files(
lucene::store::Directory* directory) {
std::vector<std::string> files;
directory->list(&files);
@@ -359,7 +359,7 @@
return sorted_files;
}
-void InvertedIndexFileWriter::add_index_info(int64_t index_id, const std::string& index_suffix,
+void XIndexFileWriter::add_index_info(int64_t index_id, const std::string& index_suffix,
int64_t compound_file_size) {
InvertedIndexFileInfo_IndexInfo index_info;
index_info.set_index_id(index_id);
@@ -369,15 +369,15 @@
*new_index_info = index_info;
}
-std::pair<int64_t, int32_t> InvertedIndexFileWriter::calculate_header_length(
+std::pair<int64_t, int32_t> XIndexFileWriter::calculate_header_length(
const std::vector<FileInfo>& sorted_files, lucene::store::Directory* directory) {
// Use RAMDirectory to calculate header length
lucene::store::RAMDirectory ram_dir;
auto* out_idx = ram_dir.createOutput("temp_idx");
- DBUG_EXECUTE_IF("InvertedIndexFileWriter::calculate_header_length_ram_output_is_nullptr",
+ DBUG_EXECUTE_IF("XIndexFileWriter::calculate_header_length_ram_output_is_nullptr",
{ out_idx = nullptr; })
if (out_idx == nullptr) {
- LOG(WARNING) << "InvertedIndexFileWriter::calculate_header_length error: RAMDirectory "
+ LOG(WARNING) << "XIndexFileWriter::calculate_header_length error: RAMDirectory "
"output is nullptr.";
_CLTHROWA(CL_ERR_IO, "Create RAMDirectory output error");
}
@@ -409,7 +409,7 @@
std::pair<std::unique_ptr<lucene::store::Directory, DirectoryDeleter>,
std::unique_ptr<lucene::store::IndexOutput>>
-InvertedIndexFileWriter::create_output_stream_v1(int64_t index_id,
+XIndexFileWriter::create_output_stream_v1(int64_t index_id,
const std::string& index_suffix) {
io::Path cfs_path(InvertedIndexDescriptor::get_index_file_path_v1(_index_path_prefix, index_id,
index_suffix));
@@ -421,10 +421,10 @@
std::unique_ptr<lucene::store::Directory, DirectoryDeleter> out_dir_ptr(out_dir);
auto* out = out_dir->createOutput(idx_name.c_str());
- DBUG_EXECUTE_IF("InvertedIndexFileWriter::write_v1_out_dir_createOutput_nullptr",
+ DBUG_EXECUTE_IF("XIndexFileWriter::write_v1_out_dir_createOutput_nullptr",
{ out = nullptr; });
if (out == nullptr) {
- LOG(WARNING) << "InvertedIndexFileWriter::create_output_stream_v1 error: CompoundDirectory "
+ LOG(WARNING) << "XIndexFileWriter::create_output_stream_v1 error: CompoundDirectory "
"output is nullptr.";
_CLTHROWA(CL_ERR_IO, "Create CompoundDirectory output error");
}
@@ -433,7 +433,7 @@
return {std::move(out_dir_ptr), std::move(output)};
}
-void InvertedIndexFileWriter::write_header_and_data_v1(lucene::store::IndexOutput* output,
+void XIndexFileWriter::write_header_and_data_v1(lucene::store::IndexOutput* output,
const std::vector<FileInfo>& sorted_files,
lucene::store::Directory* directory,
int64_t header_length,
@@ -470,7 +470,7 @@
std::pair<std::unique_ptr<lucene::store::Directory, DirectoryDeleter>,
std::unique_ptr<lucene::store::IndexOutput>>
-InvertedIndexFileWriter::create_output_stream() {
+XIndexFileWriter::create_output_stream() {
io::Path index_path {InvertedIndexDescriptor::get_index_file_path_v2(_index_path_prefix)};
auto* out_dir = DorisFSDirectoryFactory::getDirectory(_fs, index_path.parent_path().c_str());
@@ -484,7 +484,7 @@
return {std::move(out_dir_ptr), std::move(compound_file_output)};
}
-void InvertedIndexFileWriter::write_version_and_indices_count(lucene::store::IndexOutput* output) {
+void XIndexFileWriter::write_version_and_indices_count(lucene::store::IndexOutput* output) {
// Write the version number
output->writeInt(_storage_format);
@@ -493,7 +493,7 @@
output->writeInt(num_indices);
}
-std::vector<InvertedIndexFileWriter::FileMetadata> InvertedIndexFileWriter::prepare_file_metadata(
+std::vector<XIndexFileWriter::FileMetadata> XIndexFileWriter::prepare_file_metadata(
int64_t& current_offset) {
std::vector<FileMetadata> file_metadata;
@@ -514,7 +514,7 @@
return file_metadata;
}
-void InvertedIndexFileWriter::write_index_headers_and_metadata(
+void XIndexFileWriter::write_index_headers_and_metadata(
lucene::store::IndexOutput* output, const std::vector<FileMetadata>& file_metadata) {
// Group files by index_id and index_suffix
std::map<std::pair<int64_t, std::string>, std::vector<FileMetadata>> indices;
@@ -546,7 +546,7 @@
}
}
-void InvertedIndexFileWriter::copy_files_data(lucene::store::IndexOutput* output,
+void XIndexFileWriter::copy_files_data(lucene::store::IndexOutput* output,
const std::vector<FileMetadata>& file_metadata) {
const int64_t buffer_length = 16384;
uint8_t buffer[buffer_length];
diff --git a/be/src/olap/rowset/segment_v2/inverted_index_file_writer.h b/be/src/olap/rowset/segment_v2/x_index_file_writer.h
similarity index 96%
rename from be/src/olap/rowset/segment_v2/inverted_index_file_writer.h
rename to be/src/olap/rowset/segment_v2/x_index_file_writer.h
index ab7cdbf..42922e1 100644
--- a/be/src/olap/rowset/segment_v2/inverted_index_file_writer.h
+++ b/be/src/olap/rowset/segment_v2/x_index_file_writer.h
@@ -41,8 +41,8 @@
using InvertedIndexDirectoryMap =
std::map<std::pair<int64_t, std::string>, std::shared_ptr<lucene::store::Directory>>;
-class InvertedIndexFileWriter;
-using InvertedIndexFileWriterPtr = std::unique_ptr<InvertedIndexFileWriter>;
+class XIndexFileWriter;
+using XIndexFileWriterPtr = std::unique_ptr<XIndexFileWriter>;
class FileInfo {
public:
@@ -50,9 +50,9 @@
int64_t filesize;
};
-class InvertedIndexFileWriter {
+class XIndexFileWriter {
public:
- InvertedIndexFileWriter(io::FileSystemSPtr fs, std::string index_path_prefix,
+ XIndexFileWriter(io::FileSystemSPtr fs, std::string index_path_prefix,
std::string rowset_id, int64_t seg_id,
InvertedIndexStorageFormatPB storage_format,
io::FileWriterPtr file_writer = nullptr)
@@ -70,7 +70,7 @@
Result<std::shared_ptr<DorisFSDirectory>> open(const TabletIndex* index_meta);
Status delete_index(const TabletIndex* index_meta);
Status initialize(InvertedIndexDirectoryMap& indices_dirs);
- virtual ~InvertedIndexFileWriter() = default;
+ virtual ~XIndexFileWriter() = default;
Status write();
Status write_v1();
Status close();
diff --git a/be/src/olap/rowset/vertical_beta_rowset_writer.cpp b/be/src/olap/rowset/vertical_beta_rowset_writer.cpp
index f493f21..e59fa15 100644
--- a/be/src/olap/rowset/vertical_beta_rowset_writer.cpp
+++ b/be/src/olap/rowset/vertical_beta_rowset_writer.cpp
@@ -170,10 +170,10 @@
RETURN_IF_ERROR(BaseBetaRowsetWriter::create_file_writer(seg_id, segment_file_writer));
DCHECK(segment_file_writer != nullptr);
- InvertedIndexFileWriterPtr inverted_index_file_writer;
+ XIndexFileWriterPtr x_index_file_writer;
if (context.tablet_schema->has_inverted_index()) {
- RETURN_IF_ERROR(RowsetWriter::create_inverted_index_file_writer(
- seg_id, &inverted_index_file_writer));
+ RETURN_IF_ERROR(RowsetWriter::create_x_index_file_writer(
+ seg_id, &x_index_file_writer));
}
segment_v2::SegmentWriterOptions writer_options;
@@ -183,11 +183,11 @@
// TODO if support VerticalSegmentWriter, also need to handle cluster key primary key index
*writer = std::make_unique<segment_v2::SegmentWriter>(
segment_file_writer.get(), seg_id, context.tablet_schema, context.tablet,
- context.data_dir, writer_options, inverted_index_file_writer.get());
+ context.data_dir, writer_options, x_index_file_writer.get());
RETURN_IF_ERROR(this->_seg_files.add(seg_id, std::move(segment_file_writer)));
if (context.tablet_schema->has_inverted_index()) {
- RETURN_IF_ERROR(this->_idx_files.add(seg_id, std::move(inverted_index_file_writer)));
+ RETURN_IF_ERROR(this->_idx_files.add(seg_id, std::move(x_index_file_writer)));
}
auto s = (*writer)->init(column_ids, is_key);
@@ -219,7 +219,7 @@
template <class T>
requires std::is_base_of_v<BaseBetaRowsetWriter, T>
Status VerticalBetaRowsetWriter<T>::_close_file_writers() {
- RETURN_IF_ERROR(BaseBetaRowsetWriter::_close_inverted_index_file_writers());
+ RETURN_IF_ERROR(BaseBetaRowsetWriter::_close_x_index_file_writers());
return this->_seg_files.close();
}
diff --git a/be/src/olap/schema_change.cpp b/be/src/olap/schema_change.cpp
index 063e5e9..2211d2a 100644
--- a/be/src/olap/schema_change.cpp
+++ b/be/src/olap/schema_change.cpp
@@ -57,7 +57,7 @@
#include "olap/rowset/rowset_writer_context.h"
#include "olap/rowset/segment_v2/column_reader.h"
#include "olap/rowset/segment_v2/inverted_index_desc.h"
-#include "olap/rowset/segment_v2/inverted_index_writer.h"
+#include "olap/rowset/segment_v2/index_writer.h"
#include "olap/rowset/segment_v2/segment.h"
#include "olap/schema.h"
#include "olap/segment_loader.h"
diff --git a/be/src/olap/schema_change.h b/be/src/olap/schema_change.h
index c29cb49..632fb44 100644
--- a/be/src/olap/schema_change.h
+++ b/be/src/olap/schema_change.h
@@ -43,7 +43,7 @@
#include "olap/rowset/rowset.h"
#include "olap/rowset/rowset_reader.h"
#include "olap/rowset/rowset_writer.h"
-#include "olap/rowset/segment_v2/inverted_index_writer.h"
+#include "olap/rowset/segment_v2/index_writer.h"
#include "olap/storage_engine.h"
#include "olap/tablet.h"
#include "olap/tablet_fwd.h"
diff --git a/be/src/olap/tablet_meta.cpp b/be/src/olap/tablet_meta.cpp
index 43b0d5d..3d03847 100644
--- a/be/src/olap/tablet_meta.cpp
+++ b/be/src/olap/tablet_meta.cpp
@@ -285,6 +285,9 @@
case TIndexType::INVERTED:
index_pb->set_index_type(IndexType::INVERTED);
break;
+ case TIndexType::ANN:
+ index_pb->set_index_type(IndexType::ANN);
+ break;
case TIndexType::BLOOMFILTER:
index_pb->set_index_type(IndexType::BLOOMFILTER);
break;
diff --git a/be/src/olap/tablet_schema.cpp b/be/src/olap/tablet_schema.cpp
index 7b6b5f3..f3d8a3f 100644
--- a/be/src/olap/tablet_schema.cpp
+++ b/be/src/olap/tablet_schema.cpp
@@ -767,6 +767,9 @@
case TIndexType::INVERTED:
_index_type = IndexType::INVERTED;
break;
+ case TIndexType::ANN:
+ _index_type = IndexType::ANN;
+ break;
case TIndexType::BLOOMFILTER:
_index_type = IndexType::BLOOMFILTER;
break;
@@ -794,6 +797,9 @@
case TIndexType::INVERTED:
_index_type = IndexType::INVERTED;
break;
+ case TIndexType::ANN:
+ _index_type = IndexType::ANN;
+ break;
case TIndexType::BLOOMFILTER:
_index_type = IndexType::BLOOMFILTER;
break;
@@ -1416,9 +1422,26 @@
return nullptr;
}
+const TabletIndex* TabletSchema::ann_index(int32_t col_unique_id,
+ const std::string& suffix_path) const {
+ for (size_t i = 0; i < _indexes.size(); i++) {
+ if (_indexes[i].index_type() == IndexType::ANN) {
+ for (int32_t id : _indexes[i].col_unique_ids()) {
+ if (id == col_unique_id &&
+ _indexes[i].get_index_suffix() == escape_for_path_name(suffix_path)) {
+ return &(_indexes[i]);
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+
+
const TabletIndex* TabletSchema::inverted_index(const TabletColumn& col) const {
// Some columns(Float, Double, JSONB ...) from the variant do not support inverted index
- if (!segment_v2::InvertedIndexColumnWriter::check_support_inverted_index(col)) {
+ if (!segment_v2::IndexColumnWriter::check_support_inverted_index(col)) {
return nullptr;
}
// TODO use more efficient impl
@@ -1427,6 +1450,17 @@
return inverted_index(col_unique_id, escape_for_path_name(col.suffix_path()));
}
+const TabletIndex* TabletSchema::ann_index(const TabletColumn& col) const {
+ // Some columns(Float, Double, JSONB ...) from the variant do not support inverted index
+ if (!segment_v2::IndexColumnWriter::check_support_ann_index(col)) {
+ return nullptr;
+ }
+ // TODO use more efficient impl
+ // Use parent id if unique not assigned, this could happend when accessing subcolumns of variants
+ int32_t col_unique_id = col.is_extracted_column() ? col.parent_unique_id() : col.unique_id();
+ return ann_index(col_unique_id, escape_for_path_name(col.suffix_path()));
+}
+
bool TabletSchema::has_ngram_bf_index(int32_t col_unique_id) const {
// TODO use more efficient impl
for (size_t i = 0; i < _indexes.size(); i++) {
diff --git a/be/src/olap/tablet_schema.h b/be/src/olap/tablet_schema.h
index 3dfe055..30848d495 100644
--- a/be/src/olap/tablet_schema.h
+++ b/be/src/olap/tablet_schema.h
@@ -414,15 +414,39 @@
}
return false;
}
+
+ bool has_ann_index() const {
+ for (const auto& index : _indexes) {
+ if (index.index_type() == IndexType::ANN) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool has_extra_index(){
+ return has_inverted_index() || has_ann_index();
+ }
+
bool has_inverted_index_with_index_id(int64_t index_id) const;
// Check whether this column supports inverted index
// Some columns (Float, Double, JSONB ...) from the variant do not support index, but they are listed in TabletIndex.
const TabletIndex* inverted_index(const TabletColumn& col) const;
+
// Regardless of whether this column supports inverted index
// TabletIndex information will be returned as long as it exists.
const TabletIndex* inverted_index(int32_t col_unique_id,
const std::string& suffix_path = "") const;
+
+ const TabletIndex* ann_index(const TabletColumn& col) const;
+
+
+ // Regardless of whether this column supports inverted index
+ // TabletIndex information will be returned as long as it exists.
+ const TabletIndex* ann_index(int32_t col_unique_id,
+ const std::string& suffix_path = "") const;
+
bool has_ngram_bf_index(int32_t col_unique_id) const;
const TabletIndex* get_ngram_bf_index(int32_t col_unique_id) const;
void update_indexes_from_thrift(const std::vector<doris::TOlapTableIndex>& indexes);
diff --git a/be/src/olap/task/index_builder.cpp b/be/src/olap/task/index_builder.cpp
index 2ce3152..20ab8f1 100644
--- a/be/src/olap/task/index_builder.cpp
+++ b/be/src/olap/task/index_builder.cpp
@@ -24,9 +24,9 @@
#include "olap/rowset/rowset_writer_context.h"
#include "olap/rowset/segment_v2/inverted_index_desc.h"
#include "olap/rowset/segment_v2/inverted_index_file_reader.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
-#include "olap/rowset/segment_v2/inverted_index_writer.h"
+#include "olap/rowset/segment_v2/index_writer.h"
#include "olap/segment_loader.h"
#include "olap/storage_engine.h"
#include "olap/tablet_schema.h"
@@ -317,20 +317,20 @@
<< ", err: " << st;
return st;
}
- auto inverted_index_file_writer = std::make_unique<InvertedIndexFileWriter>(
+ auto x_index_file_writer = std::make_unique<XIndexFileWriter>(
fs, std::move(index_path_prefix),
output_rowset_meta->rowset_id().to_string(), seg_ptr->id(),
output_rowset_schema->get_inverted_index_storage_format(),
std::move(file_writer));
- RETURN_IF_ERROR(inverted_index_file_writer->initialize(dirs));
+ RETURN_IF_ERROR(x_index_file_writer->initialize(dirs));
// create inverted index writer
for (auto& index_meta : _dropped_inverted_indexes) {
- RETURN_IF_ERROR(inverted_index_file_writer->delete_index(&index_meta));
+ RETURN_IF_ERROR(x_index_file_writer->delete_index(&index_meta));
}
- _inverted_index_file_writers.emplace(seg_ptr->id(),
- std::move(inverted_index_file_writer));
+ _x_index_file_writers.emplace(seg_ptr->id(),
+ std::move(x_index_file_writer));
}
- for (auto&& [seg_id, inverted_index_writer] : _inverted_index_file_writers) {
+ for (auto&& [seg_id, inverted_index_writer] : _x_index_file_writers) {
auto st = inverted_index_writer->close();
if (!st.ok()) {
LOG(ERROR) << "close inverted_index_writer error:" << st;
@@ -338,7 +338,7 @@
}
inverted_index_size += inverted_index_writer->get_index_file_total_size();
}
- _inverted_index_file_writers.clear();
+ _x_index_file_writers.clear();
output_rowset_meta->set_data_disk_size(output_rowset_meta->data_disk_size());
output_rowset_meta->set_total_disk_size(output_rowset_meta->total_disk_size() +
inverted_index_size);
@@ -361,7 +361,7 @@
std::vector<std::pair<int64_t, int64_t>> inverted_index_writer_signs;
_olap_data_convertor->reserve(_alter_inverted_indexes.size());
- std::unique_ptr<InvertedIndexFileWriter> inverted_index_file_writer = nullptr;
+ std::unique_ptr<XIndexFileWriter> x_index_file_writer = nullptr;
if (output_rowset_schema->get_inverted_index_storage_format() >=
InvertedIndexStorageFormatPB::V2) {
auto idx_file_reader_iter = _inverted_index_file_readers.find(
@@ -383,13 +383,13 @@
return st;
}
auto dirs = DORIS_TRY(idx_file_reader_iter->second->get_all_directories());
- inverted_index_file_writer = std::make_unique<InvertedIndexFileWriter>(
+ x_index_file_writer = std::make_unique<XIndexFileWriter>(
fs, index_path_prefix, output_rowset_meta->rowset_id().to_string(),
seg_ptr->id(), output_rowset_schema->get_inverted_index_storage_format(),
std::move(file_writer));
- RETURN_IF_ERROR(inverted_index_file_writer->initialize(dirs));
+ RETURN_IF_ERROR(x_index_file_writer->initialize(dirs));
} else {
- inverted_index_file_writer = std::make_unique<InvertedIndexFileWriter>(
+ x_index_file_writer = std::make_unique<XIndexFileWriter>(
fs, index_path_prefix, output_rowset_meta->rowset_id().to_string(),
seg_ptr->id(), output_rowset_schema->get_inverted_index_storage_format());
}
@@ -414,7 +414,7 @@
auto column = output_rowset_schema->column(column_idx);
// variant column is not support for building index
auto is_support_inverted_index =
- InvertedIndexColumnWriter::check_support_inverted_index(column);
+ IndexColumnWriter::check_support_inverted_index(column);
DBUG_EXECUTE_IF("IndexBuilder::handle_single_rowset_support_inverted_index",
{ is_support_inverted_index = false; })
if (!is_support_inverted_index) {
@@ -425,10 +425,10 @@
return_columns.emplace_back(column_idx);
std::unique_ptr<Field> field(FieldFactory::create(column));
const auto* index_meta = output_rowset_schema->inverted_index(column);
- std::unique_ptr<segment_v2::InvertedIndexColumnWriter> inverted_index_builder;
+ std::unique_ptr<segment_v2::IndexColumnWriter> inverted_index_builder;
try {
- RETURN_IF_ERROR(segment_v2::InvertedIndexColumnWriter::create(
- field.get(), &inverted_index_builder, inverted_index_file_writer.get(),
+ RETURN_IF_ERROR(segment_v2::IndexColumnWriter::create(
+ field.get(), &inverted_index_builder, x_index_file_writer.get(),
index_meta));
DBUG_EXECUTE_IF(
"IndexBuilder::handle_single_rowset_index_column_writer_create_error", {
@@ -454,8 +454,8 @@
break;
}
- _inverted_index_file_writers.emplace(seg_ptr->id(),
- std::move(inverted_index_file_writer));
+ _x_index_file_writers.emplace(seg_ptr->id(),
+ std::move(x_index_file_writer));
// create iterator for each segment
StorageReadOptions read_options;
@@ -528,8 +528,8 @@
_olap_data_convertor->reset();
}
- for (auto&& [seg_id, inverted_index_file_writer] : _inverted_index_file_writers) {
- auto st = inverted_index_file_writer->close();
+ for (auto&& [seg_id, x_index_file_writer] : _x_index_file_writers) {
+ auto st = x_index_file_writer->close();
DBUG_EXECUTE_IF("IndexBuilder::handle_single_rowset_file_writer_close_error", {
st = Status::Error<ErrorCode::INVERTED_INDEX_CLUCENE_ERROR>(
"debug point: handle_single_rowset_file_writer_close_error");
@@ -538,10 +538,10 @@
LOG(ERROR) << "close inverted_index_writer error:" << st;
return st;
}
- inverted_index_size += inverted_index_file_writer->get_index_file_total_size();
+ inverted_index_size += x_index_file_writer->get_index_file_total_size();
}
_inverted_index_builders.clear();
- _inverted_index_file_writers.clear();
+ _x_index_file_writers.clear();
output_rowset_meta->set_data_disk_size(output_rowset_meta->data_disk_size());
output_rowset_meta->set_total_disk_size(output_rowset_meta->total_disk_size() +
inverted_index_size);
diff --git a/be/src/olap/task/index_builder.h b/be/src/olap/task/index_builder.h
index 8c996bb..765fb1f 100644
--- a/be/src/olap/task/index_builder.h
+++ b/be/src/olap/task/index_builder.h
@@ -23,15 +23,15 @@
#include "olap/rowset/pending_rowset_helper.h"
#include "olap/rowset/rowset_fwd.h"
#include "olap/rowset/segment_v2/inverted_index_desc.h"
-#include "olap/rowset/segment_v2/inverted_index_file_writer.h"
+#include "olap/rowset/segment_v2/x_index_file_writer.h"
#include "olap/rowset/segment_v2/segment.h"
#include "olap/tablet_fwd.h"
#include "vec/olap/olap_data_convertor.h"
namespace doris {
namespace segment_v2 {
-class InvertedIndexColumnWriter;
-class InvertedIndexFileWriter;
+class IndexColumnWriter;
+class XIndexFileWriter;
} // namespace segment_v2
namespace vectorized {
class OlapBlockDataConvertor;
@@ -81,12 +81,12 @@
std::vector<PendingRowsetGuard> _pending_rs_guards;
std::vector<RowsetReaderSharedPtr> _input_rs_readers;
std::unique_ptr<vectorized::OlapBlockDataConvertor> _olap_data_convertor;
- // "<segment_id, index_id>" -> InvertedIndexColumnWriter
+ // "<segment_id, index_id>" -> IndexColumnWriter
std::unordered_map<std::pair<int64_t, int64_t>,
- std::unique_ptr<segment_v2::InvertedIndexColumnWriter>>
+ std::unique_ptr<segment_v2::IndexColumnWriter>>
_inverted_index_builders;
- std::unordered_map<int64_t, std::unique_ptr<InvertedIndexFileWriter>>
- _inverted_index_file_writers;
+ std::unordered_map<int64_t, std::unique_ptr<XIndexFileWriter>>
+ _x_index_file_writers;
// <rowset_id, segment_id>
std::unordered_map<std::pair<std::string, int64_t>, std::unique_ptr<InvertedIndexFileReader>>
_inverted_index_file_readers;
diff --git a/be/src/pipeline/CMakeLists.txt b/be/src/pipeline/CMakeLists.txt
index fc69608..477f4a5 100644
--- a/be/src/pipeline/CMakeLists.txt
+++ b/be/src/pipeline/CMakeLists.txt
@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
+set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
+add_compile_options(-Wno-error=attributes)
+
# where to put generated libraries
set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/pipeline")
diff --git a/be/src/runtime/CMakeLists.txt b/be/src/runtime/CMakeLists.txt
index a0b3b79..de36172 100644
--- a/be/src/runtime/CMakeLists.txt
+++ b/be/src/runtime/CMakeLists.txt
@@ -17,6 +17,9 @@
# add_subdirectory(bufferpool)
+set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
+add_compile_options(-Wno-error=attributes)
+
# where to put generated libraries
set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/runtime")
diff --git a/be/src/runtime/exec_env.h b/be/src/runtime/exec_env.h
index 0c9a415..e5e4be0 100644
--- a/be/src/runtime/exec_env.h
+++ b/be/src/runtime/exec_env.h
@@ -30,7 +30,7 @@
#include "io/cache/fs_file_cache_storage.h"
#include "olap/memtable_memory_limiter.h"
#include "olap/options.h"
-#include "olap/rowset/segment_v2/inverted_index_writer.h"
+#include "olap/rowset/segment_v2/index_writer.h"
#include "olap/tablet_fwd.h"
#include "pipeline/pipeline_tracing.h"
#include "runtime/cluster_info.h"
diff --git a/be/src/runtime/exec_env_init.cpp b/be/src/runtime/exec_env_init.cpp
index df66315..9be2e86 100644
--- a/be/src/runtime/exec_env_init.cpp
+++ b/be/src/runtime/exec_env_init.cpp
@@ -214,6 +214,7 @@
_store_paths = store_paths;
_tmp_file_dirs = std::make_unique<segment_v2::TmpFileDirs>(_store_paths);
RETURN_IF_ERROR(_tmp_file_dirs->init());
+ // return Status::OK(); //调试diskann时打开这里
_user_function_cache = new UserFunctionCache();
static_cast<void>(_user_function_cache->init(doris::config::user_function_dir));
_external_scan_context_mgr = new ExternalScanContextMgr(this);
@@ -453,6 +454,7 @@
}
Status ExecEnv::_init_mem_env() {
+ // return Status::OK(); //diskann打开
bool is_percent = false;
std::stringstream ss;
// 1. init mem tracker
@@ -460,6 +462,7 @@
_heap_profiler = HeapProfiler::create_global_instance();
init_mem_tracker();
thread_context()->thread_mem_tracker_mgr->init();
+ //diskann这里得注释
#if defined(USE_MEM_TRACKER) && !defined(__SANITIZE_ADDRESS__) && !defined(ADDRESS_SANITIZER) && \
!defined(LEAK_SANITIZER) && !defined(THREAD_SANITIZER) && !defined(USE_JEMALLOC)
init_hook();
diff --git a/be/src/service/CMakeLists.txt b/be/src/service/CMakeLists.txt
index e44045d..d8c1e9a 100644
--- a/be/src/service/CMakeLists.txt
+++ b/be/src/service/CMakeLists.txt
@@ -28,6 +28,11 @@
pch_reuse(Service)
+find_package(OpenMP REQUIRED)
+if (OpenMP_FOUND)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
+endif()
+
if (${MAKE_TEST} STREQUAL "OFF" AND ${BUILD_BENCHMARK} STREQUAL "OFF")
add_executable(doris_be
doris_main.cpp
diff --git a/be/src/vector/CMakeLists.txt b/be/src/vector/CMakeLists.txt
new file mode 100644
index 0000000..ff6cf31
--- /dev/null
+++ b/be/src/vector/CMakeLists.txt
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# where to put generated libraries
+set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/vector")
+
+find_package(OpenMP REQUIRED)
+if (OpenMP_FOUND)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
+endif()
+
+set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
+add_compile_options(-Wno-error=attributes)
+add_compile_options(-Wno-deprecated-copy)
+add_compile_options(-Wno-reorder)
+add_compile_options(-Wno-unused-but-set-variable)
+add_compile_options(-Wno-error=unused-variable)
+
+
+file(GLOB_RECURSE SRC_FILES *.cpp)
+add_library(vector STATIC ${SRC_FILES})
+
diff --git a/be/src/vector/ann_parse.h b/be/src/vector/ann_parse.h
new file mode 100644
index 0000000..7f23485
--- /dev/null
+++ b/be/src/vector/ann_parse.h
@@ -0,0 +1,143 @@
+
+#include <iostream>
+#include "extern/diskann/include/utils.h"
+
+// used for debug
+class AnnParse{
+ public:
+ static void parse_data_stream(std::stringstream &data_stream, const std::string &point) {
+ int32_t nrows_32 = 0, ncols_32 = 0;
+ data_stream.seekg(0, std::ios::beg);
+ if (!data_stream.read(reinterpret_cast<char*>(&nrows_32), sizeof(int32_t)) ||
+ !data_stream.read(reinterpret_cast<char*>(&ncols_32), sizeof(int32_t))) {
+ std::cerr << "Error: Failed to read nrows_32 and ncols_32 from data_stream." << std::endl;
+ return;
+ }
+ std::cout << "====================data stream parse result(" << point << ")====================" << std::endl;
+ std::cout << "rows:" << nrows_32
+ << ", clos:" << ncols_32 << std::endl;
+ for(int i=0;i<nrows_32;i++){
+ std::cout << "vec["<<i<<"]"<<"->"<<"[";
+ for(int j=0;j<ncols_32;j++){
+ float vec;
+ data_stream.read(reinterpret_cast<char*>(&vec), 4);
+ std::cout << vec << ",";
+ }
+ std::cout << "]\n";
+ }
+ data_stream.clear();
+ data_stream.seekg(0, std::ios::beg);
+ }
+
+ static void parse_compressed_stream(std::stringstream &data_stream, const std::string &point) {
+ int32_t nrows_32 = 0, ncols_32 = 0;
+ data_stream.seekg(0, std::ios::beg);
+ if (!data_stream.read(reinterpret_cast<char*>(&nrows_32), sizeof(int32_t)) ||
+ !data_stream.read(reinterpret_cast<char*>(&ncols_32), sizeof(int32_t))) {
+ std::cerr << "Error: Failed to read nrows_32 and ncols_32 from data_stream." << std::endl;
+ return;
+ }
+ std::cout << "====================compressed stream parse result(" << point << ")====================" << std::endl;
+ std::cout << "rows:" << nrows_32
+ << ", clos:" << ncols_32 << std::endl;
+ for(int i=0;i<nrows_32;i++){
+ std::cout << "vec["<<i<<"]"<<"->"<<"[";
+ for(int j=0;j<ncols_32;j++){
+ uint8_t vec;
+ if(!data_stream.read(reinterpret_cast<char*>(&vec), 1)){
+ std::cerr << "Error: Failed to read from compressed_stream." << std::endl;
+ }
+ std::cout << int(vec) << ",";
+ }
+ std::cout << "]\n";
+ }
+ data_stream.clear();
+ data_stream.seekg(0, std::ios::beg);
+ }
+
+ static void parse_pivots_stream(std::stringstream &ss, const std::string &point){
+ std::cout << "==================== pivots stream parse result(" << point << ")====================" << std::endl;
+ ss.seekg(0);
+ size_t *cumu_offsets;
+ size_t rows;
+ size_t clos;
+ diskann::load_bin<size_t>(ss, cumu_offsets, rows, clos, 0);
+ std::cout << "cumu_offsets:[";
+ for(int i=0;i<rows;i++){
+ std::cout << *(cumu_offsets+i) << ",";
+ }
+ std::cout << "]";
+
+ float *codebook;
+ diskann::load_bin<float>(ss, codebook, rows, clos, cumu_offsets[0]);
+ std::cout << "codebook:["<<rows<<","<<clos<<"]"<< std::endl;
+ for(int i=0;i<rows;i++){
+ std::cout << "vec["<< i<<"]->[";
+ for(int j=0;j<clos;j++){
+ std::cout << codebook[i*clos+j] <<",";
+ }
+ std::cout << "]\n";
+ }
+
+ float *centroid;
+ diskann::load_bin<float>(ss, centroid, rows, clos, cumu_offsets[1]);
+ std::cout << "centroid:["<<rows<<","<<clos<<"]->[";
+ for(int i=0;i<rows;i++){
+ std::cout << centroid[i]<<",";
+ }
+ std::cout << "]\n";
+
+ uint32_t *chunk_offsets;
+ diskann::load_bin<uint32_t>(ss, chunk_offsets, rows, clos, cumu_offsets[2]);
+ std::cout << "chunk_offsets:["<<rows<<","<<clos<<"]->[";
+ for(int i=0;i<rows;i++){
+ std::cout << chunk_offsets[i]<<",";
+ }
+ std::cout << "]\n";
+ delete []cumu_offsets;
+ delete []codebook;
+ delete []centroid;
+ delete []chunk_offsets;
+ }
+
+ static void parse_index_stream(std::stringstream &ss, const std::string &point){
+ std::cout << "====================index stream parse result(" << point << ")====================" << std::endl;
+ ss.seekg(0);
+ // 读取头部信息
+ uint64_t index_size;
+ uint32_t max_observed_max_degree;
+ uint32_t start_point;
+ uint64_t num_forzen_points;
+ ss.read(reinterpret_cast<char*>(&index_size), sizeof(uint64_t));
+ ss.read(reinterpret_cast<char*>(&max_observed_max_degree), sizeof(uint32_t));
+ ss.read(reinterpret_cast<char*>(&start_point), sizeof(uint32_t));
+ ss.read(reinterpret_cast<char*>(&num_forzen_points), sizeof(uint64_t));
+
+ std::cout << "index_size:" << index_size << "\n"
+ << "max_observed_max_degree:" << max_observed_max_degree << "\n"
+ << "start_point:" << start_point << "\n"
+ << "num_forzen_points" << num_forzen_points << "\n";
+
+ // 持续读取向量,直到达到index_size
+ size_t current_offset = 24;
+ int idx = 0;
+ while (current_offset < index_size) {
+ // 读取邻居个数
+ uint32_t neighbor_count;
+ ss.read(reinterpret_cast<char*>(&neighbor_count), sizeof(uint32_t));
+ current_offset += 4;
+ // 读取邻居ID列表
+ std::vector<uint32_t> neighbor_ids(neighbor_count);
+ ss.read(reinterpret_cast<char*>(neighbor_ids.data()), neighbor_count * sizeof(uint32_t));
+ current_offset += neighbor_count * 4;
+ std::cout <<"vec["<<idx<<"] has "<< neighbor_count <<" neighbors, neighbor ids=[";
+ for(int m=0;m<neighbor_count;m++){
+ std::cout << neighbor_ids[m] << ",";
+ }
+ std::cout << "]" << std::endl;
+ idx++;
+ }
+ ss.seekg(0);
+ }
+
+};
\ No newline at end of file
diff --git a/be/src/vector/diskann_vector_index.cpp b/be/src/vector/diskann_vector_index.cpp
new file mode 100644
index 0000000..f293456
--- /dev/null
+++ b/be/src/vector/diskann_vector_index.cpp
@@ -0,0 +1,206 @@
+
+#include <iostream>
+#include <omp.h>
+#include <boost/program_options.hpp>
+#include <random>
+
+#include "extern/diskann/include/utils.h"
+#include "extern/diskann/include/disk_utils.h"
+#include "extern/diskann/include/math_utils.h"
+#include "extern/diskann/include/index.h"
+#include "extern/diskann/include/partition.h"
+
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <unistd.h>
+#include "extern/diskann/include/linux_aligned_file_reader.h"
+
+#include "extern/diskann/include/pq_flash_index.h"
+#include "extern/diskann/include/combined_file.h"
+#include <string.h>
+#include "extern/diskann/include/timer.h"
+
+#include <fstream>
+#include <filesystem>
+#include <queue>
+#include <map>
+#include <condition_variable>
+
+#include <sys/mman.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <atomic>
+#include "diskann_vector_index.h"
+#include "ann_parse.h"
+
+#define FINALLY_CLOSE(x) \
+ try { \
+ if (x != nullptr) { \
+ x->close(); \
+ delete x; \
+ } \
+ } catch (...) { \
+ }
+
+doris::Status DiskannVectorIndex::add(int n, const float *vec){
+ // 追加数据
+ _data_stream.write(reinterpret_cast<const char*>(vec), n * ndim * sizeof(float));
+ npt_num += n;
+ // 保存当前位置
+ std::streampos current_pos = _data_stream.tellp();
+ // 回到 npt_num 位置更新它
+ _data_stream.seekp(npt_num_pos);
+ _data_stream.write(reinterpret_cast<const char*>(&npt_num), sizeof(npt_num));
+ // 恢复写入指针
+ _data_stream.seekp(current_pos);
+ if (_data_stream.fail()) {
+ return doris::Status::IOError("Failed to write vector data");
+ }
+ return doris::Status::OK();
+}
+
+
+int DiskannVectorIndex::calculate_num_pq_chunks(){
+ double final_index_ram_limit = diskann::get_memory_budget(builderParameterPtr->get_search_ram_budget());
+ int num_pq_chunks = diskann::calculate_num_pq_chunks(final_index_ram_limit, npt_num, builderParameterPtr->dim);
+ return num_pq_chunks;
+}
+
+doris::Status DiskannVectorIndex::build(){
+ try{
+ diskann::generate_quantized_data<float>(_data_stream,
+ _pq_pivots_stream,
+ _pq_compressed_stream,
+ builderParameterPtr->get_metric_type(),
+ builderParameterPtr->get_sample_rate(),
+ calculate_num_pq_chunks(),
+ false,
+ "");
+ _data_stream.seekg(0, _data_stream.beg);
+ diskann::build_merged_vamana_index<float>(_data_stream,
+ builderParameterPtr->get_metric_type(),
+ builderParameterPtr->get_l(),
+ builderParameterPtr->get_r(),
+ 1,
+ builderParameterPtr->get_indexing_ram_budget(),
+ _vamana_index_stream, "", "",
+ 0, false,
+ builderParameterPtr->get_num_threads(), false, "",
+ "", "", 0);
+ _data_stream.seekg(0, _data_stream.beg);
+ diskann::create_disk_layout<float>(_data_stream, _vamana_index_stream, _disk_layout_stream);
+ return doris::Status::OK();
+ } catch (const std::exception& e) {
+ return doris::Status::InternalError<true>(std::string(e.what()));
+ }
+}
+
+
+
+
+doris::Status DiskannVectorIndex::save(){
+ try{
+ //构建索引到临时stringstream中
+ RETURN_IF_ERROR(build());
+ //把stream刷到存储层
+ lucene::store::IndexOutput* pq_pivots_output = _dir->createOutput(DiskannFileDesc::get_pq_pivots_file_name());
+ lucene::store::IndexOutput* pq_compressed_output = _dir->createOutput(DiskannFileDesc::get_pq_compressed_file_name());
+ lucene::store::IndexOutput* vamana_index_output = _dir->createOutput(DiskannFileDesc::get_vamana_index_file_name());
+ lucene::store::IndexOutput* disk_layout_output = _dir->createOutput(DiskannFileDesc::get_disklayout_file_name());
+ lucene::store::IndexOutput* tag_output = _dir->createOutput(DiskannFileDesc::get_tag_file_name());
+ RETURN_IF_ERROR(stream_write_to_output(_pq_pivots_stream, pq_pivots_output));
+ RETURN_IF_ERROR(stream_write_to_output(_pq_compressed_stream, pq_compressed_output));
+ RETURN_IF_ERROR(stream_write_to_output(_vamana_index_stream, vamana_index_output));
+ RETURN_IF_ERROR(stream_write_to_output(_disk_layout_stream, disk_layout_output));
+ RETURN_IF_ERROR(stream_write_to_output(_tag_stream, tag_output));
+ FINALLY_CLOSE(pq_pivots_output);
+ FINALLY_CLOSE(pq_compressed_output);
+ FINALLY_CLOSE(vamana_index_output);
+ FINALLY_CLOSE(disk_layout_output);
+ FINALLY_CLOSE(tag_output);
+ } catch (const std::exception& e) {
+ return doris::Status::InternalError(e.what());
+ }
+ return doris::Status::OK();
+}
+
+doris::Status DiskannVectorIndex::stream_write_to_output(std::stringstream &stream, lucene::store::IndexOutput *output) {
+ try {
+ stream.seekg(0, std::ios::beg); // 确保从头开始读取
+ if (!stream.good()) {
+ return doris::Status::Corruption("stream seekg failed");
+ }
+ const size_t buffer_size = 4096; // 4KB 缓冲区
+ std::vector<char> buffer(buffer_size);
+ while (stream) { // 只要 stream 仍然有效,就继续读取
+ stream.read(buffer.data(), buffer_size);
+ std::streamsize bytes_read = stream.gcount(); // 获取实际读取的字节数
+ if (bytes_read > 0) {
+ output->writeBytes(reinterpret_cast<const uint8_t*>(buffer.data()), static_cast<int32_t>(bytes_read));
+ }
+ }
+ return doris::Status::OK();
+ } catch (const std::exception &e) {
+ return doris::Status::Corruption(std::string("failed stream write to output, message=") + e.what());
+ }
+}
+
+
+
+doris::Status DiskannVectorIndex::load(VectorIndex::Metric dist_fn){
+ diskann::Metric metric;
+ if (dist_fn == VectorIndex::Metric::L2) {
+ metric = diskann::Metric::L2;
+ } else if (dist_fn == VectorIndex::Metric::INNER_PRODUCT) {
+ metric = diskann::Metric::INNER_PRODUCT;
+ } else if (dist_fn == VectorIndex::Metric::COSINE){
+ metric = diskann::Metric::COSINE;
+ } else {
+ std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
+ return doris::Status::InternalError("Error. Only l2 and mips distance functions are supported");
+ }
+ lucene::store::IndexInput* pq_pivots_input = _dir->openInput(DiskannFileDesc::get_pq_pivots_file_name());
+ std::cout << "Actual type: " << typeid(*pq_pivots_input).name() << std::endl;
+
+ lucene::store::IndexInput* pq_compressed_input = _dir->openInput(DiskannFileDesc::get_pq_compressed_file_name());
+ lucene::store::IndexInput* vamana_index_input = _dir->openInput(DiskannFileDesc::get_vamana_index_file_name());
+ lucene::store::IndexInput* disk_layout_input = _dir->openInput(DiskannFileDesc::get_disklayout_file_name());
+ lucene::store::IndexInput* tag_input = _dir->openInput(DiskannFileDesc::get_tag_file_name());
+ //Try to minimize the intrusion of Lucene code into the source code of Diskann, so we will split it into a layer here
+ std::shared_ptr<IndexInputReaderWrapper> pq_pivots_reader(new IndexInputReaderWrapper(pq_pivots_input));
+ std::shared_ptr<IndexInputReaderWrapper> pq_compressed_reader(new IndexInputReaderWrapper(pq_compressed_input));
+ std::shared_ptr<IndexInputReaderWrapper> vamana_index_reader(new IndexInputReaderWrapper(vamana_index_input));
+ std::shared_ptr<IndexInputReaderWrapper> disk_layout_reader(new IndexInputReaderWrapper(disk_layout_input));
+ std::shared_ptr<IndexInputReaderWrapper> tag_reader(new IndexInputReaderWrapper(tag_input));
+
+ _pFlashIndex = std::make_shared<diskann::PQFlashIndex<float,uint16_t>>(disk_layout_reader, metric);
+ _pFlashIndex->load(8, pq_pivots_reader, pq_compressed_reader, vamana_index_reader, disk_layout_reader, tag_reader);
+ return doris::Status::OK();
+}
+
+doris::Status DiskannVectorIndex::search(const float * query_vec,int topk, SearchResult *result,const SearchParameters* params){
+ try {
+ DiskannSearchParameter *searchParam = (DiskannSearchParameter*)params;
+ IDFilter *filter = searchParam->get_filter();
+ int optimized_beamwidth = searchParam->get_beam_width();
+ int search_list = searchParam->get_search_list();
+ std::vector<uint64_t> query_result_ids_64(topk);
+ std::vector<float> query_result_dists(topk);
+ diskann::QueryStats* stats = static_cast<diskann::QueryStats*>(result->stat);
+ uint32_t k = _pFlashIndex->cached_beam_search(query_vec, topk, search_list,
+ query_result_ids_64.data(),
+ query_result_dists.data(),
+ optimized_beamwidth, filter, stats);
+ result->rows = k;
+ for(int i=0;i<k;i++){
+ result->distances.push_back(query_result_dists[i]);
+ result->ids.push_back(query_result_ids_64[i]);
+ }
+ return doris::Status::OK();
+ } catch (const std::exception& e) {
+ return doris::Status::InternalError(e.what());
+ }
+}
+
+
+
diff --git a/be/src/vector/diskann_vector_index.h b/be/src/vector/diskann_vector_index.h
new file mode 100644
index 0000000..44cb949
--- /dev/null
+++ b/be/src/vector/diskann_vector_index.h
@@ -0,0 +1,274 @@
+#pragma once
+
+#include <CLucene.h>
+#include <CLucene/store/IndexInput.h>
+#include <CLucene/store/IndexOutput.h>
+
+#include "extern/diskann/include/distance.h"
+#include "extern/diskann/include/pq_flash_index.h"
+#include "vector_index.h"
+#include <roaring/roaring.hh>
+
+struct DiskannBuilderParameter : public BuilderParameter{
+ diskann::Metric metric_type;
+ int L;
+ int R;
+ int num_threads;
+ double sample_rate;
+ float indexing_ram_budget; //单位GB
+ float search_ram_budget; //单位GB
+ int dim;
+
+ DiskannBuilderParameter& with_mertic_type(VectorIndex::Metric metric){
+ metric_type = convert_to_diskann_metric(metric);
+ return *this;
+ }
+
+ diskann::Metric convert_to_diskann_metric(VectorIndex::Metric metric) {
+ switch (metric) {
+ case VectorIndex::Metric::L2:
+ return diskann::Metric::L2;
+ case VectorIndex::Metric::COSINE:
+ return diskann::Metric::COSINE;
+ case VectorIndex::Metric::INNER_PRODUCT:
+ return diskann::Metric::INNER_PRODUCT;
+ default:
+ throw std::invalid_argument("Unknown metric type");
+ }
+ }
+
+ std::string metric_to_string(diskann::Metric metric) {
+ switch (metric) {
+ case diskann::Metric::L2:
+ return "L2";
+ case diskann::Metric::INNER_PRODUCT:
+ return "INNER_PRODUCT";
+ case diskann::Metric::COSINE:
+ return "COSINE";
+ case diskann::Metric::FAST_L2:
+ return "FAST_L2";
+ default:
+ return "UNKNOWN";
+ }
+ }
+
+ DiskannBuilderParameter& with_dim(int d){
+ dim = d;
+ return *this;
+ }
+
+ DiskannBuilderParameter& with_indexing_ram_budget_mb(float ram){
+ indexing_ram_budget = ram / 1024;
+ return *this;
+ }
+
+ DiskannBuilderParameter& with_search_ram_budget_mb(float ram){
+ search_ram_budget = ram / 1024;
+ return *this;
+ }
+
+ DiskannBuilderParameter& with_sample_rate(float rate){
+ sample_rate = rate;
+ return *this;
+ }
+
+ DiskannBuilderParameter& with_build_num_threads(int threads){
+ num_threads = threads;
+ return *this;
+ }
+
+ DiskannBuilderParameter& with_L(int l){
+ L = l;
+ return *this;
+ }
+
+ DiskannBuilderParameter& with_R(int r){
+ R = r;
+ return *this;
+ }
+
+ std::string to_string(){
+ std::ostringstream oss;
+ oss << "metric_type:" << metric_to_string(metric_type)
+ << ", L:" << L
+ << ", R:" << R
+ << ", num_threads:" << num_threads
+ << ", sample_rate:" << sample_rate
+ << ", indexing_ram_budget:" << indexing_ram_budget <<"G"
+ << ", search_ram_budget:" << search_ram_budget <<"G"
+ << ", dim:" << dim;
+ return oss.str();
+ }
+ diskann::Metric get_metric_type(){
+ return metric_type;
+ }
+ int get_l(){
+ return L;
+ }
+ int get_r(){
+ return R;
+ }
+ int get_num_threads(){
+ return num_threads;
+ }
+ double get_sample_rate(){
+ return sample_rate;
+ }
+ float get_indexing_ram_budget(){
+ return indexing_ram_budget;
+ }
+ int get_dim(){
+ return dim;
+ }
+ float get_search_ram_budget(){
+ return search_ram_budget;
+ }
+};
+
+
+struct IDFilter : public diskann::Filter {
+ private:
+ std::shared_ptr<roaring::Roaring> _bitmap;
+ public:
+ IDFilter(std::shared_ptr<roaring::Roaring> bitmap){
+ _bitmap = bitmap;
+ }
+ bool is_member(uint32_t idx){
+ return _bitmap->contains(idx);
+ }
+};
+
+struct DiskannSearchParameter : public SearchParameters{
+ int search_list;
+ int beam_width;
+ std::shared_ptr<IDFilter> filter;
+ DiskannSearchParameter(){
+ filter = nullptr;
+ search_list = 100;
+ beam_width = 2;
+ }
+ DiskannSearchParameter& with_search_list(int l){
+ search_list = l;
+ return *this;
+ }
+ DiskannSearchParameter& with_beam_width(int width){
+ beam_width = width;
+ return *this;
+ }
+
+ DiskannSearchParameter& set_filter(std::shared_ptr<IDFilter> f){
+ filter = f;
+ return *this;
+ }
+
+ IDFilter *get_filter(){
+ return filter.get();
+ }
+ int get_search_list(){
+ return search_list;
+ }
+ int get_beam_width(){
+ return beam_width;
+ }
+};
+
+
+
+class WriterWrapper {
+ private:
+ std::shared_ptr<lucene::store::IndexOutput> _out;
+ public:
+ WriterWrapper(std::shared_ptr<lucene::store::IndexOutput> out);
+ void write(uint8_t* b, uint64_t len);
+};
+
+
+//diskann的索引文件合到到1个文件,重新定义下每个部分代表什么含义
+class DiskannFileDesc {
+ public:
+ static constexpr const char* PQ_PIVOTS_FILE_NAME = "pq_pivots_file";
+ static constexpr const char* PQ_COMPRESSED_FILE_NAME = "pq_compressed_file";
+ static constexpr const char* VAMANA_INDEX_FILE_NAME = "vamana_index_file";
+ static constexpr const char* DISK_LAYOUT_FILE_NAME = "disk_layout_file";
+ static constexpr const char* TAG_FILE_NAME = "tag_file";
+ static const char* get_pq_pivots_file_name(){
+ return PQ_PIVOTS_FILE_NAME;
+ }
+ static const char* get_pq_compressed_file_name(){
+ return PQ_COMPRESSED_FILE_NAME;
+ }
+ static const char* get_vamana_index_file_name(){
+ return VAMANA_INDEX_FILE_NAME;
+ }
+ static const char* get_disklayout_file_name(){
+ return DISK_LAYOUT_FILE_NAME;
+ }
+ static const char* get_tag_file_name(){
+ return TAG_FILE_NAME;
+ }
+};
+
+class DiskannVectorIndex : public VectorIndex{
+ private:
+ std::shared_ptr<diskann::PQFlashIndex<float,uint16_t> > _pFlashIndex;
+ std::shared_ptr<DiskannBuilderParameter> builderParameterPtr;
+ //适配doris的存储介质
+ std::shared_ptr<lucene::store::Directory> _dir;
+
+ //原始向量
+ std::stringstream _data_stream;
+ //codebook临时缓存
+ std::stringstream _pq_pivots_stream;
+ //量化向量临时缓存
+ std::stringstream _pq_compressed_stream;
+ //图索引
+ std::stringstream _vamana_index_stream;
+ //最终的磁盘布局
+ std::stringstream _disk_layout_stream;
+ std::stringstream _tag_stream;
+
+ int npt_num = 0 ; // 向量总数
+ int ndim = 0; // 向量维度
+ std::streampos npt_num_pos; // 记录 npt_num 在流中的位置
+
+ std::mutex _data_stream_mutex;
+
+
+
+ private:
+ int calculate_num_pq_chunks();
+
+ public:
+ DiskannVectorIndex(std::shared_ptr<lucene::store::Directory> dir){
+ builderParameterPtr = nullptr;
+ _dir = dir;
+ // 先写入占位的 npt_num 和 ndim
+ npt_num_pos = _data_stream.tellp(); // 记录 npt_num 的偏移
+ _data_stream.seekp(static_cast<std::streampos>(npt_num_pos));
+ _data_stream.write(reinterpret_cast<const char*>(&npt_num), sizeof(npt_num));
+ _data_stream.write(reinterpret_cast<const char*>(&ndim), sizeof(ndim));
+ }
+ doris::Status add(int n, const float *vec);
+ doris::Status build();
+ void set_build_params(std::shared_ptr<BuilderParameter> params){
+ builderParameterPtr = std::static_pointer_cast<DiskannBuilderParameter>(params);
+
+ //设置dim
+ ndim = builderParameterPtr->get_dim();
+ std::streamoff pos = static_cast<std::streamoff>(npt_num_pos + std::streampos(sizeof(npt_num)));
+ _data_stream.seekp(pos, std::ios::beg);
+ _data_stream.write(reinterpret_cast<const char*>(&ndim), sizeof(ndim));
+ }
+ doris::Status search(
+ const float * query_vec,
+ int k,
+ SearchResult *result,
+ const SearchParameters* params = nullptr);
+ //把std::string的内容刷到_dir中
+ doris::Status save();
+ //负责从dir中解析内容
+ doris::Status load(VectorIndex::Metric dist_fn);
+ private:
+ doris::Status stream_write_to_output(std::stringstream &stream, lucene::store::IndexOutput *output);
+};
+
diff --git a/be/src/vector/stream_wrapper.h b/be/src/vector/stream_wrapper.h
new file mode 100644
index 0000000..16a9629
--- /dev/null
+++ b/be/src/vector/stream_wrapper.h
@@ -0,0 +1,142 @@
+#pragma once
+
+#include <CLucene.h>
+#include <CLucene/store/IndexInput.h>
+#include <CLucene/store/IndexOutput.h>
+
+class IReaderWrapper{
+ public:
+ virtual ~IReaderWrapper() = default;
+ virtual void seek(uint64_t pos) =0;
+ virtual void read(char *s, uint64_t n, uint64_t offset) =0;
+ virtual void read(char *s, uint64_t n)=0;
+};
+
+class ShareStringStreamReaderWrapper : public IReaderWrapper {
+private:
+ std::stringstream *ss; // 共享的底层stringstream
+ std::mutex *mtx; // 保护共享资源的互斥锁
+ uint64_t _offset; // 当前读取器的偏移量
+ std::streamsize last_read_count; // 最后一次读取的字节数
+
+public:
+ // 构造函数,接收共享的stringstream和互斥锁
+ ShareStringStreamReaderWrapper(std::stringstream &stream, std::mutex& mutex)
+ : ss(&stream), mtx(&mutex), _offset(0), last_read_count(0) {
+ }
+
+ // 与std::stringstream兼容的read方法
+ void read(char* s, uint64_t n) {
+ std::lock_guard<std::mutex> lock(*mtx);
+
+ // 移动到当前偏移位置
+ ss->seekg(_offset, ss->beg);
+
+ // 读取指定数量的字节
+ ss->read(s, n);
+
+ // 获取实际读取的字节数
+ last_read_count = ss->gcount();
+
+ // 更新偏移量
+ _offset += static_cast<size_t>(last_read_count);
+ }
+
+ void read(char* s, uint64_t n, uint64_t offset) {
+ std::lock_guard<std::mutex> lock(*mtx);
+ // 移动到当前偏移位置
+ ss->seekg(offset, ss->beg);
+ // 读取指定数量的字节
+ ss->read(s, n);
+ }
+
+ // 与std::stringstream兼容的seekp方法,设置偏移量到指定位置
+ void seek(uint64_t pos) {
+ _offset = pos;
+ }
+
+ ~ShareStringStreamReaderWrapper(){
+
+ }
+};
+
+//简单把std::stringstream封装下,为了diskann::load_bin有个统一的接口
+class SampleStringStreamReaderWrapper : public IReaderWrapper {
+private:
+ std::stringstream *ss;
+ size_t _offset;
+ std::streamsize last_read_count;
+public:
+ // 构造函数,接收共享的stringstream和互斥锁
+ SampleStringStreamReaderWrapper(std::stringstream &stream)
+ : ss(&stream), _offset(0), last_read_count(0){}
+
+ // 与std::stringstream兼容的read方法
+ void read(char* s, uint64_t n) {
+ ss->read(s, n);
+ }
+
+ void read(char* s, uint64_t n, uint64_t offset) {
+ // 移动到当前偏移位置
+ ss->seekg(offset, ss->beg);
+ // 读取指定数量的字节
+ ss->read(s, n);
+ }
+
+ void seek(uint64_t pos) {
+ ss->seekg(pos, ss->beg);
+ }
+ ~SampleStringStreamReaderWrapper(){
+
+ }
+};
+
+class IndexInputReaderWrapper : public IReaderWrapper{
+ private:
+ lucene::store::IndexInput *_input;
+ std::mutex mtx;
+ public:
+ IndexInputReaderWrapper(lucene::store::IndexInput *input){
+ _input = input;
+ }
+ void seek(uint64_t offset){
+ _input->seek(offset);
+ }
+ //Note that the offset here and the offset in the inputindex do not have the same meaning
+ void read(char *s, uint64_t n, uint64_t offset){
+ std::lock_guard<std::mutex> lock(mtx);
+ _input->seek(offset);
+ _input->readBytes(reinterpret_cast<uint8_t*>(s), static_cast<int32_t>(n));
+ }
+
+ void read(char *s, uint64_t n){
+ _input->readBytes(reinterpret_cast<uint8_t*>(s), static_cast<int32_t>(n));
+ }
+ ~IndexInputReaderWrapper(){
+ if(_input!=nullptr){
+ _input->close();
+ delete _input;
+ _input = nullptr;
+ }
+ }
+
+ //used for debug
+ std::stringstream readAll() {
+ std::stringstream buffer;
+ uint64_t len = _input->length();
+ _input->seek(0); // 确保从头开始读取
+
+ std::vector<char> data(len); // 创建缓冲区
+ _input->readBytes(reinterpret_cast<uint8_t*>(data.data()), static_cast<int32_t>(len));
+
+ buffer.write(data.data(), len);
+ return buffer;
+ }
+};
+
+using IReaderWrapperSPtr = std::shared_ptr<IReaderWrapper>;
+using ShareStringStreamReaderWrapperSPtr = std::shared_ptr<ShareStringStreamReaderWrapper>;
+using IndexInputReaderWrapperSPtr = std::shared_ptr<IndexInputReaderWrapper>;
+using SampleStringStreamReaderWrapperSPtr = std::shared_ptr<SampleStringStreamReaderWrapper>;
+
+
diff --git a/be/src/vector/vector_index.h b/be/src/vector/vector_index.h
new file mode 100644
index 0000000..44c20ea
--- /dev/null
+++ b/be/src/vector/vector_index.h
@@ -0,0 +1,94 @@
+#pragma once
+
+#include "common/status.h"
+#include "common/exception.h"
+struct SearchResult {
+ int rows;
+ std::vector<float> distances;
+ std::vector<int64_t> ids;
+ void *stat; //统计分析
+
+ SearchResult(){
+ rows = 0;
+ stat = nullptr;
+ }
+ float get_distance(int idx){
+ if (idx < 0 || idx >= static_cast<int>(distances.size())) {
+ throw std::out_of_range("Invalid distance index");
+ }
+ return distances[idx];
+ }
+ int64_t get_id(int idx){
+ if (idx < 0 || idx >= static_cast<int>(ids.size())) {
+ throw std::out_of_range("Invalid ID index");
+ }
+ return ids[idx];
+ }
+ void reset(){
+ rows = 0;
+ distances.clear();
+ ids.clear();
+ }
+ bool has_rows(){
+ return rows > 0;
+ }
+ int row_count(){
+ return rows;
+ }
+};
+
+struct SearchParameters {
+ virtual ~SearchParameters() {}
+};
+
+struct BuilderParameter {
+
+};
+
+class VectorIndex {
+public:
+ enum class Metric {
+ L2,
+ COSINE,
+ INNER_PRODUCT,
+ UNKNOWN
+ };
+
+ virtual doris::Status add(int n, const float *vec) =0;
+ virtual void set_build_params(std::shared_ptr<BuilderParameter> params)=0;
+ virtual doris::Status search(
+ const float * query_vec,
+ int k,
+ SearchResult *result,
+ const SearchParameters* params = nullptr) =0;
+ //virtual Status save(FileWriter* writer);
+ virtual doris::Status save()=0;
+
+ //virtual Status load(FileReader* reader);
+ virtual doris::Status load(Metric type)=0;
+ //void reset();
+ static std::string metric_to_string(Metric metric) {
+ switch (metric) {
+ case Metric::L2:
+ return "L2";
+ case Metric::COSINE:
+ return "COSINE";
+ case Metric::INNER_PRODUCT:
+ return "INNER_PRODUCT";
+ default:
+ return "UNKNOWN";
+ }
+ }
+ static Metric string_to_metric(const std::string& metric) {
+ if (metric == "l2") {
+ return Metric::L2;
+ } else if (metric == "cosine") {
+ return Metric::COSINE;
+ } else if (metric == "inner_product") {
+ return Metric::INNER_PRODUCT;
+ } else {
+ return Metric::UNKNOWN;
+ }
+ }
+ virtual ~VectorIndex() = default;
+};
\ No newline at end of file
diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4
index 40a576c..1d0a31b 100644
--- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4
+++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4
@@ -363,6 +363,7 @@
NEVER: 'NEVER';
NEXT: 'NEXT';
NGRAM_BF: 'NGRAM_BF';
+ANN: 'ANN';
NO: 'NO';
NO_USE_MV: 'NO_USE_MV';
NON_NULLABLE: 'NON_NULLABLE';
diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
index 6eedeaa..bd0f2d2 100644
--- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
+++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
@@ -767,7 +767,7 @@
| CREATE (READ ONLY)? REPOSITORY name=identifier WITH storageBackend #createRepository
| CREATE INDEX (IF NOT EXISTS)? name=identifier
ON tableName=multipartIdentifier identifierList
- (USING (BITMAP | NGRAM_BF | INVERTED))?
+ (USING (BITMAP | NGRAM_BF | INVERTED | ANN))?
properties=propertyClause? (COMMENT STRING_LITERAL)? #createIndex
| CREATE EXTERNAL? RESOURCE (IF NOT EXISTS)?
name=identifierOrText properties=propertyClause? #createResource
@@ -1384,7 +1384,7 @@
;
indexDef
- : INDEX (ifNotExists=IF NOT EXISTS)? indexName=identifier cols=identifierList (USING indexType=(BITMAP | INVERTED | NGRAM_BF))? (PROPERTIES LEFT_PAREN properties=propertyItemList RIGHT_PAREN)? (COMMENT comment=STRING_LITERAL)?
+ : INDEX (ifNotExists=IF NOT EXISTS)? indexName=identifier cols=identifierList (USING indexType=(BITMAP | INVERTED | NGRAM_BF | ANN ))? (PROPERTIES LEFT_PAREN properties=propertyItemList RIGHT_PAREN)? (COMMENT comment=STRING_LITERAL)?
;
partitionsDef
@@ -1961,6 +1961,7 @@
| NEVER
| NEXT
| NGRAM_BF
+ | ANN
| NO
| NON_NULLABLE
| NULLS
diff --git a/fe/fe-core/src/main/cup/sql_parser.cup b/fe/fe-core/src/main/cup/sql_parser.cup
index 7046a71..830ad1a 100644
--- a/fe/fe-core/src/main/cup/sql_parser.cup
+++ b/fe/fe-core/src/main/cup/sql_parser.cup
@@ -455,6 +455,7 @@
KW_ISNULL,
KW_ISOLATION,
KW_INVERTED,
+ KW_ANN,
KW_JOB,
KW_JOBS,
KW_JOIN,
@@ -4136,6 +4137,10 @@
{:
RESULT = IndexDef.IndexType.INVERTED;
:}
+ | KW_USING KW_ANN
+ {:
+ RESULT = IndexDef.IndexType.ANN;
+ :}
;
opt_or_replace ::=
@@ -8298,6 +8303,8 @@
{: RESULT = id; :}
| KW_INVERTED:id
{: RESULT = id; :}
+ | KW_ANN:id
+ {: RESULT = id; :}
| KW_ISNULL:id
{: RESULT = id; :}
| KW_ISOLATION:id
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/IndexDef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/IndexDef.java
index b239eb8..8578201 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/IndexDef.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/IndexDef.java
@@ -211,7 +211,8 @@
BITMAP,
INVERTED,
BLOOMFILTER,
- NGRAM_BF
+ NGRAM_BF,
+ ANN
}
public boolean isInvertedIndex() {
@@ -222,7 +223,7 @@
TInvertedIndexFileStorageFormat invertedIndexFileStorageFormat,
boolean disableInvertedIndexV1ForVariant) throws AnalysisException {
if (indexType == IndexType.BITMAP || indexType == IndexType.INVERTED || indexType == IndexType.BLOOMFILTER
- || indexType == IndexType.NGRAM_BF) {
+ || indexType == IndexType.NGRAM_BF || indexType == IndexType.ANN) {
String indexColName = column.getName();
caseSensitivityColumns.add(indexColName);
PrimitiveType colType = column.getDataType();
@@ -233,6 +234,10 @@
+ "invalid index: " + indexName);
}
+ if(indexType == IndexType.ANN && !colType.isArrayType() ){
+ throw new AnalysisException("ANN index column must be array type");
+ }
+
// In inverted index format v1, each subcolumn of a variant has its own index file, leading to high IOPS.
// when the subcolumn type changes, it may result in missing files, causing link file failure.
if (colType.isVariantType() && disableInvertedIndexV1ForVariant) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Index.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Index.java
index 8d4cc0e..4f09253 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Index.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Index.java
@@ -293,6 +293,10 @@
builder.setIndexType(OlapFile.IndexType.BLOOMFILTER);
break;
+ case ANN:
+ builder.setIndexType(OlapFile.IndexType.ANN);
+ break;
+
default:
throw new RuntimeException("indexType " + indexType + " is not processed in toPb");
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java
index 1304bd0..a9aec9f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java
@@ -76,6 +76,10 @@
this.indexType = IndexType.NGRAM_BF;
break;
}
+ case "ANN": {
+ this.indexType = IndexType.ANN;
+ break;
+ }
default:
throw new AnalysisException("unknown index type " + indexTypeName);
}
@@ -101,7 +105,7 @@
TInvertedIndexFileStorageFormat invertedIndexFileStorageFormat,
boolean disableInvertedIndexV1ForVariant) throws AnalysisException {
if (indexType == IndexType.BITMAP || indexType == IndexType.INVERTED
- || indexType == IndexType.BLOOMFILTER || indexType == IndexType.NGRAM_BF) {
+ || indexType == IndexType.BLOOMFILTER || indexType == IndexType.NGRAM_BF || indexType == IndexType.ANN) {
String indexColName = column.getName();
caseSensitivityCols.add(indexColName);
DataType colType = column.getType();
@@ -113,6 +117,10 @@
+ " index. " + "invalid index: " + name);
}
+ if(indexType == IndexType.ANN && !colType.isArrayType()){
+ throw new AnalysisException("Ann index column must be array type, invalid index: " + name);
+ }
+
// In inverted index format v1, each subcolumn of a variant has its own index file, leading to high IOPS.
// when the subcolumn type changes, it may result in missing files, causing link file failure.
if (colType.isVariantType() && disableInvertedIndexV1ForVariant) {
diff --git a/fe/fe-core/src/main/jflex/sql_scanner.flex b/fe/fe-core/src/main/jflex/sql_scanner.flex
index 9903675..af39d9b 100644
--- a/fe/fe-core/src/main/jflex/sql_scanner.flex
+++ b/fe/fe-core/src/main/jflex/sql_scanner.flex
@@ -123,6 +123,7 @@
keywordMap.put("binlog", new Integer(SqlParserSymbols.KW_BINLOG));
keywordMap.put("bitmap", new Integer(SqlParserSymbols.KW_BITMAP));
keywordMap.put("inverted", new Integer(SqlParserSymbols.KW_INVERTED));
+ keywordMap.put("ann", new Integer(SqlParserSymbols.KW_ANN));
keywordMap.put("bitmap_empty", new Integer(SqlParserSymbols.KW_BITMAP_EMPTY));
keywordMap.put("bitmap_union", new Integer(SqlParserSymbols.KW_BITMAP_UNION));
keywordMap.put("ngram_bf", new Integer(SqlParserSymbols.KW_NGRAM_BF));
diff --git a/gensrc/proto/olap_file.proto b/gensrc/proto/olap_file.proto
index 2c378fe..cb15db0 100644
--- a/gensrc/proto/olap_file.proto
+++ b/gensrc/proto/olap_file.proto
@@ -338,6 +338,7 @@
INVERTED = 1;
BLOOMFILTER = 2;
NGRAM_BF = 3;
+ ANN = 4;
}
enum InvertedIndexStorageFormatPB {
diff --git a/gensrc/thrift/Descriptors.thrift b/gensrc/thrift/Descriptors.thrift
index b80ce5c..a3f93db 100644
--- a/gensrc/thrift/Descriptors.thrift
+++ b/gensrc/thrift/Descriptors.thrift
@@ -155,7 +155,8 @@
BITMAP = 0,
INVERTED = 1,
BLOOMFILTER = 2,
- NGRAM_BF = 3
+ NGRAM_BF = 3,
+ ANN = 4
}
// Mapping from names defined by Avro to the enum.