Skip to content

backend/native: TD-T6 — real AVX2 kernels for scal/nrm2/asum (f32+f64)#186

Merged
AdaWorldAPI merged 1 commit into
masterfrom
claude/continue-ndarray-x0Oaw
May 21, 2026
Merged

backend/native: TD-T6 — real AVX2 kernels for scal/nrm2/asum (f32+f64)#186
AdaWorldAPI merged 1 commit into
masterfrom
claude/continue-ndarray-x0Oaw

Conversation

@AdaWorldAPI
Copy link
Copy Markdown
Owner

Summary

Closes TD-T6 (critical audit finding from the per-CPU matrix doc). The AVX2 native BLAS-1 module had documented // No AVX2 specialization — fall through to scalar shims for scal_f32/scal_f64/nrm2_f32/nrm2_f64/asum_f32/asum_f64 — six ops on every Haswell+ host fell to scalar even though dot_f32_avx2 and axpy_f32_avx2 shipped real AVX2 in the same module since day one.

This wires the six missing kernels.

Kernels

Op AVX2 instructions Throughput vs scalar (n=4096)
scal_* broadcast α via _mm256_set1_ps/pd, mul 8/4 lanes, scalar tail ~8×
nrm2_* 2-accumulator FMA (_mm256_fmadd_*), horiz reduce + sqrt ~16×
asum_* abs via _mm256_and_* with sign-bit-cleared mask, sum-reduce ~8×

All three follow the existing dot_f32_avx2 template — #[target_feature(enable = "avx2[,fma]")] on the inner unsafe fn, public wrapper does cfg(target_arch = "x86_64"), non-x86 builds keep their scalar fallback, scalar tail handles n % chunk_size.

Numerical contract

  • scal is byte-equal to scalar (x[i] *= α is the same op).
  • asum drifts ~1-2 ULP on long vectors because SIMD horizontal reduce orders the sum differently from strict left-fold.
  • nrm2 same as asum + final sqrt rounding.

Test tolerance: |got - expected| <= |expected| * 1e-5 + 1e-6 (same precedent as the existing dot_f32_avx2 and the BLAS reference implementations broadly).

Test plan

  • 2090 lib tests pass (was 2087 — +3 new parity tests).
  • 3 new td_t6_*_parity tests sweep n ∈ {0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100} — covers the chunk-of-16 unroll path, the chunk-of-8 cleanup, and the scalar tail for every kernel.
  • Existing test_scal_f32 / test_nrm2_f64 / test_asum_f32 (which used to exercise the scalar shims) now hit the AVX2 kernels and continue to pass.
  • cargo clippy --lib --tests --features rayon,native -- -D warnings clean.
  • cargo clippy --lib --tests --features rayon,native,runtime-dispatch -- -D warnings clean.
  • cargo fmt --all --check clean.

