From 54caa469d9dff307050ccdca29671cabee3f9922 Mon Sep 17 00:00:00 2001 From: Rafik Saliev Date: Mon, 11 May 2026 12:34:07 +0200 Subject: [PATCH 1/3] Refactor inter-query thread pool creation to validate thread count against centroids --- include/svs/index/ivf/index.h | 57 ++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index 26959c47..3119e08f 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -148,7 +148,9 @@ class IVFIndex { , cluster_{std::move(cluster)} , cluster0_{cluster_.view_cluster(0)} , distance_{std::move(distance_function)} - , inter_query_threadpool_{threads::as_threadpool(std::move(threadpool_proto))} + , inter_query_threadpool_{make_inter_query_threadpool( + std::move(threadpool_proto), centroids_.size(), logger + )} , intra_query_thread_count_{intra_query_thread_count} , logger_{std::move(logger)} { validate_thread_configuration(); @@ -572,6 +574,59 @@ class IVFIndex { ///// Initialization Methods ///// + // Generic inter-query thread pool creation which just validates the thread count + // against the number of centroids. + static InterQueryThreadPool make_inter_query_threadpool( + ThreadPoolProto proto, size_t num_centroids, const svs::logging::logger_ptr& logger + ) { + if (proto.size() > num_centroids) { + throw std::invalid_argument( + "Number of inter-query threads cannot exceed number of centroids" + ); + } + return InterQueryThreadPool{threads::as_threadpool(std::move(proto))}; + } + + // Specialization for size_t thread pool prototype to allow automatic resizing and + // logging of adjustments. + static InterQueryThreadPool make_inter_query_threadpool( + size_t proto, size_t num_centroids, const svs::logging::logger_ptr& logger + ) + requires std::is_same_v + { + if (proto > num_centroids) { + logger->warn( + "Provided thread pool has {} threads, but there are only {} centroids. " + "Reducing thread pool size to match number of centroids.", + proto, + num_centroids + ); + proto = num_centroids; + } + return InterQueryThreadPool{threads::DefaultThreadPool(proto)}; + } + + // Specialization for thread pool prototypes that support resizing. + static InterQueryThreadPool make_inter_query_threadpool( + ThreadPoolProto proto, size_t num_centroids, const svs::logging::logger_ptr& logger + ) + requires( + !std::is_same_v && + requires { proto.resize(num_centroids); } + ) + { + if (proto.size() > num_centroids) { + logger->warn( + "Provided thread pool has {} threads, but there are only {} centroids. " + "Reducing thread pool size to match number of centroids.", + proto.size(), + num_centroids + ); + proto.resize(num_centroids); + } + return InterQueryThreadPool{threads::as_threadpool(std::move(proto))}; + } + void validate_thread_configuration() { if (intra_query_thread_count_ < 1) { throw std::invalid_argument("Intra-query thread count must be at least 1"); From d05f7b7d120e2e704f7b80b902a4ffb1e47050e9 Mon Sep 17 00:00:00 2001 From: Rafik Saliev Date: Mon, 11 May 2026 12:45:16 +0200 Subject: [PATCH 2/3] SortedBuffer::can_skip() evaluation fix. Co-Authored-By: mnorris11 --- include/svs/index/ivf/sorted_buffer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/svs/index/ivf/sorted_buffer.h b/include/svs/index/ivf/sorted_buffer.h index 85939ef6..d028063b 100644 --- a/include/svs/index/ivf/sorted_buffer.h +++ b/include/svs/index/ivf/sorted_buffer.h @@ -139,7 +139,7 @@ template > class SortedBuffer { /// @brief Return ``true`` if a neighbor with the given distance can be skipped. /// bool can_skip(float distance) const { - return compare_(back().distance(), distance) && full(); + return full() && compare_(back().distance(), distance); } /// From 2ea2002c987cebee57c452dd83ead5f2078b0a08 Mon Sep 17 00:00:00 2001 From: Rafik Saliev Date: Mon, 11 May 2026 14:50:42 +0200 Subject: [PATCH 3/3] Address code review --- include/svs/index/ivf/index.h | 94 ++++++++++++------------- tests/svs/index/ivf/index.cpp | 125 ++++++++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 50 deletions(-) diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index 3119e08f..bcb2fca6 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -574,57 +574,51 @@ class IVFIndex { ///// Initialization Methods ///// - // Generic inter-query thread pool creation which just validates the thread count - // against the number of centroids. - static InterQueryThreadPool make_inter_query_threadpool( - ThreadPoolProto proto, size_t num_centroids, const svs::logging::logger_ptr& logger - ) { - if (proto.size() > num_centroids) { - throw std::invalid_argument( - "Number of inter-query threads cannot exceed number of centroids" - ); - } - return InterQueryThreadPool{threads::as_threadpool(std::move(proto))}; - } - - // Specialization for size_t thread pool prototype to allow automatic resizing and - // logging of adjustments. - static InterQueryThreadPool make_inter_query_threadpool( - size_t proto, size_t num_centroids, const svs::logging::logger_ptr& logger - ) - requires std::is_same_v - { - if (proto > num_centroids) { - logger->warn( - "Provided thread pool has {} threads, but there are only {} centroids. " - "Reducing thread pool size to match number of centroids.", - proto, - num_centroids - ); - proto = num_centroids; - } - return InterQueryThreadPool{threads::DefaultThreadPool(proto)}; - } - - // Specialization for thread pool prototypes that support resizing. - static InterQueryThreadPool make_inter_query_threadpool( - ThreadPoolProto proto, size_t num_centroids, const svs::logging::logger_ptr& logger - ) - requires( - !std::is_same_v && - requires { proto.resize(num_centroids); } - ) - { - if (proto.size() > num_centroids) { - logger->warn( - "Provided thread pool has {} threads, but there are only {} centroids. " - "Reducing thread pool size to match number of centroids.", - proto.size(), - num_centroids - ); - proto.resize(num_centroids); + static auto make_inter_query_threadpool( + ThreadPoolProto proto, size_t num_centroids, svs::logging::logger_ptr& logger + ) -> decltype(threads::as_threadpool(std::move(proto))) { + if constexpr (std::is_same_v) { + // Specialization for size_t thread pool prototype to allow automatic resizing + // and logging of adjustments. + if (proto > num_centroids) { + svs::logging::warn( + logger, + "Provided thread pool has {} threads, but there are only {} centroids. " + "Reducing thread pool size to match number of centroids.", + proto, + num_centroids + ); + proto = num_centroids; + } + } else if constexpr (requires { proto.resize(num_centroids); }) { + // Specialization for thread pool prototypes that support resizing. + if (proto.size() > num_centroids) { + svs::logging::warn( + logger, + "Provided thread pool has {} threads, but there are only {} centroids. " + "Reducing thread pool size to match number of centroids.", + proto.size(), + num_centroids + ); + proto.resize(num_centroids); + } + } else { + // Generic inter-query thread pool adjustment which just validates the thread + // count against the number of centroids. + if (proto.size() > num_centroids) { + svs::logging::error( + logger, + "Provided thread pool has {} threads, but there are only {} centroids. " + "This configuration is not supported.", + proto.size(), + num_centroids + ); + throw std::invalid_argument( + "Number of inter-query threads cannot exceed number of centroids" + ); + } } - return InterQueryThreadPool{threads::as_threadpool(std::move(proto))}; + return threads::as_threadpool(std::move(proto)); } void validate_thread_configuration() { diff --git a/tests/svs/index/ivf/index.cpp b/tests/svs/index/ivf/index.cpp index b15f2604..d59cf8ca 100644 --- a/tests/svs/index/ivf/index.cpp +++ b/tests/svs/index/ivf/index.cpp @@ -28,6 +28,7 @@ // svs #include "svs/core/data.h" #include "svs/core/distance.h" +#include "svs/core/logging.h" #include "svs/index/ivf/clustering.h" #include "svs/index/ivf/hierarchical_kmeans.h" #include "svs/lib/saveload.h" @@ -36,6 +37,9 @@ #include #include +// third-party +#include + CATCH_TEST_CASE("IVF Index Single Search", "[ivf][index][single_search]") { namespace ivf = svs::index::ivf; @@ -417,3 +421,124 @@ CATCH_TEST_CASE("IVF Index Save and Load", "[ivf][index][saveload]") { svs_test::cleanup_temp_directory(); } } + +CATCH_TEST_CASE("IVF Index Inter-Query Thread Count Boundaries", "[ivf][index][threads]") { + namespace ivf = svs::index::ivf; + + auto make_test_logger = [](std::vector& captured_logs, + std::vector& captured_levels) { + auto callback_sink = std::make_shared( + [&captured_logs, &captured_levels](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + captured_levels.push_back(svs::logging::detail::from_spdlog(msg.level)); + } + ); + callback_sink->set_level(spdlog::level::trace); + auto logger = + std::make_shared("ivf_threads_test_logger", callback_sink); + logger->set_level(spdlog::level::trace); + return logger; + }; + + auto build_components = []() { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto distance = svs::distance::DistanceL2(); + auto build_params = ivf::IVFBuildParameters(2, 5, false); + auto build_threadpool = svs::threads::SequentialThreadPool(); + + auto clustering = ivf::build_clustering( + build_params, data, distance, build_threadpool, false + ); + + auto centroids = clustering.centroids(); + using Idx = uint32_t; + auto cluster = ivf::DenseClusteredDataset( + clustering, data, build_threadpool, svs::lib::Allocator() + ); + + return std::make_tuple( + std::move(centroids), std::move(cluster), std::move(distance) + ); + }; + + CATCH_SECTION("size_t thread prototype is clamped and warns") { + auto [centroids, cluster, distance] = build_components(); + CATCH_REQUIRE(centroids.size() == 2); + + std::vector logs; + std::vector levels; + auto logger = make_test_logger(logs, levels); + + using IndexType = ivf:: + IVFIndex; + + IndexType index( + std::move(centroids), std::move(cluster), distance, size_t{4}, 1, logger + ); + + CATCH_REQUIRE(index.get_num_threads() == 2); + CATCH_REQUIRE( + std::find(levels.begin(), levels.end(), svs::logging::Level::Warn) != + levels.end() + ); + } + + CATCH_SECTION("resizable thread prototype is clamped and warns") { + auto [centroids, cluster, distance] = build_components(); + CATCH_REQUIRE(centroids.size() == 2); + + std::vector logs; + std::vector levels; + auto logger = make_test_logger(logs, levels); + + auto threadpool_proto = svs::threads::NativeThreadPool(4); + using IndexType = ivf::IVFIndex< + decltype(centroids), + decltype(cluster), + decltype(distance), + decltype(threadpool_proto)>; + + IndexType index( + std::move(centroids), + std::move(cluster), + distance, + std::move(threadpool_proto), + 1, + logger + ); + + CATCH_REQUIRE(index.get_num_threads() == 2); + CATCH_REQUIRE( + std::find(levels.begin(), levels.end(), svs::logging::Level::Warn) != + levels.end() + ); + } + + CATCH_SECTION("non-resizable thread prototype throws") { + auto [centroids, cluster, distance] = build_components(); + CATCH_REQUIRE(centroids.size() == 2); + + std::vector logs; + std::vector levels; + auto logger = make_test_logger(logs, levels); + + auto threadpool_proto = svs::threads::QueueThreadPoolWrapper(4); + using IndexType = ivf::IVFIndex< + decltype(centroids), + decltype(cluster), + decltype(distance), + decltype(threadpool_proto)>; + + CATCH_REQUIRE_THROWS_AS( + IndexType( + std::move(centroids), + std::move(cluster), + distance, + std::move(threadpool_proto), + 1, + logger + ), + std::invalid_argument + ); + } +}