33//! p64 has 8 HEEL planes (u64 each). For weighted f64 distance computation,
44//! each plane produces one f64 distance value → 8 values = one F64x8 register.
55//!
6- //! This module provides the SIMD kernel; p64-bridge calls it.
7- //! ndarray = hardware acceleration, consumers use the kernel.
8- //!
9- //! Dispatch: AVX-512 (native __m512d) → AVX2 (2×__m256d) → scalar.
10- //! LazyLock selects at startup.
11-
12- use std:: sync:: LazyLock ;
13-
14- /// Kernel signature: 8 distances in, weighted sum out.
15- /// `distances`: 8 f64 values (one per HEEL plane).
16- /// `weights`: 8 f64 weights (per-expert importance).
17- /// Returns: weighted sum = Σ(distance[i] × weight[i]).
18- type HeelF64x8DotFn = unsafe fn ( & [ f64 ; 8 ] , & [ f64 ; 8 ] ) -> f64 ;
19-
20- #[ cfg( target_arch = "x86_64" ) ]
21- #[ target_feature( enable = "avx512f" ) ]
22- unsafe fn heel_dot_avx512 ( a : & [ f64 ; 8 ] , b : & [ f64 ; 8 ] ) -> f64 {
23- use std:: arch:: x86_64:: * ;
24- let va = _mm512_loadu_pd ( a. as_ptr ( ) ) ;
25- let vb = _mm512_loadu_pd ( b. as_ptr ( ) ) ;
26- let prod = _mm512_mul_pd ( va, vb) ;
27- _mm512_reduce_add_pd ( prod)
28- }
29-
30- #[ cfg( target_arch = "x86_64" ) ]
31- #[ target_feature( enable = "avx2" ) ]
32- unsafe fn heel_dot_avx2 ( a : & [ f64 ; 8 ] , b : & [ f64 ; 8 ] ) -> f64 {
33- use std:: arch:: x86_64:: * ;
34- // 2 passes of 4 lanes
35- let va0 = _mm256_loadu_pd ( a. as_ptr ( ) ) ;
36- let vb0 = _mm256_loadu_pd ( b. as_ptr ( ) ) ;
37- let p0 = _mm256_mul_pd ( va0, vb0) ;
38-
39- let va1 = _mm256_loadu_pd ( a[ 4 ..] . as_ptr ( ) ) ;
40- let vb1 = _mm256_loadu_pd ( b[ 4 ..] . as_ptr ( ) ) ;
41- let p1 = _mm256_mul_pd ( va1, vb1) ;
42-
43- let sum = _mm256_add_pd ( p0, p1) ;
44- // Horizontal sum of 4 f64
45- let hi = _mm256_extractf128_pd ( sum, 1 ) ;
46- let lo = _mm256_castpd256_pd128 ( sum) ;
47- let pair = _mm_add_pd ( lo, hi) ;
48- let hi64 = _mm_unpackhi_pd ( pair, pair) ;
49- let result = _mm_add_sd ( pair, hi64) ;
50- _mm_cvtsd_f64 ( result)
51- }
52-
53- fn heel_dot_scalar ( a : & [ f64 ; 8 ] , b : & [ f64 ; 8 ] ) -> f64 {
54- let mut sum = 0.0f64 ;
55- for i in 0 ..8 {
56- sum += a[ i] * b[ i] ;
57- }
58- sum
59- }
6+ //! Uses `crate::simd::F64x8` polyfill — automatic dispatch:
7+ //! AVX-512: native __m512d (one register)
8+ //! AVX2: 2× __m256d (two registers, same API)
9+ //! Scalar: [f64; 8] fallback
10+ //! Consumer writes `crate::simd::F64x8`. The polyfill handles the rest.
6011
61- static HEEL_DOT_KERNEL : LazyLock < HeelF64x8DotFn > = LazyLock :: new ( || {
62- #[ cfg( target_arch = "x86_64" ) ]
63- {
64- if is_x86_feature_detected ! ( "avx512f" ) {
65- return heel_dot_avx512 as HeelF64x8DotFn ;
66- }
67- if is_x86_feature_detected ! ( "avx2" ) {
68- return heel_dot_avx2 as HeelF64x8DotFn ;
69- }
70- }
71- heel_dot_scalar as HeelF64x8DotFn
72- } ) ;
12+ use crate :: simd:: F64x8 ;
7313
7414/// Compute weighted dot product of 8 HEEL plane distances.
7515///
7616/// `distances[i]` = distance for HEEL plane i.
7717/// `weights[i]` = importance weight for plane i.
7818/// Returns: Σ(distances[i] × weights[i]).
7919///
80- /// One SIMD pass on AVX-512 ( single ` vmulpd` + ` vreducepd`) .
81- /// Two passes on AVX2 . Scalar fallback for non-x86 .
20+ /// One F64x8 multiply + reduce_sum. On AVX-512: single vmulpd + vreducepd.
21+ /// On AVX2: 2× vmulpd + 2× haddpd . Scalar: 8 multiplies + sum .
8222#[ inline]
8323pub fn heel_weighted_distance ( distances : & [ f64 ; 8 ] , weights : & [ f64 ; 8 ] ) -> f64 {
84- unsafe { HEEL_DOT_KERNEL ( distances, weights) }
24+ let vd = F64x8 :: from_slice ( distances) ;
25+ let vw = F64x8 :: from_slice ( weights) ;
26+ ( vd * vw) . reduce_sum ( )
8527}
8628
8729/// Compute L1-like distance across 8 HEEL planes.
8830///
8931/// For each plane i: distance[i] = popcount(a[i] XOR b[i]) as f64.
90- /// Then weighted sum via F64x8 dot product.
91- ///
92- /// This converts binary Hamming distances to f64 for weighted combination,
93- /// where each plane's contribution is scaled by expert importance.
32+ /// This is Hamming on binary HEEL planes — valid because HEEL planes
33+ /// ARE uniform binary data (unlike bgz17 i16 which must use L1).
9434pub fn heel_plane_distances ( a : & [ u64 ; 8 ] , b : & [ u64 ; 8 ] ) -> [ f64 ; 8 ] {
9535 let mut dists = [ 0.0f64 ; 8 ] ;
9636 for i in 0 ..8 {
@@ -99,7 +39,7 @@ pub fn heel_plane_distances(a: &[u64; 8], b: &[u64; 8]) -> [f64; 8] {
9939 dists
10040}
10141
102- /// Full pipeline: 8 HEEL planes → Hamming per plane → weighted F64x8 dot → scalar distance .
42+ /// Full pipeline: 8 HEEL planes → Hamming per plane → weighted F64x8 dot → scalar.
10343#[ inline]
10444pub fn heel_weighted_hamming (
10545 a_planes : & [ u64 ; 8 ] ,
@@ -117,20 +57,151 @@ pub const UNIFORM_WEIGHTS: [f64; 8] = [1.0; 8];
11757/// Contradiction plane (index 7) gets 0.5× weight.
11858pub const HEEL_7PLUS1_WEIGHTS : [ f64 ; 8 ] = [ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.5 ] ;
11959
60+ // ═══════════════════════════════════════════════════════════════════════════
61+ // SIMD cosine similarity via F64x8 — for CLAM cosine clustering
62+ // ═══════════════════════════════════════════════════════════════════════════
63+
64+ /// SIMD dot product on f64 slices via F64x8.
65+ ///
66+ /// Processes 8 elements per iteration. Remainder handled scalar.
67+ /// Used by cosine_simd as the inner kernel.
68+ pub fn dot_f64_simd ( a : & [ f64 ] , b : & [ f64 ] ) -> f64 {
69+ let n = a. len ( ) . min ( b. len ( ) ) ;
70+ let chunks = n / 8 ;
71+ let remainder = n % 8 ;
72+
73+ let mut acc = F64x8 :: splat ( 0.0 ) ;
74+ for i in 0 ..chunks {
75+ let va = F64x8 :: from_slice ( & a[ i * 8 ..] ) ;
76+ let vb = F64x8 :: from_slice ( & b[ i * 8 ..] ) ;
77+ acc = va. mul_add ( vb, acc) ; // acc = va * vb + acc (FMA)
78+ }
79+ let mut sum = acc. reduce_sum ( ) ;
80+
81+ // Scalar remainder
82+ let offset = chunks * 8 ;
83+ for i in 0 ..remainder {
84+ sum += a[ offset + i] * b[ offset + i] ;
85+ }
86+ sum
87+ }
88+
89+ /// SIMD sum of squares via F64x8.
90+ pub fn sum_sq_f64_simd ( a : & [ f64 ] ) -> f64 {
91+ let n = a. len ( ) ;
92+ let chunks = n / 8 ;
93+ let remainder = n % 8 ;
94+
95+ let mut acc = F64x8 :: splat ( 0.0 ) ;
96+ for i in 0 ..chunks {
97+ let va = F64x8 :: from_slice ( & a[ i * 8 ..] ) ;
98+ acc = va. mul_add ( va, acc) ; // acc = va * va + acc
99+ }
100+ let mut sum = acc. reduce_sum ( ) ;
101+
102+ let offset = chunks * 8 ;
103+ for i in 0 ..remainder {
104+ sum += a[ offset + i] * a[ offset + i] ;
105+ }
106+ sum
107+ }
108+
109+ /// SIMD cosine similarity on f64 slices.
110+ ///
111+ /// Computes dot(a,b) / (||a|| × ||b||) using F64x8 FMA.
112+ /// Single pass: accumulates dot, norm_a, norm_b simultaneously.
113+ pub fn cosine_f64_simd ( a : & [ f64 ] , b : & [ f64 ] ) -> f64 {
114+ let n = a. len ( ) . min ( b. len ( ) ) ;
115+ let chunks = n / 8 ;
116+ let remainder = n % 8 ;
117+
118+ let mut dot_acc = F64x8 :: splat ( 0.0 ) ;
119+ let mut na_acc = F64x8 :: splat ( 0.0 ) ;
120+ let mut nb_acc = F64x8 :: splat ( 0.0 ) ;
121+
122+ for i in 0 ..chunks {
123+ let va = F64x8 :: from_slice ( & a[ i * 8 ..] ) ;
124+ let vb = F64x8 :: from_slice ( & b[ i * 8 ..] ) ;
125+ dot_acc = va. mul_add ( vb, dot_acc) ; // dot += a*b
126+ na_acc = va. mul_add ( va, na_acc) ; // na += a*a
127+ nb_acc = vb. mul_add ( vb, nb_acc) ; // nb += b*b
128+ }
129+
130+ let mut dot = dot_acc. reduce_sum ( ) ;
131+ let mut na = na_acc. reduce_sum ( ) ;
132+ let mut nb = nb_acc. reduce_sum ( ) ;
133+
134+ let offset = chunks * 8 ;
135+ for i in 0 ..remainder {
136+ dot += a[ offset + i] * b[ offset + i] ;
137+ na += a[ offset + i] * a[ offset + i] ;
138+ nb += b[ offset + i] * b[ offset + i] ;
139+ }
140+
141+ let denom = ( na * nb) . sqrt ( ) ;
142+ if denom < 1e-12 { 0.0 } else { dot / denom }
143+ }
144+
145+ /// SIMD cosine similarity on f32 slices (converts to f64 internally for precision).
146+ ///
147+ /// For hot paths where input is f32 but you need f64 precision cosine.
148+ /// Converts 8 f32 → 8 f64 per chunk via scalar widening, then F64x8 FMA.
149+ pub fn cosine_f32_to_f64_simd ( a : & [ f32 ] , b : & [ f32 ] ) -> f64 {
150+ let n = a. len ( ) . min ( b. len ( ) ) ;
151+ let chunks = n / 8 ;
152+ let remainder = n % 8 ;
153+
154+ let mut dot_acc = F64x8 :: splat ( 0.0 ) ;
155+ let mut na_acc = F64x8 :: splat ( 0.0 ) ;
156+ let mut nb_acc = F64x8 :: splat ( 0.0 ) ;
157+
158+ let mut buf_a = [ 0.0f64 ; 8 ] ;
159+ let mut buf_b = [ 0.0f64 ; 8 ] ;
160+
161+ for i in 0 ..chunks {
162+ let off = i * 8 ;
163+ for j in 0 ..8 {
164+ buf_a[ j] = a[ off + j] as f64 ;
165+ buf_b[ j] = b[ off + j] as f64 ;
166+ }
167+ let va = F64x8 :: from_slice ( & buf_a) ;
168+ let vb = F64x8 :: from_slice ( & buf_b) ;
169+ dot_acc = va. mul_add ( vb, dot_acc) ;
170+ na_acc = va. mul_add ( va, na_acc) ;
171+ nb_acc = vb. mul_add ( vb, nb_acc) ;
172+ }
173+
174+ let mut dot = dot_acc. reduce_sum ( ) ;
175+ let mut na = na_acc. reduce_sum ( ) ;
176+ let mut nb = nb_acc. reduce_sum ( ) ;
177+
178+ let offset = chunks * 8 ;
179+ for i in 0 ..remainder {
180+ let ai = a[ offset + i] as f64 ;
181+ let bi = b[ offset + i] as f64 ;
182+ dot += ai * bi;
183+ na += ai * ai;
184+ nb += bi * bi;
185+ }
186+
187+ let denom = ( na * nb) . sqrt ( ) ;
188+ if denom < 1e-12 { 0.0 } else { dot / denom }
189+ }
190+
120191#[ cfg( test) ]
121192mod tests {
122193 use super :: * ;
123194
124195 #[ test]
125- fn dot_product_basic ( ) {
196+ fn heel_dot_basic ( ) {
126197 let a = [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ] ;
127- let b = [ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ] ;
198+ let b = [ 1.0 ; 8 ] ;
128199 let result = heel_weighted_distance ( & a, & b) ;
129- assert ! ( ( result - 36.0 ) . abs( ) < 1e-10 , "1+2+3+4+5+6+7 +8 = 36, got {}" , result) ;
200+ assert ! ( ( result - 36.0 ) . abs( ) < 1e-10 , "1+2+... +8 = 36, got {}" , result) ;
130201 }
131202
132203 #[ test]
133- fn dot_product_weighted ( ) {
204+ fn heel_dot_weighted ( ) {
134205 let distances = [ 10.0 , 20.0 , 30.0 , 40.0 , 50.0 , 60.0 , 70.0 , 80.0 ] ;
135206 let weights = [ 2.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.5 ] ;
136207 let result = heel_weighted_distance ( & distances, & weights) ;
@@ -141,38 +212,97 @@ mod tests {
141212 fn plane_distances_self_zero ( ) {
142213 let planes = [ 0x1234u64 ; 8 ] ;
143214 let dists = heel_plane_distances ( & planes, & planes) ;
144- for d in & dists {
145- assert_eq ! ( * d, 0.0 ) ;
146- }
215+ for d in & dists { assert_eq ! ( * d, 0.0 ) ; }
147216 }
148217
149218 #[ test]
150219 fn plane_distances_opposite ( ) {
151220 let a = [ 0u64 ; 8 ] ;
152221 let b = [ u64:: MAX ; 8 ] ;
153222 let dists = heel_plane_distances ( & a, & b) ;
154- for d in & dists {
155- assert_eq ! ( * d, 64.0 ) ;
156- }
223+ for d in & dists { assert_eq ! ( * d, 64.0 ) ; }
157224 }
158225
159226 #[ test]
160227 fn full_pipeline_uniform ( ) {
161228 let a = [ 0xFFFF_0000_FFFF_0000u64 ; 8 ] ;
162229 let b = [ 0x0000_FFFF_0000_FFFFu64 ; 8 ] ;
163230 let d = heel_weighted_hamming ( & a, & b, & UNIFORM_WEIGHTS ) ;
164- // Each plane: all bits differ = 64
165- assert ! ( ( d - 64.0 * 8.0 ) . abs( ) < 1e-10 , "8 planes × 64 bits = 512, got {}" , d) ;
231+ assert ! ( ( d - 512.0 ) . abs( ) < 1e-10 , "8×64 = 512, got {}" , d) ;
166232 }
167233
168234 #[ test]
169235 fn seven_plus_one_weights ( ) {
170236 let a = [ 0u64 ; 8 ] ;
171237 let b = [ u64:: MAX ; 8 ] ;
172- let d_uniform = heel_weighted_hamming ( & a, & b, & UNIFORM_WEIGHTS ) ;
173- let d_7plus1 = heel_weighted_hamming ( & a, & b, & HEEL_7PLUS1_WEIGHTS ) ;
174- // 7+1: plane 7 at 0.5× = 7×64 + 0.5×64 = 480 vs 512
175- assert ! ( ( d_uniform - 512.0 ) . abs( ) < 1e-10 ) ;
176- assert ! ( ( d_7plus1 - 480.0 ) . abs( ) < 1e-10 , "7×64 + 0.5×64 = 480, got {}" , d_7plus1) ;
238+ let d = heel_weighted_hamming ( & a, & b, & HEEL_7PLUS1_WEIGHTS ) ;
239+ assert ! ( ( d - 480.0 ) . abs( ) < 1e-10 , "7×64 + 0.5×64 = 480, got {}" , d) ;
240+ }
241+
242+ // ── SIMD cosine tests ───────────────────────────────────────────
243+
244+ #[ test]
245+ fn cosine_identical ( ) {
246+ let a: Vec < f64 > = ( 0 ..1024 ) . map ( |i| ( i as f64 * 0.01 ) . sin ( ) ) . collect ( ) ;
247+ let c = cosine_f64_simd ( & a, & a) ;
248+ assert ! ( ( c - 1.0 ) . abs( ) < 1e-10 , "self-cosine should be 1.0: {}" , c) ;
249+ }
250+
251+ #[ test]
252+ fn cosine_opposite ( ) {
253+ let a: Vec < f64 > = ( 0 ..256 ) . map ( |i| i as f64 * 0.1 ) . collect ( ) ;
254+ let b: Vec < f64 > = a. iter ( ) . map ( |v| -v) . collect ( ) ;
255+ let c = cosine_f64_simd ( & a, & b) ;
256+ assert ! ( ( c - ( -1.0 ) ) . abs( ) < 1e-10 , "opposite should be -1.0: {}" , c) ;
257+ }
258+
259+ #[ test]
260+ fn cosine_orthogonal ( ) {
261+ let mut a = vec ! [ 0.0f64 ; 256 ] ;
262+ let mut b = vec ! [ 0.0f64 ; 256 ] ;
263+ a[ 0 ] = 1.0 ;
264+ b[ 1 ] = 1.0 ;
265+ let c = cosine_f64_simd ( & a, & b) ;
266+ assert ! ( c. abs( ) < 1e-10 , "orthogonal should be 0.0: {}" , c) ;
267+ }
268+
269+ #[ test]
270+ fn cosine_matches_scalar ( ) {
271+ let a: Vec < f64 > = ( 0 ..333 ) . map ( |i| ( i as f64 * 0.037 ) . sin ( ) ) . collect ( ) ;
272+ let b: Vec < f64 > = ( 0 ..333 ) . map ( |i| ( i as f64 * 0.023 ) . cos ( ) ) . collect ( ) ;
273+
274+ let simd_cos = cosine_f64_simd ( & a, & b) ;
275+
276+ // Scalar reference
277+ let dot: f64 = a. iter ( ) . zip ( & b) . map ( |( x, y) | x * y) . sum ( ) ;
278+ let na: f64 = a. iter ( ) . map ( |x| x * x) . sum ( ) ;
279+ let nb: f64 = b. iter ( ) . map ( |x| x * x) . sum ( ) ;
280+ let scalar_cos = dot / ( na * nb) . sqrt ( ) ;
281+
282+ assert ! ( ( simd_cos - scalar_cos) . abs( ) < 1e-10 ,
283+ "SIMD {:.12} vs scalar {:.12}" , simd_cos, scalar_cos) ;
284+ }
285+
286+ #[ test]
287+ fn cosine_f32_matches_f64 ( ) {
288+ let a_f32: Vec < f32 > = ( 0 ..500 ) . map ( |i| ( i as f32 * 0.01 ) . sin ( ) ) . collect ( ) ;
289+ let b_f32: Vec < f32 > = ( 0 ..500 ) . map ( |i| ( i as f32 * 0.02 ) . cos ( ) ) . collect ( ) ;
290+
291+ let a_f64: Vec < f64 > = a_f32. iter ( ) . map ( |& v| v as f64 ) . collect ( ) ;
292+ let b_f64: Vec < f64 > = b_f32. iter ( ) . map ( |& v| v as f64 ) . collect ( ) ;
293+
294+ let cos_f64 = cosine_f64_simd ( & a_f64, & b_f64) ;
295+ let cos_f32 = cosine_f32_to_f64_simd ( & a_f32, & b_f32) ;
296+
297+ assert ! ( ( cos_f64 - cos_f32) . abs( ) < 1e-6 ,
298+ "f32 {:.10} vs f64 {:.10}" , cos_f32, cos_f64) ;
299+ }
300+
301+ #[ test]
302+ fn dot_f64_simd_basic ( ) {
303+ let a = [ 1.0f64 ; 24 ] ;
304+ let b = [ 2.0f64 ; 24 ] ;
305+ let d = dot_f64_simd ( & a, & b) ;
306+ assert ! ( ( d - 48.0 ) . abs( ) < 1e-10 , "24×2 = 48, got {}" , d) ;
177307 }
178308}
0 commit comments