Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 0 additions & 102 deletions include/svs/index/vamana/prune.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,6 @@ void heuristic_prune_neighbors(

auto pruned = std::vector<PruneState>(poolsize, PruneState::Available);
float current_alpha = 1.0f;
float anchor_dist = 0.0f;
bool anchor_set = false;
bool all_duplicates = true;
while (result.size() < max_result_size && !cmp(alpha, current_alpha)) {
size_t start = 0;
while (result.size() < max_result_size && start < poolsize) {
Expand All @@ -148,16 +145,6 @@ void heuristic_prune_neighbors(
const auto& query = accessor(dataset, id);
distance::maybe_fix_argument(distance_function, query);
result.push_back(detail::construct_as(lib::Type<I>(), pool[start]));

if (all_duplicates) {
if (!anchor_set) {
anchor_dist = pool[start].distance();
anchor_set = true;
} else if (pool[start].distance() != anchor_dist) {
all_duplicates = false;
}
}

for (size_t t = start + 1; t < poolsize; ++t) {
if (excluded(pruned[t])) {
continue;
Expand All @@ -184,44 +171,6 @@ void heuristic_prune_neighbors(
}
current_alpha *= alpha;
}

// Add a diversity edge if a duplicate cluster is detected.
// A "cluster" requires at least 2 kept candidates sharing the same
// distance; a single retained neighbor is not a cluster and must not
// be replaced (doing so would discard the only true nearest-neighbor
// edge for that node).
if (all_duplicates && anchor_set && result.size() >= 2) {
auto result_id = [](const I& r) -> size_t {
if constexpr (std::integral<I>) {
return static_cast<size_t>(r);
} else {
return static_cast<size_t>(r.id());
}
};
for (size_t t = 0; t < poolsize; ++t) {
const auto& candidate = pool[t];
auto cid = candidate.id();
if (cid == current_node_id || candidate.distance() == anchor_dist) {
continue;
}
bool in_result = false;
for (const auto& r : result) {
if (result_id(r) == static_cast<size_t>(cid)) {
in_result = true;
break;
}
}
assert(
!in_result &&
"Candidate with non-anchor distance should not already be in result"
);
if (in_result) {
continue;
}
result.back() = detail::construct_as(lib::Type<I>(), candidate);
break;
}
}
}

template <
Expand Down Expand Up @@ -254,9 +203,6 @@ void heuristic_prune_neighbors(
std::vector<float> pruned(poolsize, type_traits::tombstone_v<float, decltype(cmp)>);

float current_alpha = 1.0f;
float anchor_dist = 0.0f;
bool anchor_set = false;
bool all_duplicates = true;
while (result.size() < max_result_size && !cmp(alpha, current_alpha)) {
size_t start = 0;
while (result.size() < max_result_size && start < poolsize) {
Expand All @@ -272,16 +218,6 @@ void heuristic_prune_neighbors(
const auto& query = accessor(dataset, id);
distance::maybe_fix_argument(distance_function, query);
result.push_back(detail::construct_as(lib::Type<I>(), pool[start]));

if (all_duplicates) {
if (!anchor_set) {
anchor_dist = pool[start].distance();
anchor_set = true;
} else if (pool[start].distance() != anchor_dist) {
all_duplicates = false;
}
}

for (size_t t = start + 1; t < poolsize; ++t) {
if (cmp(current_alpha, pruned[t])) {
continue;
Expand All @@ -300,44 +236,6 @@ void heuristic_prune_neighbors(
}
current_alpha *= alpha;
}

// Add a diversity edge if a duplicate cluster is detected.
// A "cluster" requires at least 2 kept candidates sharing the same
// distance; a single retained neighbor is not a cluster and must not
// be replaced (doing so would discard the only true nearest-neighbor
// edge for that node).
if (all_duplicates && anchor_set && result.size() >= 2) {
auto result_id = [](const I& r) -> size_t {
if constexpr (std::integral<I>) {
return static_cast<size_t>(r);
} else {
return static_cast<size_t>(r.id());
}
};
for (size_t t = 0; t < poolsize; ++t) {
const auto& candidate = pool[t];
auto cid = candidate.id();
if (cid == current_node_id || candidate.distance() == anchor_dist) {
continue;
}
bool in_result = false;
for (const auto& r : result) {
if (result_id(r) == static_cast<size_t>(cid)) {
in_result = true;
break;
}
}
assert(
!in_result &&
"Candidate with non-anchor distance should not already be in result"
);
if (in_result) {
continue;
}
result.back() = detail::construct_as(lib::Type<I>(), candidate);
break;
}
}
}

///
Expand Down
65 changes: 0 additions & 65 deletions tests/svs/index/vamana/prune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
// header under test
#include "svs/index/vamana/prune.h"

// core
#include "svs/core/data/simple.h"
#include "svs/core/distance/euclidean.h"

// catch2
#include "catch2/catch_test_macros.hpp"

Expand Down Expand Up @@ -50,65 +46,4 @@ CATCH_TEST_CASE("Pruning", "[index][vamana]") {
CATCH_REQUIRE(v::excluded(v::PruneState::Pruned) == true);
}
}

CATCH_SECTION("Duplicate Cluster Trap") {
auto data = svs::data::SimpleData<float>(6, 4);
auto d0 = std::vector<float>{1.0f, 1.0f, 1.0f, 1.0f};
auto d4 = std::vector<float>{2.0f, 1.0f, 1.0f, 1.0f};
auto d5 = std::vector<float>{1.5f, 1.0f, 1.0f, 1.0f};

for (size_t i = 0; i < 4; ++i) {
data.set_datum(i, d0);
}
data.set_datum(4, d4);
data.set_datum(5, d5);

auto dist = svs::distance::DistanceL2();
auto accessor = svs::data::GetDatumAccessor{};

std::vector<svs::Neighbor<size_t>> pool = {
{size_t{0}, 0.0f},
{size_t{1}, 0.0f},
{size_t{2}, 0.0f},
{size_t{3}, 0.0f},
{size_t{4}, 1.0f}};

CATCH_SECTION("Iterative Strategy Fix") {
std::vector<svs::Neighbor<size_t>> result;
v::heuristic_prune_neighbors(
v::IterativePruneStrategy{},
2,
1.3f,
data,
accessor,
dist,
size_t{5},
std::span<const svs::Neighbor<size_t>>(pool),
result
);

CATCH_REQUIRE(result.size() == 2);
CATCH_REQUIRE(result[0].id() == 0);
CATCH_REQUIRE(result[1].id() == 4);
}

CATCH_SECTION("Progressive Strategy Fix") {
std::vector<svs::Neighbor<size_t>> result;
v::heuristic_prune_neighbors(
v::ProgressivePruneStrategy{},
2,
1.3f,
data,
accessor,
dist,
size_t{5},
std::span<const svs::Neighbor<size_t>>(pool),
result
);

CATCH_REQUIRE(result.size() == 2);
CATCH_REQUIRE(result[0].id() == 0);
CATCH_REQUIRE(result[1].id() == 4);
}
}
}
Loading