diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index 26959c47..bcb2fca6 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -148,7 +148,9 @@ class IVFIndex { , cluster_{std::move(cluster)} , cluster0_{cluster_.view_cluster(0)} , distance_{std::move(distance_function)} - , inter_query_threadpool_{threads::as_threadpool(std::move(threadpool_proto))} + , inter_query_threadpool_{make_inter_query_threadpool( + std::move(threadpool_proto), centroids_.size(), logger + )} , intra_query_thread_count_{intra_query_thread_count} , logger_{std::move(logger)} { validate_thread_configuration(); @@ -572,6 +574,53 @@ class IVFIndex { ///// Initialization Methods ///// + static auto make_inter_query_threadpool( + ThreadPoolProto proto, size_t num_centroids, svs::logging::logger_ptr& logger + ) -> decltype(threads::as_threadpool(std::move(proto))) { + if constexpr (std::is_same_v) { + // Specialization for size_t thread pool prototype to allow automatic resizing + // and logging of adjustments. + if (proto > num_centroids) { + svs::logging::warn( + logger, + "Provided thread pool has {} threads, but there are only {} centroids. " + "Reducing thread pool size to match number of centroids.", + proto, + num_centroids + ); + proto = num_centroids; + } + } else if constexpr (requires { proto.resize(num_centroids); }) { + // Specialization for thread pool prototypes that support resizing. + if (proto.size() > num_centroids) { + svs::logging::warn( + logger, + "Provided thread pool has {} threads, but there are only {} centroids. " + "Reducing thread pool size to match number of centroids.", + proto.size(), + num_centroids + ); + proto.resize(num_centroids); + } + } else { + // Generic inter-query thread pool adjustment which just validates the thread + // count against the number of centroids. + if (proto.size() > num_centroids) { + svs::logging::error( + logger, + "Provided thread pool has {} threads, but there are only {} centroids. " + "This configuration is not supported.", + proto.size(), + num_centroids + ); + throw std::invalid_argument( + "Number of inter-query threads cannot exceed number of centroids" + ); + } + } + return threads::as_threadpool(std::move(proto)); + } + void validate_thread_configuration() { if (intra_query_thread_count_ < 1) { throw std::invalid_argument("Intra-query thread count must be at least 1"); diff --git a/include/svs/index/ivf/sorted_buffer.h b/include/svs/index/ivf/sorted_buffer.h index 85939ef6..d028063b 100644 --- a/include/svs/index/ivf/sorted_buffer.h +++ b/include/svs/index/ivf/sorted_buffer.h @@ -139,7 +139,7 @@ template > class SortedBuffer { /// @brief Return ``true`` if a neighbor with the given distance can be skipped. /// bool can_skip(float distance) const { - return compare_(back().distance(), distance) && full(); + return full() && compare_(back().distance(), distance); } /// diff --git a/tests/svs/index/ivf/index.cpp b/tests/svs/index/ivf/index.cpp index b15f2604..d59cf8ca 100644 --- a/tests/svs/index/ivf/index.cpp +++ b/tests/svs/index/ivf/index.cpp @@ -28,6 +28,7 @@ // svs #include "svs/core/data.h" #include "svs/core/distance.h" +#include "svs/core/logging.h" #include "svs/index/ivf/clustering.h" #include "svs/index/ivf/hierarchical_kmeans.h" #include "svs/lib/saveload.h" @@ -36,6 +37,9 @@ #include #include +// third-party +#include + CATCH_TEST_CASE("IVF Index Single Search", "[ivf][index][single_search]") { namespace ivf = svs::index::ivf; @@ -417,3 +421,124 @@ CATCH_TEST_CASE("IVF Index Save and Load", "[ivf][index][saveload]") { svs_test::cleanup_temp_directory(); } } + +CATCH_TEST_CASE("IVF Index Inter-Query Thread Count Boundaries", "[ivf][index][threads]") { + namespace ivf = svs::index::ivf; + + auto make_test_logger = [](std::vector& captured_logs, + std::vector& captured_levels) { + auto callback_sink = std::make_shared( + [&captured_logs, &captured_levels](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + captured_levels.push_back(svs::logging::detail::from_spdlog(msg.level)); + } + ); + callback_sink->set_level(spdlog::level::trace); + auto logger = + std::make_shared("ivf_threads_test_logger", callback_sink); + logger->set_level(spdlog::level::trace); + return logger; + }; + + auto build_components = []() { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto distance = svs::distance::DistanceL2(); + auto build_params = ivf::IVFBuildParameters(2, 5, false); + auto build_threadpool = svs::threads::SequentialThreadPool(); + + auto clustering = ivf::build_clustering( + build_params, data, distance, build_threadpool, false + ); + + auto centroids = clustering.centroids(); + using Idx = uint32_t; + auto cluster = ivf::DenseClusteredDataset( + clustering, data, build_threadpool, svs::lib::Allocator() + ); + + return std::make_tuple( + std::move(centroids), std::move(cluster), std::move(distance) + ); + }; + + CATCH_SECTION("size_t thread prototype is clamped and warns") { + auto [centroids, cluster, distance] = build_components(); + CATCH_REQUIRE(centroids.size() == 2); + + std::vector logs; + std::vector levels; + auto logger = make_test_logger(logs, levels); + + using IndexType = ivf:: + IVFIndex; + + IndexType index( + std::move(centroids), std::move(cluster), distance, size_t{4}, 1, logger + ); + + CATCH_REQUIRE(index.get_num_threads() == 2); + CATCH_REQUIRE( + std::find(levels.begin(), levels.end(), svs::logging::Level::Warn) != + levels.end() + ); + } + + CATCH_SECTION("resizable thread prototype is clamped and warns") { + auto [centroids, cluster, distance] = build_components(); + CATCH_REQUIRE(centroids.size() == 2); + + std::vector logs; + std::vector levels; + auto logger = make_test_logger(logs, levels); + + auto threadpool_proto = svs::threads::NativeThreadPool(4); + using IndexType = ivf::IVFIndex< + decltype(centroids), + decltype(cluster), + decltype(distance), + decltype(threadpool_proto)>; + + IndexType index( + std::move(centroids), + std::move(cluster), + distance, + std::move(threadpool_proto), + 1, + logger + ); + + CATCH_REQUIRE(index.get_num_threads() == 2); + CATCH_REQUIRE( + std::find(levels.begin(), levels.end(), svs::logging::Level::Warn) != + levels.end() + ); + } + + CATCH_SECTION("non-resizable thread prototype throws") { + auto [centroids, cluster, distance] = build_components(); + CATCH_REQUIRE(centroids.size() == 2); + + std::vector logs; + std::vector levels; + auto logger = make_test_logger(logs, levels); + + auto threadpool_proto = svs::threads::QueueThreadPoolWrapper(4); + using IndexType = ivf::IVFIndex< + decltype(centroids), + decltype(cluster), + decltype(distance), + decltype(threadpool_proto)>; + + CATCH_REQUIRE_THROWS_AS( + IndexType( + std::move(centroids), + std::move(cluster), + distance, + std::move(threadpool_proto), + 1, + logger + ), + std::invalid_argument + ); + } +}