Out of scope (separate PRs)

  • AVX-512 versions of the same three ops — kernels_avx512.rs has them already (lines 137-209), wired through the cfg(target_feature = "avx512f") path. This PR fixes the AVX2 tier, which serves Haswell through Arrow Lake / Zen 1-3.
  • Runtime-dispatch trampolines for these ops (would go in simd_runtime/blas_l1.rs mirroring the matmul.rs pattern from PR simd_int_ops, hpc: AMX TDPBUSD arm for gemm_u8_i8 slice surface #185's runtime-dispatch landing).

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u


Generated by Claude Code

Closes TD-T6 (critical audit finding from the per-CPU matrix doc).
Before this commit, the AVX2 native BLAS-1 module had:

  pub fn scal_f32(alpha: f32, x: &mut [f32]) {
      super::scalar::scal_f32(alpha, x);  // ← scalar shim, no AVX2
  }
  pub fn nrm2_f32(x: &[f32]) -> f32 {
      super::scalar::nrm2_f32(x)          // ← scalar shim
  }
  pub fn asum_f32(x: &[f32]) -> f32 {
      super::scalar::asum_f32(x)          // ← scalar shim
  }
  // ... and f64 siblings, same shape

These were the documented "// No AVX2 specialization — fall through
to scalar" path. Three operations on every Haswell+ host fell to
scalar even though `dot_f32_avx2` and `axpy_f32_avx2` shipped real
AVX2 in the same module since day one. PR #180's audit flagged this
as TD-T6 (critical: blocks BLAS-1 throughput on Haswell / Arrow
Lake / Zen 1-3).

New AVX2 kernels (6 total — f32 + f64 for each of scal / nrm2 / asum):

  scal: broadcast α to ymm via `_mm256_set1_ps`, multiply 8/4 lanes
        at a time via `_mm256_mul_ps`/`_mm256_mul_pd`, scalar tail.
        Stores result back to the same buffer in-place.

  nrm2: two-accumulator unroll with `_mm256_fmadd_ps`/`_pd` (x²
        accumulated via FMA, single-rounded per IEEE), horizontal
        reduce + scalar sqrt. Same shape as `dot_f32_avx2` (which
        also unrolls 2 accumulators + uses FMA), just operates on
        one input vector instead of two.

  asum: abs via `_mm256_and_ps`/`_pd` with a sign-bit-cleared mask
        (0x7FFFFFFF for f32, 0x7FFFFFFFFFFFFFFF for f64) — one
        AVX instruction (VANDPS) is faster than calling f32::abs()
        lane-by-lane. Two-accumulator unroll + horizontal reduce.

All three follow the existing `dot_f32_avx2` template:
- `#[target_feature(enable = "avx2[,fma]")]` on the inner unsafe fn.
- Public wrapper does `cfg(target_arch = "x86_64")` and dispatches
  to the unsafe fn (tier detection in caller-of-caller verified
  AVX2 before reaching this module).
- Non-x86_64 builds: pass through to `super::scalar::*`.
- Scalar tail handles `n % chunk_size` lanes via the same fold the
  scalar reference uses.

Numerical contract:
  scal: byte-equal to scalar (`x[i] *= α` is the same op).
  asum: small ULP drift on long vectors because the SIMD horizontal
        reduce orders the sum differently from strict left-fold.
        Test tolerance: `|got - expected| <= |expected|*1e-5 + 1e-6`.
  nrm2: same — drifts ~1-2 ULP on long vectors via reduce-order +
        sqrt rounding. Same tolerance.

3 new parity tests (`td_t6_scal_f32_parity`,
`td_t6_nrm2_f32_parity`, `td_t6_asum_f32_parity`) sweep
n ∈ {0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100} — covers the
chunk-of-16 unroll path, the chunk-of-8 cleanup path, and the
scalar tail for every kernel.

Verification:
  * 2090 lib tests pass (was 2087 — +3 new parity tests; the
    existing test_scal_f32 / test_nrm2_f64 / test_asum_f32 that
    used to hit the scalar shims now exercise the AVX2 kernels
    and continue to pass).
  * cargo clippy --lib --tests --features rayon,native -- -D warnings
    clean.
  * cargo clippy --lib --tests --features rayon,native,runtime-dispatch
    -- -D warnings clean.
  * cargo fmt --all --check clean.

Throughput impact (back-of-envelope on Sapphire Rapids, n=4096):
  scal_f32: scalar 4096 cycles (1 mul/lane) → AVX2 ~520 cycles
            (8 lanes/instr + 1-cycle issue) = ~8× faster.
  asum_f32: scalar 4096 cycles → AVX2 ~520 cycles = ~8× faster.
  nrm2_f32: scalar 4096 cycles (1 FMA/lane) → AVX2 ~260 cycles
            (16 lanes via 2-acc unroll, 1-cycle issue) = ~16×.

Out of scope (separate PRs):
  * AVX-512 versions of the same three ops — `kernels_avx512.rs`
    has them already (lines 137-209), wired through the
    cfg(target_feature = "avx512f") path. This commit fixes the
    AVX2 tier, which serves Haswell through Arrow Lake / Zen 1-3.
  * Runtime-dispatch trampolines for these ops (would go in
    `simd_runtime/blas_l1.rs` mirroring the matmul.rs pattern from
    the runtime-dispatch PR).

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
@AdaWorldAPI AdaWorldAPI merged commit 8739d90 into master May 21, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants