Skip to content

Commit 256b23d

Browse files
committed
refactor: heel_f64x8 uses crate::simd::F64x8 polyfill, add SIMD cosine
Replace raw std::arch intrinsics with crate::simd::F64x8 polyfill. Automatic dispatch: AVX-512 (native __m512d) → AVX2 (2×__m256d) → scalar. Consumer writes crate::simd::F64x8 — polyfill handles tier selection. Added SIMD cosine kernels using F64x8 FMA: cosine_f64_simd() — single-pass dot + norm_a + norm_b via F64x8 cosine_f32_to_f64_simd() — f32 input, f64 precision cosine dot_f64_simd() — F64x8 FMA dot product on f64 slices sum_sq_f64_simd() — F64x8 sum of squares 12 tests passing (6 HEEL + 6 cosine). https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp
1 parent 5277b29 commit 256b23d

1 file changed

Lines changed: 221 additions & 91 deletions

File tree

src/hpc/heel_f64x8.rs

Lines changed: 221 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,94 +3,34 @@
33
//! p64 has 8 HEEL planes (u64 each). For weighted f64 distance computation,
44
//! each plane produces one f64 distance value → 8 values = one F64x8 register.
55
//!
6-
//! This module provides the SIMD kernel; p64-bridge calls it.
7-
//! ndarray = hardware acceleration, consumers use the kernel.
8-
//!
9-
//! Dispatch: AVX-512 (native __m512d) → AVX2 (2×__m256d) → scalar.
10-
//! LazyLock selects at startup.
11-
12-
use std::sync::LazyLock;
13-
14-
/// Kernel signature: 8 distances in, weighted sum out.
15-
/// `distances`: 8 f64 values (one per HEEL plane).
16-
/// `weights`: 8 f64 weights (per-expert importance).
17-
/// Returns: weighted sum = Σ(distance[i] × weight[i]).
18-
type HeelF64x8DotFn = unsafe fn(&[f64; 8], &[f64; 8]) -> f64;
19-
20-
#[cfg(target_arch = "x86_64")]
21-
#[target_feature(enable = "avx512f")]
22-
unsafe fn heel_dot_avx512(a: &[f64; 8], b: &[f64; 8]) -> f64 {
23-
use std::arch::x86_64::*;
24-
let va = _mm512_loadu_pd(a.as_ptr());
25-
let vb = _mm512_loadu_pd(b.as_ptr());
26-
let prod = _mm512_mul_pd(va, vb);
27-
_mm512_reduce_add_pd(prod)
28-
}
29-
30-
#[cfg(target_arch = "x86_64")]
31-
#[target_feature(enable = "avx2")]
32-
unsafe fn heel_dot_avx2(a: &[f64; 8], b: &[f64; 8]) -> f64 {
33-
use std::arch::x86_64::*;
34-
// 2 passes of 4 lanes
35-
let va0 = _mm256_loadu_pd(a.as_ptr());
36-
let vb0 = _mm256_loadu_pd(b.as_ptr());
37-
let p0 = _mm256_mul_pd(va0, vb0);
38-
39-
let va1 = _mm256_loadu_pd(a[4..].as_ptr());
40-
let vb1 = _mm256_loadu_pd(b[4..].as_ptr());
41-
let p1 = _mm256_mul_pd(va1, vb1);
42-
43-
let sum = _mm256_add_pd(p0, p1);
44-
// Horizontal sum of 4 f64
45-
let hi = _mm256_extractf128_pd(sum, 1);
46-
let lo = _mm256_castpd256_pd128(sum);
47-
let pair = _mm_add_pd(lo, hi);
48-
let hi64 = _mm_unpackhi_pd(pair, pair);
49-
let result = _mm_add_sd(pair, hi64);
50-
_mm_cvtsd_f64(result)
51-
}
52-
53-
fn heel_dot_scalar(a: &[f64; 8], b: &[f64; 8]) -> f64 {
54-
let mut sum = 0.0f64;
55-
for i in 0..8 {
56-
sum += a[i] * b[i];
57-
}
58-
sum
59-
}
6+
//! Uses `crate::simd::F64x8` polyfill — automatic dispatch:
7+
//! AVX-512: native __m512d (one register)
8+
//! AVX2: 2× __m256d (two registers, same API)
9+
//! Scalar: [f64; 8] fallback
10+
//! Consumer writes `crate::simd::F64x8`. The polyfill handles the rest.
6011
61-
static HEEL_DOT_KERNEL: LazyLock<HeelF64x8DotFn> = LazyLock::new(|| {
62-
#[cfg(target_arch = "x86_64")]
63-
{
64-
if is_x86_feature_detected!("avx512f") {
65-
return heel_dot_avx512 as HeelF64x8DotFn;
66-
}
67-
if is_x86_feature_detected!("avx2") {
68-
return heel_dot_avx2 as HeelF64x8DotFn;
69-
}
70-
}
71-
heel_dot_scalar as HeelF64x8DotFn
72-
});
12+
use crate::simd::F64x8;
7313

7414
/// Compute weighted dot product of 8 HEEL plane distances.
7515
///
7616
/// `distances[i]` = distance for HEEL plane i.
7717
/// `weights[i]` = importance weight for plane i.
7818
/// Returns: Σ(distances[i] × weights[i]).
7919
///
80-
/// One SIMD pass on AVX-512 (single `vmulpd` + `vreducepd`).
81-
/// Two passes on AVX2. Scalar fallback for non-x86.
20+
/// One F64x8 multiply + reduce_sum. On AVX-512: single vmulpd + vreducepd.
21+
/// On AVX2: 2× vmulpd + 2× haddpd. Scalar: 8 multiplies + sum.
8222
#[inline]
8323
pub fn heel_weighted_distance(distances: &[f64; 8], weights: &[f64; 8]) -> f64 {
84-
unsafe { HEEL_DOT_KERNEL(distances, weights) }
24+
let vd = F64x8::from_slice(distances);
25+
let vw = F64x8::from_slice(weights);
26+
(vd * vw).reduce_sum()
8527
}
8628

8729
/// Compute L1-like distance across 8 HEEL planes.
8830
///
8931
/// For each plane i: distance[i] = popcount(a[i] XOR b[i]) as f64.
90-
/// Then weighted sum via F64x8 dot product.
91-
///
92-
/// This converts binary Hamming distances to f64 for weighted combination,
93-
/// where each plane's contribution is scaled by expert importance.
32+
/// This is Hamming on binary HEEL planes — valid because HEEL planes
33+
/// ARE uniform binary data (unlike bgz17 i16 which must use L1).
9434
pub fn heel_plane_distances(a: &[u64; 8], b: &[u64; 8]) -> [f64; 8] {
9535
let mut dists = [0.0f64; 8];
9636
for i in 0..8 {
@@ -99,7 +39,7 @@ pub fn heel_plane_distances(a: &[u64; 8], b: &[u64; 8]) -> [f64; 8] {
9939
dists
10040
}
10141

102-
/// Full pipeline: 8 HEEL planes → Hamming per plane → weighted F64x8 dot → scalar distance.
42+
/// Full pipeline: 8 HEEL planes → Hamming per plane → weighted F64x8 dot → scalar.
10343
#[inline]
10444
pub fn heel_weighted_hamming(
10545
a_planes: &[u64; 8],
@@ -117,20 +57,151 @@ pub const UNIFORM_WEIGHTS: [f64; 8] = [1.0; 8];
11757
/// Contradiction plane (index 7) gets 0.5× weight.
11858
pub const HEEL_7PLUS1_WEIGHTS: [f64; 8] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5];
11959

60+
// ═══════════════════════════════════════════════════════════════════════════
61+
// SIMD cosine similarity via F64x8 — for CLAM cosine clustering
62+
// ═══════════════════════════════════════════════════════════════════════════
63+
64+
/// SIMD dot product on f64 slices via F64x8.
65+
///
66+
/// Processes 8 elements per iteration. Remainder handled scalar.
67+
/// Used by cosine_simd as the inner kernel.
68+
pub fn dot_f64_simd(a: &[f64], b: &[f64]) -> f64 {
69+
let n = a.len().min(b.len());
70+
let chunks = n / 8;
71+
let remainder = n % 8;
72+
73+
let mut acc = F64x8::splat(0.0);
74+
for i in 0..chunks {
75+
let va = F64x8::from_slice(&a[i * 8..]);
76+
let vb = F64x8::from_slice(&b[i * 8..]);
77+
acc = va.mul_add(vb, acc); // acc = va * vb + acc (FMA)
78+
}
79+
let mut sum = acc.reduce_sum();
80+
81+
// Scalar remainder
82+
let offset = chunks * 8;
83+
for i in 0..remainder {
84+
sum += a[offset + i] * b[offset + i];
85+
}
86+
sum
87+
}
88+
89+
/// SIMD sum of squares via F64x8.
90+
pub fn sum_sq_f64_simd(a: &[f64]) -> f64 {
91+
let n = a.len();
92+
let chunks = n / 8;
93+
let remainder = n % 8;
94+
95+
let mut acc = F64x8::splat(0.0);
96+
for i in 0..chunks {
97+
let va = F64x8::from_slice(&a[i * 8..]);
98+
acc = va.mul_add(va, acc); // acc = va * va + acc
99+
}
100+
let mut sum = acc.reduce_sum();
101+
102+
let offset = chunks * 8;
103+
for i in 0..remainder {
104+
sum += a[offset + i] * a[offset + i];
105+
}
106+
sum
107+
}
108+
109+
/// SIMD cosine similarity on f64 slices.
110+
///
111+
/// Computes dot(a,b) / (||a|| × ||b||) using F64x8 FMA.
112+
/// Single pass: accumulates dot, norm_a, norm_b simultaneously.
113+
pub fn cosine_f64_simd(a: &[f64], b: &[f64]) -> f64 {
114+
let n = a.len().min(b.len());
115+
let chunks = n / 8;
116+
let remainder = n % 8;
117+
118+
let mut dot_acc = F64x8::splat(0.0);
119+
let mut na_acc = F64x8::splat(0.0);
120+
let mut nb_acc = F64x8::splat(0.0);
121+
122+
for i in 0..chunks {
123+
let va = F64x8::from_slice(&a[i * 8..]);
124+
let vb = F64x8::from_slice(&b[i * 8..]);
125+
dot_acc = va.mul_add(vb, dot_acc); // dot += a*b
126+
na_acc = va.mul_add(va, na_acc); // na += a*a
127+
nb_acc = vb.mul_add(vb, nb_acc); // nb += b*b
128+
}
129+
130+
let mut dot = dot_acc.reduce_sum();
131+
let mut na = na_acc.reduce_sum();
132+
let mut nb = nb_acc.reduce_sum();
133+
134+
let offset = chunks * 8;
135+
for i in 0..remainder {
136+
dot += a[offset + i] * b[offset + i];
137+
na += a[offset + i] * a[offset + i];
138+
nb += b[offset + i] * b[offset + i];
139+
}
140+
141+
let denom = (na * nb).sqrt();
142+
if denom < 1e-12 { 0.0 } else { dot / denom }
143+
}
144+
145+
/// SIMD cosine similarity on f32 slices (converts to f64 internally for precision).
146+
///
147+
/// For hot paths where input is f32 but you need f64 precision cosine.
148+
/// Converts 8 f32 → 8 f64 per chunk via scalar widening, then F64x8 FMA.
149+
pub fn cosine_f32_to_f64_simd(a: &[f32], b: &[f32]) -> f64 {
150+
let n = a.len().min(b.len());
151+
let chunks = n / 8;
152+
let remainder = n % 8;
153+
154+
let mut dot_acc = F64x8::splat(0.0);
155+
let mut na_acc = F64x8::splat(0.0);
156+
let mut nb_acc = F64x8::splat(0.0);
157+
158+
let mut buf_a = [0.0f64; 8];
159+
let mut buf_b = [0.0f64; 8];
160+
161+
for i in 0..chunks {
162+
let off = i * 8;
163+
for j in 0..8 {
164+
buf_a[j] = a[off + j] as f64;
165+
buf_b[j] = b[off + j] as f64;
166+
}
167+
let va = F64x8::from_slice(&buf_a);
168+
let vb = F64x8::from_slice(&buf_b);
169+
dot_acc = va.mul_add(vb, dot_acc);
170+
na_acc = va.mul_add(va, na_acc);
171+
nb_acc = vb.mul_add(vb, nb_acc);
172+
}
173+
174+
let mut dot = dot_acc.reduce_sum();
175+
let mut na = na_acc.reduce_sum();
176+
let mut nb = nb_acc.reduce_sum();
177+
178+
let offset = chunks * 8;
179+
for i in 0..remainder {
180+
let ai = a[offset + i] as f64;
181+
let bi = b[offset + i] as f64;
182+
dot += ai * bi;
183+
na += ai * ai;
184+
nb += bi * bi;
185+
}
186+
187+
let denom = (na * nb).sqrt();
188+
if denom < 1e-12 { 0.0 } else { dot / denom }
189+
}
190+
120191
#[cfg(test)]
121192
mod tests {
122193
use super::*;
123194

124195
#[test]
125-
fn dot_product_basic() {
196+
fn heel_dot_basic() {
126197
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
127-
let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
198+
let b = [1.0; 8];
128199
let result = heel_weighted_distance(&a, &b);
129-
assert!((result - 36.0).abs() < 1e-10, "1+2+3+4+5+6+7+8 = 36, got {}", result);
200+
assert!((result - 36.0).abs() < 1e-10, "1+2+...+8 = 36, got {}", result);
130201
}
131202

132203
#[test]
133-
fn dot_product_weighted() {
204+
fn heel_dot_weighted() {
134205
let distances = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
135206
let weights = [2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5];
136207
let result = heel_weighted_distance(&distances, &weights);
@@ -141,38 +212,97 @@ mod tests {
141212
fn plane_distances_self_zero() {
142213
let planes = [0x1234u64; 8];
143214
let dists = heel_plane_distances(&planes, &planes);
144-
for d in &dists {
145-
assert_eq!(*d, 0.0);
146-
}
215+
for d in &dists { assert_eq!(*d, 0.0); }
147216
}
148217

149218
#[test]
150219
fn plane_distances_opposite() {
151220
let a = [0u64; 8];
152221
let b = [u64::MAX; 8];
153222
let dists = heel_plane_distances(&a, &b);
154-
for d in &dists {
155-
assert_eq!(*d, 64.0);
156-
}
223+
for d in &dists { assert_eq!(*d, 64.0); }
157224
}
158225

159226
#[test]
160227
fn full_pipeline_uniform() {
161228
let a = [0xFFFF_0000_FFFF_0000u64; 8];
162229
let b = [0x0000_FFFF_0000_FFFFu64; 8];
163230
let d = heel_weighted_hamming(&a, &b, &UNIFORM_WEIGHTS);
164-
// Each plane: all bits differ = 64
165-
assert!((d - 64.0 * 8.0).abs() < 1e-10, "8 planes × 64 bits = 512, got {}", d);
231+
assert!((d - 512.0).abs() < 1e-10, "8×64 = 512, got {}", d);
166232
}
167233

168234
#[test]
169235
fn seven_plus_one_weights() {
170236
let a = [0u64; 8];
171237
let b = [u64::MAX; 8];
172-
let d_uniform = heel_weighted_hamming(&a, &b, &UNIFORM_WEIGHTS);
173-
let d_7plus1 = heel_weighted_hamming(&a, &b, &HEEL_7PLUS1_WEIGHTS);
174-
// 7+1: plane 7 at 0.5× = 7×64 + 0.5×64 = 480 vs 512
175-
assert!((d_uniform - 512.0).abs() < 1e-10);
176-
assert!((d_7plus1 - 480.0).abs() < 1e-10, "7×64 + 0.5×64 = 480, got {}", d_7plus1);
238+
let d = heel_weighted_hamming(&a, &b, &HEEL_7PLUS1_WEIGHTS);
239+
assert!((d - 480.0).abs() < 1e-10, "7×64 + 0.5×64 = 480, got {}", d);
240+
}
241+
242+
// ── SIMD cosine tests ───────────────────────────────────────────
243+
244+
#[test]
245+
fn cosine_identical() {
246+
let a: Vec<f64> = (0..1024).map(|i| (i as f64 * 0.01).sin()).collect();
247+
let c = cosine_f64_simd(&a, &a);
248+
assert!((c - 1.0).abs() < 1e-10, "self-cosine should be 1.0: {}", c);
249+
}
250+
251+
#[test]
252+
fn cosine_opposite() {
253+
let a: Vec<f64> = (0..256).map(|i| i as f64 * 0.1).collect();
254+
let b: Vec<f64> = a.iter().map(|v| -v).collect();
255+
let c = cosine_f64_simd(&a, &b);
256+
assert!((c - (-1.0)).abs() < 1e-10, "opposite should be -1.0: {}", c);
257+
}
258+
259+
#[test]
260+
fn cosine_orthogonal() {
261+
let mut a = vec![0.0f64; 256];
262+
let mut b = vec![0.0f64; 256];
263+
a[0] = 1.0;
264+
b[1] = 1.0;
265+
let c = cosine_f64_simd(&a, &b);
266+
assert!(c.abs() < 1e-10, "orthogonal should be 0.0: {}", c);
267+
}
268+
269+
#[test]
270+
fn cosine_matches_scalar() {
271+
let a: Vec<f64> = (0..333).map(|i| (i as f64 * 0.037).sin()).collect();
272+
let b: Vec<f64> = (0..333).map(|i| (i as f64 * 0.023).cos()).collect();
273+
274+
let simd_cos = cosine_f64_simd(&a, &b);
275+
276+
// Scalar reference
277+
let dot: f64 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
278+
let na: f64 = a.iter().map(|x| x * x).sum();
279+
let nb: f64 = b.iter().map(|x| x * x).sum();
280+
let scalar_cos = dot / (na * nb).sqrt();
281+
282+
assert!((simd_cos - scalar_cos).abs() < 1e-10,
283+
"SIMD {:.12} vs scalar {:.12}", simd_cos, scalar_cos);
284+
}
285+
286+
#[test]
287+
fn cosine_f32_matches_f64() {
288+
let a_f32: Vec<f32> = (0..500).map(|i| (i as f32 * 0.01).sin()).collect();
289+
let b_f32: Vec<f32> = (0..500).map(|i| (i as f32 * 0.02).cos()).collect();
290+
291+
let a_f64: Vec<f64> = a_f32.iter().map(|&v| v as f64).collect();
292+
let b_f64: Vec<f64> = b_f32.iter().map(|&v| v as f64).collect();
293+
294+
let cos_f64 = cosine_f64_simd(&a_f64, &b_f64);
295+
let cos_f32 = cosine_f32_to_f64_simd(&a_f32, &b_f32);
296+
297+
assert!((cos_f64 - cos_f32).abs() < 1e-6,
298+
"f32 {:.10} vs f64 {:.10}", cos_f32, cos_f64);
299+
}
300+
301+
#[test]
302+
fn dot_f64_simd_basic() {
303+
let a = [1.0f64; 24];
304+
let b = [2.0f64; 24];
305+
let d = dot_f64_simd(&a, &b);
306+
assert!((d - 48.0).abs() < 1e-10, "24×2 = 48, got {}", d);
177307
}
178308
}

0 commit comments

Comments
 (0)