diff --git a/src/backend/native.rs b/src/backend/native.rs index 56123970..ee14bbb7 100644 --- a/src/backend/native.rs +++ b/src/backend/native.rs @@ -540,24 +540,71 @@ mod avx2 { } } - // No AVX2 specialization — fall through to scalar pub fn scal_f32(alpha: f32, x: &mut [f32]) { - super::scalar::scal_f32(alpha, x); + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() already verified AVX2 support before calling. + unsafe { scal_f32_avx2(alpha, x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::scal_f32(alpha, x); + } } pub fn scal_f64(alpha: f64, x: &mut [f64]) { - super::scalar::scal_f64(alpha, x); + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() already verified AVX2 support before calling. + unsafe { scal_f64_avx2(alpha, x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::scal_f64(alpha, x); + } } pub fn nrm2_f32(x: &[f32]) -> f32 { - super::scalar::nrm2_f32(x) + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() verified AVX2+FMA. + unsafe { nrm2_f32_avx2(x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::nrm2_f32(x) + } } pub fn nrm2_f64(x: &[f64]) -> f64 { - super::scalar::nrm2_f64(x) + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() verified AVX2+FMA. + unsafe { nrm2_f64_avx2(x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::nrm2_f64(x) + } } pub fn asum_f32(x: &[f32]) -> f32 { - super::scalar::asum_f32(x) + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() verified AVX2. + unsafe { asum_f32_avx2(x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::asum_f32(x) + } } pub fn asum_f64(x: &[f64]) -> f64 { - super::scalar::asum_f64(x) + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() verified AVX2. + unsafe { asum_f64_avx2(x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::asum_f64(x) + } } // ── AVX2 intrinsic implementations ───────────────────────────── @@ -677,6 +724,201 @@ mod avx2 { i += 1; } } + + // ── scal: x[i] *= alpha ──────────────────────────────────────── + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn scal_f32_avx2(alpha: f32, x: &mut [f32]) { + use core::arch::x86_64::*; + let n = x.len(); + let valpha = _mm256_set1_ps(alpha); + let mut i = 0; + while i + 8 <= n { + let v = _mm256_loadu_ps(x.as_ptr().add(i)); + _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v, valpha)); + i += 8; + } + while i < n { + x[i] *= alpha; + i += 1; + } + } + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn scal_f64_avx2(alpha: f64, x: &mut [f64]) { + use core::arch::x86_64::*; + let n = x.len(); + let valpha = _mm256_set1_pd(alpha); + let mut i = 0; + while i + 4 <= n { + let v = _mm256_loadu_pd(x.as_ptr().add(i)); + _mm256_storeu_pd(x.as_mut_ptr().add(i), _mm256_mul_pd(v, valpha)); + i += 4; + } + while i < n { + x[i] *= alpha; + i += 1; + } + } + + // ── nrm2: sqrt(Σ x[i]²) ──────────────────────────────────────── + // + // Two-accumulator unroll + FMA for the squared sum, scalar sqrt at + // the end. SIMD horizontal reduce ordering differs from the strict + // left-fold the scalar reference uses, so the ULP error can drift + // by 1-2 ULP on long vectors — same tolerance the existing + // `dot_f32_avx2` carries, accepted in BLAS-1. + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2,fma")] + unsafe fn nrm2_f32_avx2(x: &[f32]) -> f32 { + use core::arch::x86_64::*; + let n = x.len(); + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut i = 0; + while i + 16 <= n { + let v0 = _mm256_loadu_ps(x.as_ptr().add(i)); + let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8)); + acc0 = _mm256_fmadd_ps(v0, v0, acc0); + acc1 = _mm256_fmadd_ps(v1, v1, acc1); + i += 16; + } + while i + 8 <= n { + let v = _mm256_loadu_ps(x.as_ptr().add(i)); + acc0 = _mm256_fmadd_ps(v, v, acc0); + i += 8; + } + acc0 = _mm256_add_ps(acc0, acc1); + let hi = _mm256_extractf128_ps(acc0, 1); + let lo = _mm256_castps256_ps128(acc0); + let sum128 = _mm_add_ps(lo, hi); + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let result = _mm_add_ss(sums, shuf2); + let mut total = _mm_cvtss_f32(result); + while i < n { + total += x[i] * x[i]; + i += 1; + } + total.sqrt() + } + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2,fma")] + unsafe fn nrm2_f64_avx2(x: &[f64]) -> f64 { + use core::arch::x86_64::*; + let n = x.len(); + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + let mut i = 0; + while i + 8 <= n { + let v0 = _mm256_loadu_pd(x.as_ptr().add(i)); + let v1 = _mm256_loadu_pd(x.as_ptr().add(i + 4)); + acc0 = _mm256_fmadd_pd(v0, v0, acc0); + acc1 = _mm256_fmadd_pd(v1, v1, acc1); + i += 8; + } + while i + 4 <= n { + let v = _mm256_loadu_pd(x.as_ptr().add(i)); + acc0 = _mm256_fmadd_pd(v, v, acc0); + i += 4; + } + acc0 = _mm256_add_pd(acc0, acc1); + let hi = _mm256_extractf128_pd(acc0, 1); + let lo = _mm256_castpd256_pd128(acc0); + let sum128 = _mm_add_pd(lo, hi); + let shuf = _mm_unpackhi_pd(sum128, sum128); + let result = _mm_add_sd(sum128, shuf); + let mut total = _mm_cvtsd_f64(result); + while i < n { + total += x[i] * x[i]; + i += 1; + } + total.sqrt() + } + + // ── asum: Σ |x[i]| ───────────────────────────────────────────── + // + // Abs via AND with sign-bit-cleared mask (one AVX instruction — + // VANDPS), horizontal sum at the end. Same ordering caveat as + // nrm2. + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn asum_f32_avx2(x: &[f32]) -> f32 { + use core::arch::x86_64::*; + let n = x.len(); + // Sign-bit-cleared mask: 0x7FFFFFFF in every lane. + let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFF_FFFFi32)); + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut i = 0; + while i + 16 <= n { + let v0 = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i)), abs_mask); + let v1 = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i + 8)), abs_mask); + acc0 = _mm256_add_ps(acc0, v0); + acc1 = _mm256_add_ps(acc1, v1); + i += 16; + } + while i + 8 <= n { + let v = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i)), abs_mask); + acc0 = _mm256_add_ps(acc0, v); + i += 8; + } + acc0 = _mm256_add_ps(acc0, acc1); + let hi = _mm256_extractf128_ps(acc0, 1); + let lo = _mm256_castps256_ps128(acc0); + let sum128 = _mm_add_ps(lo, hi); + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let result = _mm_add_ss(sums, shuf2); + let mut total = _mm_cvtss_f32(result); + while i < n { + total += x[i].abs(); + i += 1; + } + total + } + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn asum_f64_avx2(x: &[f64]) -> f64 { + use core::arch::x86_64::*; + let n = x.len(); + let abs_mask = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7FFF_FFFF_FFFF_FFFFi64)); + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + let mut i = 0; + while i + 8 <= n { + let v0 = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i)), abs_mask); + let v1 = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i + 4)), abs_mask); + acc0 = _mm256_add_pd(acc0, v0); + acc1 = _mm256_add_pd(acc1, v1); + i += 8; + } + while i + 4 <= n { + let v = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i)), abs_mask); + acc0 = _mm256_add_pd(acc0, v); + i += 4; + } + acc0 = _mm256_add_pd(acc0, acc1); + let hi = _mm256_extractf128_pd(acc0, 1); + let lo = _mm256_castpd256_pd128(acc0); + let sum128 = _mm_add_pd(lo, hi); + let shuf = _mm_unpackhi_pd(sum128, sum128); + let result = _mm_add_sd(sum128, shuf); + let mut total = _mm_cvtsd_f64(result); + while i < n { + total += x[i].abs(); + i += 1; + } + total + } } // ═══════════════════════════════════════════════════════════════════ @@ -760,4 +1002,63 @@ mod tests { // Should be one of the valid tier values assert!(nr == 4 || nr == 8 || nr == 16); } + + // ── TD-T6: parity sweep for the new AVX2 BLAS-1 kernels ──────── + // + // The shim → real-intrinsic switch flipped scal/nrm2/asum from + // scalar-fallthrough to AVX2 chunked + scalar-tail kernels. Each + // new kernel: verify byte-equal (or ULP-tight for nrm2 which + // includes a sqrt and a different sum order) against the scalar + // reference across shapes that exercise the chunk-of-16, chunk- + // of-8, and scalar-tail code paths. + + fn ref_scal(alpha: f32, x: &[f32]) -> Vec { + x.iter().map(|&v| v * alpha).collect() + } + fn ref_nrm2(x: &[f32]) -> f32 { + x.iter().map(|&v| v * v).sum::().sqrt() + } + fn ref_asum(x: &[f32]) -> f32 { + x.iter().map(|&v| v.abs()).sum() + } + + #[test] + fn td_t6_scal_f32_parity() { + for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] { + let alpha = 1.5f32; + let init: Vec = (0..n).map(|i| (i as f32 * 0.5) - 1.0).collect(); + let expected = ref_scal(alpha, &init); + let mut got = init.clone(); + scal_f32(alpha, &mut got); + for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() { + assert_eq!(g.to_bits(), e.to_bits(), "scal_f32 n={n} i={i}: got {g} want {e}"); + } + } + } + + #[test] + fn td_t6_nrm2_f32_parity() { + for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] { + let x: Vec = (0..n).map(|i| (i as f32 * 0.3) - 0.5).collect(); + let expected = ref_nrm2(&x); + let got = nrm2_f32(&x); + // ULP tolerance because SIMD reduce order differs from + // strict left-fold; nrm2 also includes the final sqrt. + let abs_err = (got - expected).abs(); + let rel_tol = expected.abs() * 1e-5 + 1e-6; + assert!(abs_err <= rel_tol, "nrm2_f32 n={n}: got {got} want {expected} (err {abs_err})"); + } + } + + #[test] + fn td_t6_asum_f32_parity() { + for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] { + let x: Vec = (0..n).map(|i| (i as f32 * 0.3) - 0.5).collect(); + let expected = ref_asum(&x); + let got = asum_f32(&x); + let abs_err = (got - expected).abs(); + let rel_tol = expected.abs() * 1e-5 + 1e-6; + assert!(abs_err <= rel_tol, "asum_f32 n={n}: got {got} want {expected} (err {abs_err})"); + } + } }