diff --git a/.cargo/config-avx512.toml b/.cargo/config-avx512.toml index a4349ab9..1a7601c2 100644 --- a/.cargo/config-avx512.toml +++ b/.cargo/config-avx512.toml @@ -1,16 +1,50 @@ [build] -# Explicit AVX-512 config — `x86-64-v4`. Use with: +# Explicit AVX-512 config — Sapphire Rapids baseline. Use with: # cargo --config .cargo/config-avx512.toml build # cargo --config .cargo/config-avx512.toml test # -# Compiles `target_feature = "avx512f"` on, so `src/simd.rs` selects the -# `simd_avx512` backend with native `__m512` / `__m512d` / `__m512i` -# storage. Required for the Sapphire Rapids / Granite Rapids hot paths -# (`f32_to_bf16_batch_rne`, the AVX-512BF16 BF16 lanes, the AMX tiles). +# `-Ctarget-cpu=sapphirerapids` enables, in addition to the +# `x86-64-v4` AVX-512 baseline (F + BW + CD + DQ + VL): # -# Binary produced here will SIGILL on AVX2-only silicon — only use on -# hosts that report `avx512f` in `/proc/cpuinfo`. For shipping a single -# release artifact that adapts at process start, see the LazyLock runtime -# dispatch path in § 7.1 of the architecture doc instead. +# - AVX-512 VNNI (VPDPBUSD u8×i8 → i32) +# - AVX-512 BF16 (VDPBF16PS, VCVTNE2PS2BF16) +# - AVX-512 FP16 (16-wide native FP16 arithmetic) +# - AVX-512 VBMI / VBMI2 (byte permute) +# - AVX-512 IFMA, BITALG, VPOPCNTDQ, GFNI, VAES, VPCLMUL +# - AVX-VNNI (ymm VPDPBUSD on Alder/Sapphire client) +# - AMX-TILE + AMX-INT8 + AMX-BF16 (16×16×k tile kernels) +# +# Effect on the agnostic surfaces in `src/simd_*ops.rs`: +# +# - `simd_int_ops::gemm_u8_i8` resolves to the AVX-512 VNNI `VPDPBUSD` +# zmm kernel (`hpc::vnni_gemm::int8_gemm_vnni_avx512`). When the +# planned `amx-int8` arm lands, it will preempt this one and route +# to `TDPBUSD` instead — same source, no caller changes. +# - BF16 / FP16 lane ops in `src/simd_avx512.rs` light up. +# - `simd_amx::*` tile primitives are usable without further gating. +# +# Pure `x86-64-v4` is NOT used here — Skylake-X is the only AVX-512 CPU +# without VNNI and the project's design pins VNNI as the lowest common +# denominator above the scalar reference. SKX users either build with +# `-Ctarget-cpu=x86-64-v4` explicitly (and accept the scalar arm for +# `gemm_u8_i8`) or run a runtime-LazyLock dispatch binary. +# +# Binary produced here will SIGILL on CPUs that lack any of the +# enabled feature sets — i.e. anything pre-Sapphire-Rapids on x86_64: +# +# - Cooper Lake / Cascade Lake / Ice Lake-SP (no BF16+FP16+AMX) +# - Skylake-X / Skylake-SP / Skylake-W (no VNNI either) +# - Zen 4 / Zen 5 (no AMX) +# - Alder Lake / Arrow Lake (no AVX-512 at all) +# - Haswell ⇢ Coffee Lake (AVX2 only) +# +# Only deploy artifacts built with this config to hosts that report +# `amx_int8 amx_bf16 avx512_bf16 avx512_fp16 avx512_vnni` in +# `/proc/cpuinfo`. For Cascade Lake → Ice Lake-SP → Zen 4 silicon +# (AVX-512 + VNNI but no AMX/BF16/FP16), build with +# `-Ctarget-cpu=cascadelake` or `-Ctarget-cpu=znver4` instead. For +# shipping a single release artifact that adapts at process start, +# see the LazyLock runtime dispatch path in § 7.1 of the architecture +# doc instead. [target.'cfg(target_arch = "x86_64")'] -rustflags = ["-Ctarget-cpu=x86-64-v4"] +rustflags = ["-Ctarget-cpu=sapphirerapids"] diff --git a/.claude/knowledge/agnostic-surface-cpu-matrix.md b/.claude/knowledge/agnostic-surface-cpu-matrix.md new file mode 100644 index 00000000..93e97b07 --- /dev/null +++ b/.claude/knowledge/agnostic-surface-cpu-matrix.md @@ -0,0 +1,548 @@ +# Agnostic SIMD Surface — Per-CPU Resolution Matrix + Integration Plan + +> **Companion to:** `td-simd-cpu-dispatch-matrix.md` (CPU feature presence), +> `td-simd-tier-audit.md` (debt inventory), `td-simd-integration-plan.md` +> (`SimdProfile` architecture). This doc is the **cross-tab**: every public +> primitive in `crate::simd::*` × every CPU profile we target, showing the +> kernel that actually runs on that silicon. Gaps drive the integration plan. + +## CPU profile columns (abbreviations) + +Same set as `td-simd-cpu-dispatch-matrix.md` § "Master matrix — x86_64" and +§ "aarch64 profiles", with two-letter codes for table width: + +| Code | Profile (Cargo cpu / SimdProfile) | Generation | Critical features | +|------|-----------------------------------------|--------------------|--------------------------------| +| SKX | `skylake-avx512` / `SkylakeX` | Intel 2017 | AVX-512F+BW+DQ+CD+VL | +| CLX | `cascadelake` / `CascadeLake` | Intel 2019 | + AVX-512 VNNI | +| CPL | `cooperlake` / `CooperLake` | Intel 2020 | + AVX-512 BF16 (no VBMI) | +| ICX | `icelake-server` / `IceLakeSp` | Intel 2021 | + VBMI, no BF16 | +| SPR | `sapphirerapids` / `SapphireRapids` | Intel 2023 | + BF16+FP16+VBMI+AMX-INT8+BF16 | +| GNR | `graniterapids-d` / `GraniteRapids` | Intel 2024 | + AMX-FP16 | +| Z4 | `znver4` / `Zen4Avx512` | AMD 2022 | AVX-512 + VNNI+BF16+VBMI | +| Z5 | `znver5` / `Zen4Avx512` (same dispatch) | AMD 2024 | same as Z4 + minor uarch | +| ARL | `arrowlake` / `ArrowLake` | Intel 2024 | AVX2+FMA + AVX-VNNI+VNNI-INT8 | +| HSW | `x86-64-v3` / `HaswellAvx2` | Intel 2013→2021 | AVX2+FMA (no VNNI/AVX-512) | +| A76 | `cortex-a76` / `A76DotProd` | ARMv8.2 (Pi 5, M1) | NEON+dotprod+bf16+fp16 | +| A72 | `cortex-a72` / `A72Fast` | ARMv8.0 (Pi 4) | NEON only (no dotprod) | +| A53 | `cortex-a53` / `A53Baseline` | ARMv8.0 (Pi 3/Z2W) | NEON, lower IPC | +| SCA | scalar fallback | wasm32/riscv/i686 | no SIMD | + +Cell legend: + +- ✅ `kernel-name` — wired today, exercises the indicated kernel/intrinsic +- ⏳ `kernel-name` — kernel exists but **not** dispatched here yet (debt) +- 🟦 `kernel-name` — planned, no kernel exists yet (new code needed) +- 🟡 polyfill-pass — the call delegates to the polyfilled SIMD *type*; that + type's per-CPU lowering does the work (transparent dispatch — entry on + table A) +- ✗ scalar — falls back to a triple-loop scalar reference +- — — N/A on this profile + +--- + +## A. Polyfilled SIMD types — backing storage per CPU + +The polyfilled types in `crate::simd::*` ARE the CPU DTO surface (per the +session's "polyfill is everything" rule). Consumers write `F32x16`, the +type chooses native storage at compile time. Storage selection is driven +by `target_feature` cfg gates in `src/simd.rs` (lines 221-366). + +### Float vectors + +| Type | SKX | CLX | CPL | ICX | SPR | GNR | Z4 | Z5 | ARL | HSW | A76 | A72 | A53 | SCA | +|----------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------| +| `F32x16` | `__m512` | `__m512` | `__m512` | `__m512` | `__m512` | `__m512` | `__m512` | `__m512` | 2×`__m256` | 2×`__m256` | 4×`float32x4_t` (paired-load) | 4×`float32x4_t` | 4×`float32x4_t` | `[f32;16]` | +| `F32x8` | `__m256` | `__m256` | `__m256` | `__m256` | `__m256` | `__m256` | `__m256` | `__m256` | `__m256` | `__m256` | 2×`float32x4_t` | 2×`float32x4_t` | 2×`float32x4_t` | `[f32;8]` | +| `F64x8` | `__m512d` | `__m512d` | `__m512d` | `__m512d` | `__m512d` | `__m512d` | `__m512d` | `__m512d` | 2×`__m256d`| 2×`__m256d`| 4×`float64x2_t` | 4×`float64x2_t` | 4×`float64x2_t` | `[f64;8]` | +| `F64x4` | `__m256d` | `__m256d` | `__m256d` | `__m256d` | `__m256d` | `__m256d` | `__m256d` | `__m256d` | `__m256d` | `__m256d` | 2×`float64x2_t` | 2×`float64x2_t` | 2×`float64x2_t` | `[f64;4]` | + +### Half-precision vectors + +| Type | SKX | CLX | CPL | ICX | SPR | GNR | Z4 | Z5 | ARL | HSW | A76 | A72 | A53 | SCA | +|-----------|-----|-----|--------------------------|-----|--------------------------|--------------------------|--------------------------|--------------------------|-----|-----|-------------|-----|-----|-----| +| `BF16x16` (avx512bf16) | — | — | `__m256bh` (`simd_avx512`) | — | `__m256bh` | `__m256bh` | `__m256bh` | `__m256bh` | — | — | — | — | — | — | +| `BF16x16` (portable) | `[u16;16]` | `[u16;16]` | (uses native) | `[u16;16]` | (uses native) | (uses native) | (uses native) | (uses native) | `[u16;16]` | `[u16;16]` | `[u16;16]` 🚨 | `[u16;16]` | `[u16;16]` | `[u16;16]` | +| `BF16x8` (avx512bf16) | — | — | `__m128bh` | — | `__m128bh` | `__m128bh` | `__m128bh` | `__m128bh` | — | — | — | — | — | — | +| `F16x16` | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 | `[u16;16]` 🚨 (has fp16 HW!) | `[u16;16]` | `[u16;16]` | `[u16;16]` | + +🚨 = scalar polyfill where hardware exists — see TD-SIMD-8 in +`simd-dispatch-architecture.md` and § F gaps below. + +### Integer vectors (lane widths matching the audit's "missing lanes" sweep PR #179) + +Storage shape per CPU. "AVX-512" means native `__m512i`; "2×AVX2" means +two `__m256i` halves; "4×NEON" means four 128-bit NEON registers (e.g. +`int8x16x4_t`); "scalar" means `[T; N]` array, no SIMD register. + +| Type | SKX | CLX | CPL | ICX | SPR | GNR | Z4 | Z5 | ARL | HSW | A76 | A72 | A53 | SCA | +|----------|------------|-----|-----|-----|-----|-----|-----|-----|------------|------------|------------|------------|------------|------------| +| `I8x64` | `__m512i` | ← | ← | ← | ← | ← | ← | ← | 2×`__m256i`| 2×`__m256i`| 4×`int8x16_t` | ← | ← | `[i8;64]` | +| `I8x32` | `__m256i` | ← | ← | ← | ← | ← | ← | ← | `__m256i` | `__m256i` | 2×`int8x16_t` | ← | ← | `[i8;32]` | +| `U8x64` | `__m512i` | ← | ← | ← | ← | ← | ← | ← | 2×`__m256i`| 2×`__m256i`| 4×`uint8x16_t` | ← | ← | `[u8;64]` | +| `U8x32` | `__m256i` | ← | ← | ← | ← | ← | ← | ← | `__m256i` | `__m256i` | 2×`uint8x16_t` | ← | ← | `[u8;32]` | +| `I16x32` | `__m512i` | ← | ← | ← | ← | ← | ← | ← | 2×`__m256i`| 2×`__m256i`| 4×`int16x8_t` | ← | ← | `[i16;32]` | +| `I16x16` | `__m256i` | ← | ← | ← | ← | ← | ← | ← | `__m256i` | `__m256i` | 2×`int16x8_t` | ← | ← | `[i16;16]` | +| `U16x32` | `__m512i` | ← | ← | ← | ← | ← | ← | ← | 2×`__m256i`⏳| 2×`__m256i`⏳| 4×`uint16x8_t` | ← | ← | `[u16;32]` | +| `U16x16` | `__m256i` | ← | ← | ← | ← | ← | ← | ← | `__m256i` | `__m256i` | 2×`uint16x8_t` | ← | ← | `[u16;16]` | +| `I32x16` | `__m512i` | ← | ← | ← | ← | ← | ← | ← | 2×`__m256i`| 2×`__m256i`| 4×`int32x4_t` | ← | ← | `[i32;16]` | +| `I32x8` | `__m256i` | ← | ← | ← | ← | ← | ← | ← | `__m256i` | `__m256i` | 2×`int32x4_t` | ← | ← | `[i32;8]` | +| `U32x16` | `__m512i` | ← | ← | ← | ← | ← | ← | ← | 2×`__m256i`⏳| 2×`__m256i`⏳| 4×`uint32x4_t` | ← | ← | `[u32;16]` | +| `U32x8` | `__m256i` | ← | ← | ← | ← | ← | ← | ← | `__m256i`⏳ | `__m256i`⏳ | 2×`uint32x4_t` | ← | ← | `[u32;8]` | +| `I64x8` | `__m512i` | ← | ← | ← | ← | ← | ← | ← | 2×`__m256i`| 2×`__m256i`| 4×`int64x2_t` | ← | ← | `[i64;8]` | +| `I64x4` | `__m256i` | ← | ← | ← | ← | ← | ← | ← | `__m256i` | `__m256i` | 2×`int64x2_t` | ← | ← | `[i64;4]` | +| `U64x8` | `__m512i` | ← | ← | ← | ← | ← | ← | ← | 2×`__m256i`| 2×`__m256i`| 4×`uint64x2_t` | ← | ← | `[u64;8]` | +| `U64x4` | `__m256i` | ← | ← | ← | ← | ← | ← | ← | `__m256i` | `__m256i` | 2×`uint64x2_t` | ← | ← | `[u64;4]` | + +⏳ = TD-T22 polyfill audit — the 256-bit `U16x16/U16x32/U32x8/U32x16` +inner ops may currently use scalar storage under `#[target_feature]` rather +than real `__m256i` intrinsics. Needs verification (see § J integration plan). + +### Mask vectors + +| Type | SKX/CLX/CPL/ICX/SPR/GNR/Z4/Z5 | HSW/ARL | A76/A72/A53 | SCA | +|-----------|-------------------------------|---------|-------------|-----| +| `F32Mask16` | `__mmask16` (1 bit per lane) | `__m256i` (two-half mask) | 4×`uint32x4_t` (lane-mask) | `[bool;16]` | +| `F32Mask8` | `__mmask8` | `__m256i` (one-half mask) | 2×`uint32x4_t` | `[bool;8]` | +| `F64Mask8` | `__mmask8` | `__m256i` (two-half mask) | 4×`uint64x2_t` | `[bool;8]` | +| `F64Mask4` | `__mmask8` | `__m256i` (one-half mask) | 2×`uint64x2_t` | `[bool;4]` | + +### Critical type-method per-CPU lowerings (where it matters) + +Most methods (add, sub, mul, div, simd_lt, etc.) just delegate to the +storage's native op. The non-obvious lowerings: + +| Method | SKX | CLX | CPL | ICX | SPR | GNR | Z4 | Z5 | ARL | HSW | A76 | A72 | A53 | SCA | +|-------------------------|------------|------------|------------|------------|------------|------------|----|----|------------|------------|------------|------------|------------|------------------| +| `F32x16::mul_add` | `vfmadd231ps zmm` | ← | ← | ← | ← | ← | ← | ← | 2×`vfmadd231ps ymm` (FMA3) | 2×`vfmadd231ps ymm` | 4×`vfmaq_f32` | 4×`vfmaq_f32` | 4×`vfmaq_f32` | `f32::mul_add` | +| `F64x8::mul_add` | `vfmadd231pd zmm` | ← | ← | ← | ← | ← | ← | ← | 2×`vfmadd231pd ymm` | 2×`vfmadd231pd ymm` | 4×`vfmaq_f64` | 4×`vfmaq_f64` | 4×`vfmaq_f64` | `f64::mul_add` | +| `F32x16::simd_min/max` | `vminps/vmaxps zmm` | ← | ← | ← | ← | ← | ← | ← | 2×`vminps/vmaxps ymm` | 2×`vminps/vmaxps ymm` | 4×`vminq/vmaxq_f32` | ← | ← | scalar loop | +| `F32x16::reduce_sum` | `vaddps` + `_mm512_reduce_add_ps` ladder | ← | ← | ← | ← | ← | ← | ← | ymm reduce ladder | ymm reduce ladder | NEON paired-add ladder | ← | ← | iter sum | +| `simd_exp_f32` | Remez poly (F32x16) | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | (lib expects F32x16 from polyfill — currently no scalar override; scalar reduces lane-by-lane) | +| `simd_ln_f32` | scalar `f32::ln` per lane 🚨 | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← (TD-T18 in audit — no SIMD path on any backend) | + +--- + +## B. `simd_ops` — float slice ops + +All `simd_ops` slice functions are written **once** against the +polyfilled types (`F32x16`, `F64x8`) and inherit their per-CPU lowering. +The 🟡 cells indicate "transparent polyfill dispatch — see table A". + +| Function | SKX–GNR/Z4/Z5/ARL/HSW | A76/A72/A53 | SCA | Notes | +|----------------------|-------------------------|-----------------------|-------------------|-------| +| `add_f32` | 🟡 F32x16 + scalar tail | 🟡 | 🟡 + scalar tail | binary_f32 helper | +| `sub_f32` | 🟡 | 🟡 | 🟡 | | +| `mul_f32` | 🟡 | 🟡 | 🟡 | | +| `div_f32` | 🟡 | 🟡 | 🟡 | | +| `add_f32_inplace` | 🟡 | 🟡 | 🟡 | inplace_f32 helper | +| `sub_f32_inplace` | 🟡 | 🟡 | 🟡 | | +| `mul_f32_inplace` | 🟡 | 🟡 | 🟡 | | +| `div_f32_inplace` | 🟡 | 🟡 | 🟡 | | +| `scale_f32` | 🟡 | 🟡 | 🟡 | F32x16::mul broadcast | +| `add_scalar_f32` | 🟡 | 🟡 | 🟡 | F32x16::add broadcast | +| `scale_f32_inplace` | 🟡 | 🟡 | 🟡 | | +| **`add_mul_f32`** ✅ | 🟡 F32x16::mul_add + scalar tail (f32::mul_add) | 🟡 | 🟡 | NEW (this session) — FMA into accumulator | +| `add_f64` | 🟡 F64x8 + scalar tail | 🟡 | 🟡 | binary_f64 helper | +| `mul_f64` | 🟡 | 🟡 | 🟡 | | +| `add_f64_inplace` | 🟡 | 🟡 | 🟡 | | +| **`add_mul_f64`** ✅ | 🟡 F64x8::mul_add + scalar tail (f64::mul_add) | 🟡 | 🟡 | NEW (this session) | +| `array_chunks` | uniform — `slice::as_chunks` (stable) | uniform | uniform | const-size **non-overlapping** | +| `array_chunks_checked` | uniform | uniform | uniform | | +| **`array_windows`** ✅ | uniform — index-based iter | uniform | uniform | NEW (this session) — const-size **overlapping** | +| **`array_windows_checked`** ✅ | uniform | uniform | uniform | NEW (this session) | + +**Gap:** none — every `simd_ops` surface ride on the polyfill primitives. +Floats are the well-served path. Any speedup at this layer requires the +polyfilled types themselves to expose a faster primitive (e.g. a `dpbusd` +op on `I32x16`, see § J integration plan Phase 4). + +--- + +## C. `simd_int_ops` — integer slice ops + +| Function | SKX | CLX | CPL | ICX | SPR | GNR | Z4 | Z5 | ARL | HSW | A76 | A72 | A53 | SCA | +|--------------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|-----| +| `add_i8` ✅ MX-T1a | ✅ `_mm512_add_epi8` via `I8x64` | ← | ← | ← | ← | ← | ← | ← | ✅ `_mm256_add_epi8` ×2 via `I8x64` polyfill | ← | ✅ `vaddq_s8` via `I8x16` | ← | ← | ✅ scalar wrapping_add | +| `sub_i8` ✅ MX-T1a | ✅ `_mm512_sub_epi8` | ← | ← | ← | ← | ← | ← | ← | ✅ `_mm256_sub_epi8` ×2 | ← | ✅ `vsubq_s8` | ← | ← | ✅ scalar wrapping_sub | +| `add_i16` ✅ MX-T1a| ✅ `_mm512_add_epi16` via `I16x32` | ← | ← | ← | ← | ← | ← | ← | ✅ `_mm256_add_epi16` via `I16x32` polyfill | ← | ✅ `vaddq_s16` via `I16x8` | ← | ← | ✅ scalar wrapping_add | +| `dot_i8` | ✗ scalar 🚨 | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ✗ | +| `dot_i16` | ✗ scalar 🚨 | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ✗ | +| `min_i8` | ✅ `vpminsb zmm` via I8x64 | ← | ← | ← | ← | ← | ← | ← | ✅ `vpminsb ymm` via I8x32 polyfill of I8x64 | ← | ✅ `vminq_s8` via I8x16 | ← | ← | ✗ scalar loop | +| `max_i8` | ✅ `vpmaxsb zmm` via I8x64 | ← | ← | ← | ← | ← | ← | ← | ✅ `vpmaxsb ymm` | ← | ✅ `vmaxq_s8` | ← | ← | ✗ | +| **`gemm_u8_i8`** ✅ | ✗ scalar (no VNNI) | ✅ `vpdpbusd zmm` (CLX+) | ← | ← | ← | ← | ← | ← | ✅ `vpdpbusd ymm` (avxvnni) | ✗ scalar | 🟦 `sdot+128-bias` (planned) | ✗ scalar | ✗ scalar | ✗ scalar | +| `gemm_u8_i8` AMX preempt | — | — | — | — | 🟦 `tdpbusd` 16×16 tile (planned) | 🟦 `tdpbusd` | — | — | — | — | — | — | — | — | + +🚨 = scalar where SIMD exists. Each of these has 16-wide `I8x64::add` etc. +already in the polyfill but the slice ops don't reach for them. Trivial fix +once we decide to land an int-slice-ops sweep — see § J Phase 1b. + +--- + +## D. `simd_half` — BF16 / F16 ops + +The half-precision surface is **uniformly scalar** today: every op upcasts +to f32 lane-by-lane, computes, downcasts back via round-to-nearest-even. +This is TD-SIMD-8 in the audit — hardware paths exist on every CPU class +but only one (`BF16x16` on avx512bf16) is wired. + +| Function | SKX | CLX | CPL | ICX | SPR | GNR | Z4 | Z5 | ARL | HSW | A76 | A72 | A53 | SCA | +|---------------------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|-----| +| `BF16x16::from_slice` | uniform — `[u16;16]` load | ← | ← | ← | ← (native `__m256bh` swap-in) | ← | ← (native) | ← (native) | ← | ← | ← | ← | ← | ← | +| `BF16x16::add/sub/mul` | 🚨 scalar f32 upcast | ← | ⏳ `vdpbf16ps`-able via F32x16 mul | ← | ⏳ ditto + AMX-BF16 tile | ← | ⏳ | ⏳ | 🚨 scalar | 🚨 scalar | 🚨 scalar (BFMLALB-able) | 🚨 scalar | 🚨 scalar | 🚨 scalar | +| `BF16x16::fma` | 🚨 scalar f32 mul_add | ← | ⏳ `vdpbf16ps zmm` | ← | ⏳ AMX-BF16 / VDPBF16PS | ← | ⏳ VDPBF16PS | ⏳ | 🚨 scalar | 🚨 scalar | 🚨 scalar (BFMMLA-able) | 🚨 | 🚨 | 🚨 | +| `BF16x16::to_f32x16` | 🚨 scalar bit-shift | ← | ⏳ `vcvtne2ps2bf16` reverse | ← | ⏳ | ⏳ | ⏳ | ⏳ | 🚨 scalar | 🚨 | 🚨 (BFCVTN-able) | 🚨 | 🚨 | 🚨 | +| `F16x16::add/sub/mul` | 🚨 scalar | ← | ← | ← | ⏳ `vmulph zmm` (avx512fp16) | ← | ⏳ avx512fp16 | ⏳ | 🚨 | 🚨 | 🚨 (FMLA `v.8h`) | 🚨 | 🚨 | 🚨 | +| `F16x16::fma` | 🚨 scalar mul_add | ← | ← | ← | ⏳ `vfmadd231ph zmm` | ← | ⏳ | ⏳ | 🚨 | 🚨 | 🚨 (FMLA `v.8h`) | 🚨 | 🚨 | 🚨 | +| `F16x16::to_f32x16` | 🚨 scalar | ← | ← | ← | ← (could use F16C `vcvtph_ps` for ymm halves on every x86 from Ivy Bridge — TD-SIMD-8 misses this on ALL profiles) | ← | ← | ← | 🚨 | 🚨 (F16C wired-able) | 🚨 (`vcvt_f32_f16`) | 🚨 | 🚨 | 🚨 | +| `add_bf16_inplace` | 🟡 BF16x16 + scalar tail (inherits whatever BF16x16::add does) | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `mul_bf16_inplace` | 🟡 BF16x16 | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `add_f16_inplace` | 🟡 F16x16 | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `mul_f16_inplace` | 🟡 F16x16 | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `cast_bf16_to_f32_batch` | 🟡 BF16x16::to_f32x16 + tail | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `cast_f16_to_f32_batch` | 🟡 F16x16::to_f32x16 | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `cast_f32_to_bf16_batch` | ✗ scalar per-element 🚨 | ← | ⏳ should call `f32_to_bf16_batch_rne` (already exists for AVX-512) | ← | ⏳ AMX-BF16 / `vcvtne2ps2bf16` | ← | ⏳ | ⏳ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | +| `cast_f32_to_f16_batch` | ✗ scalar per-element 🚨 | ← | ← | ← | ⏳ `vcvtps2phx zmm` (avx512fp16) | ← | ⏳ | ⏳ | ✗ (F16C wired-able) | ✗ (F16C) | ✗ (`vcvt_f16_f32`) | ✗ | ✗ | ✗ | + +**Gap, severe.** F16/BF16 is the AI/ML hot path and the entire surface is +scalar-equivalent on every CPU. Even where F16C has been stable since 2012 +(Ivy Bridge) the dispatch doesn't reach for it. Phases F1–F3 in the +integration plan below. + +--- + +## E. Batch converters + transcendentals (`crate::simd::*` direct) + +These don't go through the polyfilled types — they're standalone +functions in `src/simd.rs` and `src/simd_avx512.rs`. + +| Function | SKX | CLX | CPL | ICX | SPR | GNR | Z4 | Z5 | ARL | HSW | A76 | A72 | A53 | SCA | +|-----------------------------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|-----| +| `bf16_to_f32_batch` | ✅ scalar batch via `<< 16` cast | ← | ✅ same | ← | ✅ same | ← | ✅ | ✅ | ✅ | ✅ | ✅ (NEON-batchable, currently scalar) | ✅ | ✅ | ✅ | +| `bf16_to_f32_scalar` | uniform — scalar reference | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `f32_to_bf16_batch` | ✅ scalar truncate (no rounding) | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `f32_to_bf16_scalar` | uniform — scalar reference | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| **`f32_to_bf16_batch_rne`** | ✅ AVX-512-F bit-fiddle (no avx512bf16 dep!) 500–20000× faster than scalar; byte-exact vs `_mm512_cvtneps_pbh` | ← | ← | ← | ← | ← | ← | ← | ✗ scalar 🚨 (uses AVX-512-F-only ops on byte loads — could be lifted to AVX2 in principle) | ✗ scalar 🚨 | ✗ scalar 🚨 | ✗ scalar | ✗ scalar | ✗ scalar | +| `f32_to_bf16_scalar_rne` | uniform — reference impl, must NOT be in hot loops | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | +| `simd_exp_f32` | ✅ Remez poly via F32x16 | ← | ← | ← | ← | ← | ← | ← | ✅ (lower lane count via F32x16 polyfill of two ymm) | ✅ same | ✅ | ✅ | ✅ | ✗ scalar | +| `simd_ln_f32` | ✗ scalar `f32::ln` per lane on ALL profiles 🚨 | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | ← | + +--- + +## F. `simd_soa` — SoA carriers (`MultiLaneColumn`) + +Layout-only. Every method is uniform across CPUs — the per-CPU dispatch +lives inside the polyfilled types returned by `iter_u8x64` / `iter_f32x16` +/ `iter_f64x8` / `iter_u64x8`. See table A. + +| Method | Behavior across all CPUs | +|-----------------------|---------------------------------------------------------------------| +| `MultiLaneColumn::new`| `Arc<[u8]>` carrier validation (multiple-of-64 byte buffer) | +| `len_*` / `is_empty` | u64 arithmetic on `Arc.len()` | +| `iter_u8x64` | `as_chunks::<64>` + `U8x64::from_array` (delegates to polyfill) | +| `iter_f32x16` | `as_chunks::<64>` + per-chunk `f32::from_le_bytes` × 16 + `from_array` | +| `iter_f64x8` | `as_chunks::<64>` + per-chunk `f64::from_le_bytes` × 8 + `from_array` | +| `iter_u64x8` | `as_chunks::<64>` + per-chunk `u64::from_le_bytes` × 8 + `from_array` | +| `as_bytes` | Arc-aliased `&[u8]` view | + +**Gap:** none at this layer — gaps in the polyfilled types propagate +transparently, gain from filling them is automatic. + +--- + +## G. Cognitive / HPC re-exports surfaced through `crate::simd::*` + +These are re-exports of functions that themselves use `crate::simd::*` — +their per-CPU resolution is the polyfill's, but they're listed here for +inventory completeness since they appear in the public `crate::simd::*` API. + +| Symbol | Behavior across CPUs | +|-----------------------------------------------------------------|---------------------| +| `Fingerprint{,1K,2K,64K}`, `VectorConfig`, `VectorWidth` | 🟡 polyfill-pass (uses F32x16 / U64x8 internally) | +| `hamming_distance_raw`, `popcount_raw` | TD-T-? — needs audit. AVX-512 VPOPCNTDQ wiring partially landed. | +| `wht_f32`, `wht_f32_new` | 🟡 polyfill-pass (uses F32x16) | +| `CollapseGate` | 🟡 polyfill-pass | +| `kmeans`, `squared_l2` | 🟡 polyfill-pass (uses F32x16) | +| `cosine_f32_to_f64_simd` (heel_f64x8) | 🟡 polyfill-pass (uses F64x8 + F32x16) | +| `quantize_f32_to_{i2,i4,i8}`, `dequantize_{i2,i4,i8}_to_f32` | TD-? — needs audit. Likely scalar today. | +| `QuantParams` | data carrier, no per-CPU divergence | +| `MultiLaneColumn` | covered in § F | +| `array_chunks` / `array_windows` | covered in § B | +| `add_f32` / … / `add_mul_f32` / `add_mul_f64` | covered in § B | +| `add_bf16_inplace`, `cast_*_batch`, `BF16x16`, `F16x16` | covered in § D | + +--- + +## H. Currently-MISSING agnostic surfaces (mentioned in integration plans but not yet present) + +Things we know we want but haven't built yet — sourced from the audit ++ integration plan + dispatch matrix companions. + +| Symbol | Purpose | Currently | +|-----------------------------------------|------------------------------------------------------|-----------| +| `simd_int_ops::gemm_i8` (s8 × s8 → i32) | True symmetric VNNI2 surface (Arrow Lake / GNR `vpdpbssd`) | ✗ missing | +| `simd_int_ops::gemm_u8` (u8 × u8 → u32) | Symmetric unsigned VNNI2 (`vpdpbuud`) | ✗ missing | +| `simd_int_ops::dot4_u8_i8` (vector op) | The polyfilled dot-4 primitive on `I32x{8,16}` | ✗ missing | +| `simd_ops::axpy_f32` (scalar α) | BLAS-1 `y += α * x` (different from `add_mul_f32`'s vector β) | ✗ missing | +| `simd_ops::dot_f32` | BLAS-1 f32 dot product | ✗ missing | +| `simd_ops::nrm2_f32`, `asum_f32` | BLAS-1 vector norms | ✗ missing | +| `simd_ops::gemv_f32` | BLAS-2 matrix-vector (currently TD-T7 scalar) | ✗ missing | +| `simd_ops::gemm_f32` | BLAS-3 (currently uses `matrixmultiply` workspace) | ✗ deferred — `matrixmultiply` is the production path | +| `simd_int_ops::dot_i32` / `dot_i32_i64` | INT32 dot, INT16×INT16→INT32 via VPDPWSSD | ✗ missing | +| `SimdProfile` enum + `simd_profile()` | Phase 3 dispatch foundation per integration plan | ✗ missing | +| `cpu-spr` / `cpu-zen4` / etc. features | Compile-time pin cargo features (integration plan) | ✗ missing | + +--- + +## I. Cross-cutting infrastructure status + +| Item | Status | +|-------------------------------------------------|---------------| +| **`.cargo/config.toml`** default `x86-64-v3` | ✅ (CI baseline) | +| **`.cargo/config-avx512.toml`** = `sapphirerapids` | ✅ (this session) | +| **`.cargo/config-native.toml`** = `native` | ✅ already in tree | +| **`.cargo/config-apple-m2.toml`** | ✅ in tree | +| **`.cargo/config-pi5.toml`** (A76+) | ✅ in tree | +| **`.cargo/config-graviton.toml`** (A72/A76 AWS)| ✅ in tree | +| Cargo features `cpu-spr` / `cpu-icx` / `cpu-zen4` / etc. | ✗ missing (Phase 3) | +| Cargo feature `runtime-dispatch` (LazyLock-once table) | ✗ missing (Phase 3) | +| `SimdProfile` enum | ✗ missing (Phase 3) | +| GitHub CI matrix (default v3, nightly-simd, avx512, aarch64) | ✅ partial — verified per CI doc | +| Bench harness for `gemm_u8_i8` | ✅ this session (ignored test) | +| Bench harness for BF16 / F16 ops | ✗ missing | +| Bench harness for `simd_ops` slice ops | ✗ missing | + +--- + +## J. INTEGRATION PLAN + +Filling the matrix in deliberate phases. Each item is one PR-sized unit. + +### Phase 0 — Already landed (this session) + +- ✅ `simd_int_ops::gemm_u8_i8` agnostic surface with `avx512vnni` / `avxvnni` / scalar arms (compile-time cfg chain). +- ✅ `int8_gemm_avxvnni_ymm` kernel (VEX `vpdpbusd` ymm). +- ✅ `int8_gemm_vnni_avx512` promoted to `pub(crate)` for direct dispatcher call. +- ✅ `.cargo/config-avx512.toml` → `sapphirerapids` (was bare v4 without VNNI). +- ✅ `simd_ops::array_windows` + `array_windows_checked` (overlapping const-size). +- ✅ `simd_ops::add_mul_f32` + `add_mul_f64` (slice-level FMA, polyfill-routed). +- ✅ "Foundation primitives — do not remove" doc-callout in `simd_ops.rs`. +- ✅ Bench harness (`bench_gemm_u8_i8_vs_scalar`, `#[ignore]`'d). +- ✅ MX-T1a — `add_i8` / `sub_i8` / `add_i16` lifted from scalar to polyfilled + `I8x64` / `I8x16` / `I16x32` / `I16x8` (matrix § C cells flipped). + +### Design rule for AMX / F16 / FP16 paths: inline asm-byte encoding + +> **Hard constraint for Phases 1b (AMX-INT8), 3b (AVX-512-FP16), +> 3c (NEON BF16+FP16), 4d (AMX-FP16):** every instruction that lacks +> stable Rust intrinsics on the project's pinned 1.95 stable toolchain +> MUST be emitted via raw-`.byte`-string inline asm, matching the +> pattern already proven in `src/simd_amx.rs` (lines 16-19 of its +> module docs). Rationale: +> +> 1. **AMX intrinsics are nightly-only** (Rust issue #126622). The +> project pins Rust 1.95 stable per `CLAUDE.md` line 9. The +> existing `simd_amx.rs` lifts AMX onto stable today via +> `asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc0", options(nostack, nomem))` +> for TILEZERO and equivalent encodings for TDPBUSD / TDPBF16PS. +> 2. **AVX-512-FP16 intrinsics** (`_mm512_add_ph`, `_mm512_fmadd_ph`, +> `vcvtph2ps`/`vcvtps2ph` zmm forms) — historically have had +> stabilization churn. Asm-byte encoding skips the version dance. +> 3. **NEON FP16** (FMLA `v.8h`, BFDOT, BFMMLA, USDOT) — likewise +> nightly-gated for several Rust releases. The existing +> `simd_neon_bf16.rs` and `simd_neon_dotprod.rs` stub files (TD-T10 +> / TD-T11) are placeholders meant to be filled with asm-byte +> encodings per the same pattern. +> +> Concrete recipe: +> +> ```rust +> #[cfg(target_arch = "x86_64")] +> #[target_feature(enable = "amx-tile,amx-int8")] +> unsafe fn tdpbusd_t0_t1_t2() { +> // TDPBUSD tmm0, tmm1, tmm2 — opcode VEX C4 E2 73 5E C1 +> // 5E = TDPBUSD, prefix bits = unsigned-by-signed selector +> // C1 = ModR/M (tmm0 dest, tmm1 src1, tmm2 src2 via /r encoding) +> // The byte sequence is the canonical VEX form documented in +> // Intel SDM Vol. 2D § TDPBUSD; verify with `objdump -d` of a +> // gas-assembled stub the first time it lands. +> core::arch::asm!( +> ".byte 0xc4, 0xe2, 0x73, 0x5e, 0xc1", +> options(nostack, nomem) +> ); +> } +> ``` +> +> Same pattern for NEON F16: +> +> ```rust +> #[cfg(target_arch = "aarch64")] +> #[target_feature(enable = "neon,fp16")] +> unsafe fn fmla_v8h(_acc: &mut float16x8_t, _a: float16x8_t, _b: float16x8_t) { +> // FMLA v0.8h, v1.8h, v2.8h — encoding 0x0e40_cc20 | (Rd << 0) | (Rn << 5) | (Rm << 16) +> // Same byte-encoded pattern as simd_amx.rs uses for AMX on x86. +> core::arch::asm!( +> ".inst 0x0e42cc20", // FMLA v0.8h, v1.8h, v2.8h +> options(nostack, nomem) +> ); +> } +> ``` +> +> **Verification harness:** each newly-encoded instruction lands with an +> `objdump -d` check in the doc-comment showing the gas-disassembly +> matches the intended mnemonic. The first such verification in this +> project is recorded in `simd_amx.rs:16-19` ("verified working" line). +> +> **What this rule does NOT apply to:** instructions with already-stable +> intrinsics on Rust 1.95 — `_mm512_dpbusd_epi32` (avx512vnni), +> `_mm256_dpbusd_avx_epi32` (avxvnni), `_mm256_cvtph_ps` (F16C), +> `_mm512_cvtne2ps2bf16` (avx512bf16). Those continue to use the +> intrinsics directly per the existing `simd_avx512.rs` patterns. + +### Phase 1 — Wire what already exists (highest ROI per audit) + +P0 — closes 7 of 22 audit findings. From `td-simd-integration-plan.md` Phase 1, refined with this matrix's findings: + +| Task | Surface affected | Change | Effort | +|---------|--------------------------------|--------|--------| +| TD-T1 | `hpc::amx_matmul::matmul_bf16_to_f32` | Route AMX arm through `bf16_tile_gemm_16x16` instead of scalar `bf16_gemm_f32` | 1h | +| TD-T2 | `hpc::amx_matmul::matmul_f32` | AMX arm: convert to BF16, call tile kernel — drop duplicate scalar call | 30m | +| TD-T3 | `hpc::amx_matmul::matmul_i8_to_i32` | AMX arm wires `tile_dpbusd`; non-AMX arm uses `int8_gemm_vnni` instead of scalar | 1.5h | +| TD-T4 | `hpc::quantized::bf16_gemm_f32` | Rewrite using `F32x16::mul_add` over decoded BF16 rows | 3h | +| TD-T6 | `backend::native::avx2::{scal,nrm2,asum}_f32/f64` | Replace scalar delegations with real `_mm256_*` intrinsics | 2h | +| TD-T7 | `backend::native::gemv_f32/f64` | Wire through `dispatch!` macro to AVX-512/AVX2 row-dot kernels | 2h | + +**Plus from this matrix (new):** + +| Task | Surface affected | Change | Effort | +|---------|--------------------------------|--------|--------| +| MX-T1 | `simd_int_ops::{add_i8, sub_i8, add_i16, dot_i8, dot_i16}` | Lift from scalar to polyfilled `I8x{32,64}` / `I16x{16,32}` ops. They already exist as types on every backend; just route the slice ops through them. | 3h | +| MX-T2 | `simd::cast_f32_to_bf16_batch` | Currently scalar — route to existing `f32_to_bf16_batch_rne` (AVX-512-F-only; works on every AVX-512 CPU) when available, scalar otherwise. | 30m | +| MX-T3 | `simd::cast_f32_to_f16_batch` | Add F16C (`vcvtps2ph`) fast path — stable since 2012 Ivy Bridge — currently scalar on every x86 profile. | 2h | + +**Phase 1 total: ~15–18h.** Closes all 7 CRITICAL audit findings plus the +three new "low-hanging integer/cast" wins surfaced here. + +### Phase 2 — aarch64 fills (Pi 5 / Apple M-series silicon ceiling) + +From `td-simd-integration-plan.md` Phase 2, restated: + +| Task | Surface | Change | Effort | +|---------|---------|--------|--------| +| TD-T10 | `simd_neon_bf16::BF16x{8,16}Stub` → real `bfloat16x8_t` pairs, BFDOT via asm-byte, BFMMLA wiring | Live BF16 NEON arithmetic | 4h | +| TD-T11 | `simd_neon_dotprod::F16x16Stub` → real `float16x8_t` pair via asm-byte FMLA `v.8h` | Live FP16 NEON arithmetic | 4h | +| TD-T21 | `simd::*` aarch64 integer re-exports (currently scalar polyfill from `simd_scalar::*`) → real NEON quartets | Live integer NEON for I32x8, U8x64 etc. | 8h | +| TD-T8 | `hpc::simd_dispatch` aarch64 dispatch — currently `Self::scalar()` → real NEON wrappers | byte_find_all_neon, byte_count_neon, … | 6h | +| MX-T4 | `simd_int_ops::gemm_u8_i8` NEON arm | New `int8_gemm_sdot_neon` kernel using `vdotq_s32` + +128-bias for u8×i8 | 4h | + +**Phase 2 total: ~26h.** Requires aarch64 CI runner / cross-compile verification (Pi 5 or Apple M-series). + +### Phase 3 — `SimdProfile` dispatch foundation + +From `td-simd-integration-plan.md` Phase 3 — unchanged: + +| Task | Surface | Change | Effort | +|---------|---------|--------|--------| +| T3.1 | `src/hpc/simd_profile.rs` (new) | `SimdProfile` enum + `detect()` per dispatch matrix | 3h | +| T3.2 | `Cargo.toml` features + `.cargo/config-{profile}.toml` per silicon profile | `cpu-spr`, `cpu-icx`, …, mutually exclusive | 4h | +| T3.3 | `src/hpc/gemm_dispatch.rs` (new) | First `*Dispatch` table — `bf16_gemm`, `int8_gemm`, `f32_gemv` | 4h | +| T3.4 | `src/hpc/blas1_dispatch.rs` (new) | `Blas1Dispatch` for dot/axpy/scal/nrm2/asum f32/f64 | 3h | +| T3.5 | `backend::native::dispatch!` | Migrate from local `Tier` to `simd_profile()` | 2h | +| T3.6 | `simd::tier()` | Alias to `simd_profile().coarse()` (preserve callers) | 2h | +| T3.7 | `hpc::simd_dispatch::detect()` | Migrate to `simd_profile()`; add Avx512f-only, AvxVnniInt8, IceLakeSp dispatches | 3h | +| MX-T5 | `simd_int_ops::gemm_u8_i8` | Migrate cfg chain to `GemmDispatch.int8_gemm` pointer (both compile-time pin and LazyLock-once modes) | 2h | + +**Phase 3 total: ~23h.** Provides the framework for Phase 4 and removes +the three duplicate Tier enums (TD-T12/T13/T14). + +### Phase 4 — Intra-bucket SIMD fills (parallelizable) + +Each task is one PR. Restated from `td-simd-integration-plan.md` Phase 4 +with priority rebalanced based on this matrix: + +| Task | Profile unlocking it | Surface that gets faster | Effort | +|---------|--------------------------|--------------------------|--------| +| MX-F1 (HOT) | SPR/GNR/CPL/Z4/Z5 | `BF16x16::add/sub/mul/fma` via `vdpbf16ps`-style F32x16 mul_add (drop scalar f32 round-trip) | 4h | +| MX-F2 (HOT) | All x86 (F16C stable since 2012) | `F16x16::to_f32x16` + `add/sub/mul/fma` via `vcvtph_ps`/`vcvtps_ph` round-trip + F32x16 ops | 4h | +| MX-F3 (HOT) | A76 + (arm fp16) | `F16x16` arm with FMLA `v.8h` asm-byte | 3h | +| MX-F4 | SPR/GNR (avx512fp16) | Native `F16x{8,16}` `__m{256,512}h` storage on Sapphire+/Granite (skips F32 round-trip)| 6h | +| MX-F5 | All AVX-512F | `simd_ln_f32` Remez polynomial (currently scalar everywhere) | 3h | +| MX-F6 | All AVX-512BW | `nibble_unpack`, `nibble_above_threshold` 2× width — TD-T16 | 2h | +| MX-F7 | HSW | `nibble_unpack_avx2` real `_mm256_*` (TD-T17) | 2h | +| MX-F8 | All AVX-512F | `distance::squared_distances_f32` 16-wide L2 (TD-T19) | 2h | +| MX-F9 | All AVX-512F | `spatial_hash::batch_sq_dist` 16-wide (TD-T20) | 2h | +| MX-F10 | IceLakeSp+/SPR/GNR/Z4/Z5 | VPOPCNTDQ paths — Hamming/popcount audit | 4h | +| MX-F11 | IceLakeSp+/SPR/GNR/Z4/Z5 | VBMI byte-permute audit beyond `simd_avx512.rs:695` | 4h | +| MX-F12 | IceLakeSp+/Z4/Z5 | GFNI bitmatrix multiply audit | 6h | +| MX-F13 | ARL/GNR | `simd_int_ops::gemm_i8` (s8×s8 → i32) via `vpdpbssd` ymm/zmm — NEW agnostic surface | 4h | +| MX-F14 | ARL/GNR/A76(+usdot) | `simd_int_ops::gemm_u8` (u8×u8 → u32) via `vpdpbuud` / NEON `udot` | 4h | +| MX-F15 | SPR/GNR (amx-int8) | AMX arm of `simd_int_ops::gemm_u8_i8` — `tile_dpbusd` 16×16 (the kernel exists in `bf16_tile_gemm.rs`-shape, needs INT8 sibling) | 6h | +| MX-F16 | GNR (amx-fp16) | AMX-FP16 `tdpfp16ps` — gated on CPUID.07H.1H:EAX[21], needs SimdCaps extension | 4h | + +**Phase 4 total: ~60h, parallelizable.** Every task is gated on Phase 3's +`SimdProfile` infrastructure but otherwise independent. Land in any order. + +### Phase 5 — BLAS-graph GEMM kernel polish (the JIT-parity zone) + +The kernels that the user's earlier session brought to within "a few %" of +a Cranelift-JIT inner loop, via `array_chunks` + `array_windows` + the +polyfilled `mul_add` + `add_mul_*`. Once Phases 1–4 land, this phase +verifies that no per-CPU regression has crept in vs the historical baseline: + +| Task | Surface | Action | Effort | +|---------|---------|--------|--------| +| MX-P1 | `gemm_u8_i8` bench | Land the `#[ignore]` bench from Phase 0 as a published `benches/int8_gemm.rs` criterion bench so CI can detect regressions per arm | 2h | +| MX-P2 | `gemm_u8_i8` AMX path | Verify AMX kernel reaches ≥ 2× of avx512vnni zmm on SPR (audit's expected 256:64 mul-add ratio) | 2h | +| MX-P3 | `add_mul_f32` bench | Add as `benches/blas1.rs` — compare to scalar reference and to `f32::mul_add` per-element loop. Floor: SIMD ≥ 4× scalar at length ≥ 256 on each arm | 2h | +| MX-P4 | `bgz17_bridge` GEMM | Re-bench against JIT path (now retired). Confirm the original within-a-few-% gap still holds with the post-Phase-4 polyfill | 4h | +| MX-P5 | NO_REMOVE doc audit | Walk `simd_ops.rs`, `simd_int_ops.rs`, `simd_half.rs`, `simd_soa.rs`. Confirm every helper that bench-shows ≥ 1.5× over scalar has a "Foundation primitive — do not remove" call-out with the bench number cited inline | 1h | + +**Phase 5 total: ~11h.** + +### Phase 6 — Future / out-of-current-scope + +| Item | Why deferred | +|--------------------------------|--------------| +| `gemm_f32` BLAS-3 | `matrixmultiply` workspace dep handles this — wrapping it is API design, not SIMD work | +| GPU offload | Out of scope per CLAUDE.md "HPC Rust transformation" charter | +| Cranelift-JIT GEMM revival | Dropped after the BLAS-graph polyfill reached parity — only reconsider if Phase 5 shows > 5% gap | +| `wasm32` SIMD128 backend | `core::simd` via `nightly-simd` covers it; no per-target intrinsic wiring planned | +| RISC-V Vector extension | `core::simd` ditto | +| Multi-core threading | `matrixmultiply-threading` feature exists; deeper threading is a separate phase | + +--- + +## K. How to read this doc + +1. **Picking the cfg config for a deployment:** find your CPU profile column. + Cells with ✅ on that column are wired. Cells with ⏳ are the speedups + that landed kernels but didn't wire (low-hanging gains). +2. **Adding a new agnostic surface:** copy the `simd_int_ops::gemm_u8_i8` + pattern — compile-time `#[cfg(target_feature)]` chain on `simd_int_ops` + (the entry point), kernels in `hpc::vnni_gemm` / `hpc::neon_dotprod_gemm` + / etc., scalar fallback as the universal arm. +3. **Verifying a per-CPU lowering is correct:** run the matching + `bench_*_vs_scalar` ignored test under `RUSTFLAGS='-Ctarget-cpu=$CPU'` + — the runner must have the silicon to execute the emitted instructions + (Sapphire Rapids covers everything down to and including A76's intrinsic + semantics; aarch64 needs a separate runner). +4. **Spotting matrix drift:** when adding a new public symbol to + `crate::simd::*`, this table must grow a row. Reviewers should reject + PRs that add a public symbol without a corresponding matrix entry. + +## L. Provenance + +- CPU feature presence: sourced from `td-simd-cpu-dispatch-matrix.md`. +- Audit findings (TD-T*): sourced from `td-simd-tier-audit.md`. +- Phase 1–4 effort estimates: cross-referenced with + `td-simd-integration-plan.md`; new MX-T* / MX-F* items estimated in this + doc. +- Polyfilled type backing: read directly from `src/simd.rs` lines 197–366 + (cfg-gated re-exports per `target_feature`), `src/simd_avx512.rs` + re-exports at 2260, `src/simd_avx2.rs` (256-bit polyfills), `src/simd_neon.rs` + paired-load wrappers, `src/simd_scalar.rs` arrays. +- Surface function inventory: read directly from + `src/simd_ops.rs`, `src/simd_int_ops.rs`, `src/simd_half.rs`, + `src/simd_soa.rs`, `src/simd.rs` re-exports. +- No grep / tail / head sampling — every entry traceable to a full-file + Read per the workspace rule. diff --git a/src/hpc/amx_matmul.rs b/src/hpc/amx_matmul.rs index a6838f0f..4c64177d 100644 --- a/src/hpc/amx_matmul.rs +++ b/src/hpc/amx_matmul.rs @@ -297,8 +297,15 @@ fn write_contig(view: &mut ArrayViewMut2<'_, A>, src: &[A]) { /// Matrix multiply BF16 × BF16 → f32: `out = lhs · rhs`. /// -/// Uses AMX `TDPBF16PS` (256 mul-adds per instruction) when available, -/// otherwise falls back to [`bf16_gemm_f32`]. +/// On AMX hardware (Sapphire Rapids+, Granite Rapids), 16×16-aligned tiles +/// dispatch to [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] which +/// emits `TDPBF16PS` via the asm-byte path in `simd_amx.rs` — 256 +/// BF16×BF16 multiply-accumulates per instruction (16×16×32 = 8 192 FLOPs) +/// into f32 accumulator tiles. M/N/K tail blocks (when any dim isn't +/// 16/16/32-aligned) fall through to the validated scalar +/// [`crate::hpc::quantized::bf16_gemm_f32`] reference. +/// +/// On non-AMX hosts the entire matmul goes through `bf16_gemm_f32`. /// /// `out` must be row-contiguous (column stride = 1); inputs may be strided. pub fn matmul_bf16_to_f32( @@ -310,26 +317,180 @@ pub fn matmul_bf16_to_f32( let b = pack_contig(&rhs); let mut c = vec![0.0f32; m * n]; - // AMX path: a tiled 16×16 kernel exists in `bf16_tile_gemm` for sizes that - // fit cleanly. For any leftover tail (or hosts without AMX), defer to the - // scalar `bf16_gemm_f32`. The tile kernel itself is maintained alongside - // the low-level primitives at the top of this file; the public surface - // intentionally goes through the validated scalar path so we always - // produce a numerically-stable f32 result. - if amx_available() { - // Future: AMX-tiled fast path. Today we route through the same - // f32 reference kernel; correctness is identical regardless of - // hardware. The `amx_available()` branch is preserved so callers - // can be sure the AMX detection runs. - bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0); - } else { - bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0); - } + bf16_gemm_dispatch(&a, &b, &mut c, m, n, k); write_contig(&mut out, &c); Ok(()) } +/// BF16 × BF16 → f32 GEMM with three-tier dispatch (AMX → VDPBF16PS → scalar). +/// +/// Inputs are packed row-major (`a` is M × K, `b` is K × N). Output `c` +/// is M × N row-major and is overwritten (not accumulated). +/// +/// Tier selection: +/// +/// 1. **AMX `TDPBF16PS`** (Sapphire Rapids+, Granite Rapids) when +/// `amx_available()` is true AND shapes are 16/16/32-aligned. +/// Dispatches through +/// [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] → +/// `simd_amx::tile_dpbf16ps` via asm-byte (`TDPBF16PS` intrinsic is +/// nightly-only on Rust 1.95). 8 192 BF16×BF16 multiplies + 256 f32 +/// accumulates per instruction. +/// 2. **`VDPBF16PS`** (Cooper Lake, Cascade Lake AVX-512BF16, Zen 4+) +/// when `is_x86_feature_detected!("avx512bf16")` is true. The +/// intrinsic `_mm512_dpbf16_ps` is stable on Rust 1.95 (no asm-byte +/// needed). Per instruction: 32 BF16×BF16 multiplies + 16 f32 +/// accumulates, single-rounded. Handles arbitrary shapes — M / N +/// tails fall through the per-iteration j-block trimming; K-tail +/// (odd K) is handled with a final scalar pair. +/// 3. **Scalar reference** [`bf16_gemm_f32`] for hosts without either +/// extension or for shapes the AMX arm rejects. +/// +/// The per-tier dispatch table comes from PR #180's BF16 GEMM column. +fn bf16_gemm_dispatch(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) { + if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 32 == 0 { + // SAFETY: BF16 is `#[repr(transparent)] struct BF16(pub u16)` + // (per `hpc::quantized::BF16`). Reinterpreting `&[BF16]` as + // `&[u16]` is bit-pattern preserving. + let a_u16: &[u16] = unsafe { core::slice::from_raw_parts(a.as_ptr() as *const u16, a.len()) }; + + // B is packed row-major K × N; the 16×16 tile kernel wants a + // K × 16 contiguous sub-block. Extract per (j_tile) into a + // scratch buffer once and reuse across i_tile. + let mut b_tile = vec![0u16; k * 16]; + let mut tile_c = vec![0.0f32; 256]; + + for j_tile in (0..n).step_by(16) { + // Pack b[0..k, j_tile..j_tile+16] into row-major 16-wide K-rows. + for kk in 0..k { + let row = kk * n + j_tile; + for jj in 0..16 { + b_tile[kk * 16 + jj] = b[row + jj].0; + } + } + for i_tile in (0..m).step_by(16) { + // A_tile = a[i_tile..i_tile+16, 0..k] — already contiguous + // since `a` is packed row-major M × K. + let a_tile = &a_u16[i_tile * k..(i_tile + 16) * k]; + tile_c.fill(0.0); + crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k); + // Write tile_c (16 × 16, row-major) into c (M × N, row-major). + for ii in 0..16 { + let dst_off = (i_tile + ii) * n + j_tile; + c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]); + } + } + } + return; + } + + #[cfg(target_arch = "x86_64")] + { + if std::is_x86_feature_detected!("avx512bf16") { + // SAFETY: feature-detected at runtime; the kernel is + // `#[target_feature(enable = "avx512bf16,avx512f")]`. + unsafe { + bf16_gemm_vdpbf16ps(a, b, c, m, n, k); + } + return; + } + } + + bf16_gemm_f32(a, b, c, m, n, k, 1.0, 0.0); +} + +/// AVX-512BF16 BF16 GEMM using `_mm512_dpbf16_ps` (`VDPBF16PS`). +/// +/// One VDPBF16PS instruction: 16 f32 accumulator lanes each receive +/// `acc[j] += a.bf16[2j] * b.bf16[2j] + a.bf16[2j+1] * b.bf16[2j+1]`, +/// single-rounded. The kernel maps the 16 output lanes to a row of 16 +/// j-columns of C[i, ·], with one i row processed at a time and a K-pair +/// inner loop accumulating into the same 16 f32 lanes across iterations. +/// +/// B-column packing: VDPBF16PS wants the 32 B BF16s per call laid out +/// as 16 lane-pairs (lane j contains `B[2k_pair, j_base+j]` followed by +/// `B[2k_pair+1, j_base+j]`, packed into one u32). We pre-pack B for +/// the current j-block into `b_col_pairs[k_pair * 16 + j] = u32` once +/// per j_block and reuse across all i — amortizes the gather cost. +/// +/// K-tail (when K is odd) is handled with a final scalar BF16 multiply +/// per output cell; N-tail (when the j-block has < 16 valid columns) +/// is handled by trimming the store after the VDPBF16PS chain. +/// +/// # Safety +/// Caller must have feature-detected `avx512bf16` at runtime. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512bf16,avx512f")] +unsafe fn bf16_gemm_vdpbf16ps(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) { + use core::arch::x86_64::{ + __m512bh, __m512i, _mm512_dpbf16_ps, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_setzero_ps, _mm512_storeu_ps, + }; + + let k_pairs = k / 2; + let k_tail = k % 2; + + // SAFETY: BF16 is repr(transparent) over u16. + let a_u16: &[u16] = core::slice::from_raw_parts(a.as_ptr() as *const u16, a.len()); + let b_u16: &[u16] = core::slice::from_raw_parts(b.as_ptr() as *const u16, b.len()); + + // Pre-pack scratch: 16 u32 lanes per k_pair, holding (b_lo | b_hi << 16). + let mut b_col_pairs = vec![0u32; k_pairs.max(1) * 16]; + // Scratch for the 16-wide store + N-tail trim. + let mut out_buf = [0.0f32; 16]; + + for j_base in (0..n).step_by(16) { + let j_count = 16.min(n - j_base); + + // Pack B columns [j_base..j_base+j_count] in pair-interleaved layout. + // For lanes j >= j_count (the N-tail of this j_block), pad with 0 — + // they're not stored back, but the VDPBF16PS still touches them. + for k_pair in 0..k_pairs { + let row_lo = 2 * k_pair * n; + let row_hi = (2 * k_pair + 1) * n; + for jj in 0..j_count { + let b_lo = b_u16[row_lo + j_base + jj] as u32; + let b_hi = b_u16[row_hi + j_base + jj] as u32; + b_col_pairs[k_pair * 16 + jj] = (b_hi << 16) | b_lo; + } + for jj in j_count..16 { + b_col_pairs[k_pair * 16 + jj] = 0; + } + } + + for i in 0..m { + let mut acc = _mm512_setzero_ps(); + let a_row_off = i * k; + for k_pair in 0..k_pairs { + // Broadcast A[i, 2k_pair..2k_pair+2] as the (BF16 lo, BF16 hi) + // pair across all 16 lanes. + let a_lo = a_u16[a_row_off + 2 * k_pair] as u32; + let a_hi = a_u16[a_row_off + 2 * k_pair + 1] as u32; + let pair = (a_hi << 16) | a_lo; + let a_bh: __m512bh = core::mem::transmute(_mm512_set1_epi32(pair as i32)); + let b_bh: __m512bh = + core::mem::transmute(_mm512_loadu_si512(b_col_pairs.as_ptr().add(k_pair * 16) as *const __m512i)); + acc = _mm512_dpbf16_ps(acc, a_bh, b_bh); + } + _mm512_storeu_ps(out_buf.as_mut_ptr(), acc); + + // K-tail: one extra scalar BF16 multiply for k = k_pairs*2. + if k_tail == 1 { + let a_last_f32 = BF16(a_u16[a_row_off + k - 1]).to_f32(); + let tail_row = (k - 1) * n; + for jj in 0..j_count { + let b_last_f32 = BF16(b_u16[tail_row + j_base + jj]).to_f32(); + out_buf[jj] += a_last_f32 * b_last_f32; + } + } + + // Store the j_count valid lanes (drops N-tail padding lanes). + let dst_off = i * n + j_base; + c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]); + } + } +} + // ── f32 → f32 (BF16 compute on AMX) ──────────────────────────────────────── /// Matrix multiply f32 × f32 → f32: `out = lhs · rhs`. @@ -349,10 +510,13 @@ pub fn matmul_f32( let mut c = vec![0.0f32; m * n]; if amx_available() { - // AMX path: down-cast to BF16, run BF16 GEMM, accumulate in f32. + // AMX path: down-cast to BF16 (RNE, ~1 ULP at BF16 mantissa + // precision), then dispatch through the shared BF16 helper + // which picks `TDPBF16PS` tile kernel for 16/16/32-aligned + // shapes and the scalar `bf16_gemm_f32` reference otherwise. let a_bf16: Vec = a_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect(); let b_bf16: Vec = b_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect(); - bf16_gemm_f32(&a_bf16, &b_bf16, &mut c, m, n, k, 1.0, 0.0); + bf16_gemm_dispatch(&a_bf16, &b_bf16, &mut c, m, n, k); } else { // Pure f32 reference path. for i in 0..m { diff --git a/src/hpc/vnni_gemm.rs b/src/hpc/vnni_gemm.rs index b156a8b2..05c08212 100644 --- a/src/hpc/vnni_gemm.rs +++ b/src/hpc/vnni_gemm.rs @@ -89,9 +89,20 @@ pub fn has_vnni() -> bool { /// B[p+2,j+L], B[p+3,j+L]]. /// - We pre-pack B into VNNI layout: b_packed[p/4][j..j+16] where each i32 /// contains 4 bytes from consecutive rows. +/// AVX-512 VNNI INT8 GEMM kernel — `pub(crate)` so the agnostic +/// `simd_int_ops::gemm_u8_i8` surface can call it directly under a +/// compile-time `target_feature = "avx512vnni"` gate, bypassing the +/// per-call caps branch in [`int8_gemm_vnni`]. See § "compile-time +/// dispatch table" in `.claude/knowledge/td-simd-integration-plan.md`. +/// +/// # Safety +/// +/// Caller must guarantee the CPU supports AVX-512F + AVX-512VNNI + +/// AVX-512BW. Compile-time gating via `#[cfg(target_feature = …)]` at +/// the call site is the standard contract. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx512f,avx512vnni,avx512bw")] -unsafe fn int8_gemm_vnni_avx512(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { +pub(crate) unsafe fn int8_gemm_vnni_avx512(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { use core::arch::x86_64::*; // Zero output @@ -191,6 +202,95 @@ unsafe fn int8_gemm_vnni_avx512(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: } } +// ── AVX-VNNI (ymm) inner kernel ────────────────────────────────────────── + +/// AVX-VNNI (256-bit ymm) INT8 GEMM kernel. +/// +/// VEX-encoded `VPDPBUSD` over 8-wide i32 accumulators, for Alder Lake / +/// Arrow Lake / Zen 4 / Sapphire Rapids (whenever the dispatcher resolves +/// to AVX2 + AVX-VNNI without selecting the AVX-512 zmm path). Half the +/// lane count of [`int8_gemm_vnni_avx512`], and the VEX encoding has no +/// masked load/store, so the column tail (`n % 8 != 0`) runs scalar. +/// +/// `pub(crate)` so [`crate::simd_int_ops::gemm_u8_i8`] can target it +/// directly under a compile-time `target_feature = "avxvnni"` gate. +/// +/// # Safety +/// +/// Caller must guarantee the CPU supports AVX + AVX2 + AVX-VNNI. +/// Compile-time gating via `#[cfg(target_feature = "avxvnni")]` at the +/// call site is the standard contract. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx,avx2,avxvnni")] +pub(crate) unsafe fn int8_gemm_avxvnni_ymm(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + use core::arch::x86_64::*; + + // Zero output + for v in c.iter_mut() { + *v = 0; + } + + // Pre-pack B into VNNI layout: groups of 4 rows, each i32 lane holds + // [b[p+0,j], b[p+1,j], b[p+2,j], b[p+3,j]] as 4 bytes. + let k_groups = (k + 3) / 4; + let mut b_packed = vec![0i32; k_groups * n]; + + for pg in 0..k_groups { + let p_base = pg * 4; + for j in 0..n { + let mut bytes = [0u8; 4]; + for q in 0..4 { + let p = p_base + q; + if p < k { + bytes[q] = b[p * n + j] as u8; + } + } + b_packed[pg * n + j] = i32::from_le_bytes(bytes); + } + } + + // Main GEMM loop — 8 i32 columns per ymm register. + for i in 0..m { + let mut j = 0; + while j + 8 <= n { + let mut acc = _mm256_setzero_si256(); + + for pg in 0..k_groups { + let p_base = pg * 4; + + let mut a_bytes = [0u8; 4]; + for q in 0..4 { + let p = p_base + q; + if p < k { + a_bytes[q] = a[i * k + p]; + } + } + let a_val = u32::from_le_bytes(a_bytes) as i32; + let a_broadcast = _mm256_set1_epi32(a_val); + + let b_ptr = b_packed.as_ptr().add(pg * n + j); + let b_vec = _mm256_loadu_si256(b_ptr as *const __m256i); + + // VEX-encoded VPDPBUSD: acc += dot4(a_broadcast, b_vec) per lane. + acc = _mm256_dpbusd_avx_epi32(acc, a_broadcast, b_vec); + } + + _mm256_storeu_si256(c.as_mut_ptr().add(i * n + j) as *mut __m256i, acc); + j += 8; + } + + // Scalar tail for `n - j < 8` columns — no masked ymm VPDPBUSD on VEX. + while j < n { + let mut sum = 0i32; + for p in 0..k { + sum += (a[i * k + p] as i32) * (b[p * n + j] as i32); + } + c[i * n + j] = sum; + j += 1; + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/simd.rs b/src/simd.rs index f96ea9d3..ce449991 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -503,7 +503,14 @@ pub use crate::hpc::fingerprint::{ // PR-X1 — SoA carrier + const-size slice helpers, dispatched from their // respective `simd_{type}.rs` modules. The W1a consumer contract forbids // reaching past `crate::simd::*` into the implementation modules directly. -pub use crate::simd_ops::{array_chunks, array_chunks_checked}; +// +// `array_chunks` (non-overlapping) and `array_windows` (overlapping) are +// the stable-Rust foundation primitives for SIMD-staged kernels — together +// with `add_mul_f32` / `add_mul_f64` below, they reach within a few % +// of a Cranelift-JIT'd inner loop on the BLAS-graph GEMM path and are +// the reason the JIT-native option was deemed unnecessary. See the +// "Foundation primitives — do not remove" notice in `src/simd_ops.rs`. +pub use crate::simd_ops::{array_chunks, array_chunks_checked, array_windows, array_windows_checked}; pub use crate::simd_soa::MultiLaneColumn; pub use crate::hpc::quantized::{ @@ -542,8 +549,8 @@ pub use crate::hpc::heel_f64x8::cosine_f32_to_f64_simd; // Elementwise slice ops — polyfill-dispatched (F32x16/F64x8 chunks + scalar tail). #[cfg(feature = "std")] pub use crate::simd_ops::{ - add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_scalar_f32, div_f32, div_f32_inplace, mul_f32, - mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32, sub_f32_inplace, + add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_mul_f32, add_mul_f64, add_scalar_f32, div_f32, + div_f32_inplace, mul_f32, mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32, sub_f32_inplace, }; // ============================================================================ diff --git a/src/simd_int_ops.rs b/src/simd_int_ops.rs index e6637f39..b78ffe85 100644 --- a/src/simd_int_ops.rs +++ b/src/simd_int_ops.rs @@ -19,30 +19,162 @@ /// Element-wise `dst[i] += src[i]` (wrapping i8 add). /// -/// Panics if `dst.len() != src.len()`. +/// Dispatches to the widest available SIMD lane: +/// +/// | Backend | Lane | Per-iteration intrinsic | +/// |------------|---------|-------------------------| +/// | x86_64 | `I8x64` | `_mm512_add_epi8` zmm (AVX-512-BW) / 2× `_mm256_add_epi8` ymm (AVX2 polyfill of `I8x64`) | +/// | aarch64 | `I8x16` | `vaddq_s8` × N | +/// | other | scalar | `i8::wrapping_add` lane-by-lane | +/// +/// Wrapping arithmetic. Panics if `dst.len() != src.len()`. #[inline] pub fn add_i8(dst: &mut [i8], src: &[i8]) { assert_eq!(dst.len(), src.len(), "add_i8: length mismatch"); - for i in 0..dst.len() { - dst[i] = dst[i].wrapping_add(src[i]); + let n = dst.len(); + + #[cfg(target_arch = "x86_64")] + { + use crate::simd::I8x64; + const L: usize = 64; + let chunks = n / L; + for c in 0..chunks { + let off = c * L; + let d = I8x64::from_slice(&dst[off..]); + let s = I8x64::from_slice(&src[off..]); + let arr = (d + s).to_array(); + dst[off..off + L].copy_from_slice(&arr); + } + for i in (chunks * L)..n { + dst[i] = dst[i].wrapping_add(src[i]); + } + } + + #[cfg(target_arch = "aarch64")] + { + use crate::simd_neon::I8x16; + const L: usize = 16; + let chunks = n / L; + for c in 0..chunks { + let off = c * L; + let d = I8x16::from_slice(&dst[off..]); + let s = I8x16::from_slice(&src[off..]); + let arr = d.add(s).to_array(); + dst[off..off + L].copy_from_slice(&arr); + } + for i in (chunks * L)..n { + dst[i] = dst[i].wrapping_add(src[i]); + } + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + for i in 0..n { + dst[i] = dst[i].wrapping_add(src[i]); + } } } /// Element-wise `dst[i] -= src[i]` (wrapping i8 sub). +/// +/// Dispatches the same way as [`add_i8`] (zmm AVX-512-BW / ymm AVX2 / +/// 128-bit NEON / scalar) using the polyfilled lane's `Sub` +/// implementation. #[inline] pub fn sub_i8(dst: &mut [i8], src: &[i8]) { assert_eq!(dst.len(), src.len(), "sub_i8: length mismatch"); - for i in 0..dst.len() { - dst[i] = dst[i].wrapping_sub(src[i]); + let n = dst.len(); + + #[cfg(target_arch = "x86_64")] + { + use crate::simd::I8x64; + const L: usize = 64; + let chunks = n / L; + for c in 0..chunks { + let off = c * L; + let d = I8x64::from_slice(&dst[off..]); + let s = I8x64::from_slice(&src[off..]); + let arr = (d - s).to_array(); + dst[off..off + L].copy_from_slice(&arr); + } + for i in (chunks * L)..n { + dst[i] = dst[i].wrapping_sub(src[i]); + } + } + + #[cfg(target_arch = "aarch64")] + { + use crate::simd_neon::I8x16; + const L: usize = 16; + let chunks = n / L; + for c in 0..chunks { + let off = c * L; + let d = I8x16::from_slice(&dst[off..]); + let s = I8x16::from_slice(&src[off..]); + let arr = d.sub(s).to_array(); + dst[off..off + L].copy_from_slice(&arr); + } + for i in (chunks * L)..n { + dst[i] = dst[i].wrapping_sub(src[i]); + } + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + for i in 0..n { + dst[i] = dst[i].wrapping_sub(src[i]); + } } } /// Element-wise `dst[i] += src[i]` (wrapping i16 add). +/// +/// Dispatches to `I16x32` (AVX-512-BW `_mm512_add_epi16`) on x86_64, +/// `I16x8` (`vaddq_s16`) on aarch64, scalar otherwise. #[inline] pub fn add_i16(dst: &mut [i16], src: &[i16]) { assert_eq!(dst.len(), src.len(), "add_i16: length mismatch"); - for i in 0..dst.len() { - dst[i] = dst[i].wrapping_add(src[i]); + let n = dst.len(); + + #[cfg(target_arch = "x86_64")] + { + use crate::simd::I16x32; + const L: usize = 32; + let chunks = n / L; + for c in 0..chunks { + let off = c * L; + let d = I16x32::from_slice(&dst[off..]); + let s = I16x32::from_slice(&src[off..]); + let arr = (d + s).to_array(); + dst[off..off + L].copy_from_slice(&arr); + } + for i in (chunks * L)..n { + dst[i] = dst[i].wrapping_add(src[i]); + } + } + + #[cfg(target_arch = "aarch64")] + { + use crate::simd_neon::I16x8; + const L: usize = 8; + let chunks = n / L; + for c in 0..chunks { + let off = c * L; + let d = I16x8::from_slice(&dst[off..]); + let s = I16x8::from_slice(&src[off..]); + let arr = d.add(s).to_array(); + dst[off..off + L].copy_from_slice(&arr); + } + for i in (chunks * L)..n { + dst[i] = dst[i].wrapping_add(src[i]); + } + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + for i in 0..n { + dst[i] = dst[i].wrapping_add(src[i]); + } } } @@ -77,6 +209,94 @@ pub fn dot_i16(a: &[i16], b: &[i16]) -> i64 { acc } +// ──────────────────────────────────────────────────────────────────────── +// gemm_u8_i8 — agnostic u8 × i8 → i32 matrix multiply +// ──────────────────────────────────────────────────────────────────────── + +/// `C = A · B` where `A` is `M × K` `u8`, `B` is `K × N` `i8`, `C` is `M × N` +/// `i32` (row-major, overwritten — not accumulated). +/// +/// Agnostic consumer surface. Resolves at **compile time** to one kernel +/// per the active `target_feature` set; consumers never branch on CPU +/// capability and the chosen kernel is fully inlined at the call site. +/// +/// Build matrix (additive, filled in as paths land): +/// +/// | `target_feature` | Kernel | +/// |----------------------------|-------------------------------------------------------| +/// | `amx-int8` *(planned)* | AMX `TDPBUSD` 16×16 tile (Sapphire / Granite Rapids) | +/// | `avx512vnni` | `VPDPBUSD` zmm — 16 i32 lanes (CLX → Zen 4 / SPR) | +/// | `avxvnni` | `VPDPBUSD` ymm — 8 i32 lanes (Alder/Arrow Lake, Zen 4)| +/// | `neon,dotprod` *(planned)* | NEON `SDOT` (A76+ / Apple M-series) | +/// | *(none)* | Scalar reference [`hpc::quantized::int8_gemm_i32`] | +/// +/// Arm precedence is widest-vector-first: when several `target_feature` +/// flags are set simultaneously (e.g. Sapphire Rapids enables `avx512vnni` +/// AND `avxvnni`), the highest-bandwidth arm wins via `#[cfg]` ordering. +/// +/// Build configs: +/// +/// * Default `x86-64-v3` (no VNNI) → scalar arm. Same result as calling +/// [`crate::hpc::quantized::int8_gemm_i32`] directly. +/// * `--config .cargo/config-avx512.toml` (= Sapphire Rapids, includes +/// VNNI + BF16 + FP16 + AMX) → the `avx512vnni` zmm arm. The future +/// `amx-int8` arm, once landed, will preempt this on the same config. +/// * `-Ctarget-cpu=cascadelake` / `znver4` → also lands in the +/// `avx512vnni` zmm arm (no AMX, no BF16). +/// * `RUSTFLAGS='-Ctarget-feature=+avxvnni'` on an AVX2 baseline → +/// the `avxvnni` ymm arm (Arrow Lake / Alder Lake without AVX-512). +/// +/// # Panics +/// +/// Panics if the slice lengths are inconsistent with the given dimensions. +#[inline] +pub fn gemm_u8_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + assert!(a.len() >= m * k, "gemm_u8_i8: a.len()={} < m*k={}", a.len(), m * k); + assert!(b.len() >= k * n, "gemm_u8_i8: b.len()={} < k*n={}", b.len(), k * n); + assert!(c.len() >= m * n, "gemm_u8_i8: c.len()={} < m*n={}", c.len(), m * n); + + // Compile-time dispatch chain. Exactly one arm survives per build; + // the others are stripped by `#[cfg]` so the compiler emits a direct + // call to the chosen kernel with no runtime branch. + + #[cfg(all(target_arch = "x86_64", target_feature = "avx512vnni"))] + { + // SAFETY: `target_feature = "avx512vnni"` at this site guarantees + // AVX-512F + VNNI + BW (the kernel's `#[target_feature(enable)]` + // set). The dispatcher is the safety invariant the kernel relies on. + unsafe { crate::hpc::vnni_gemm::int8_gemm_vnni_avx512(a, b, c, m, n, k) }; + return; + } + + #[cfg(all( + target_arch = "x86_64", + target_feature = "avxvnni", + not(target_feature = "avx512vnni"), + ))] + { + // SAFETY: `target_feature = "avxvnni"` at this site guarantees + // AVX + AVX2 + AVX-VNNI (the kernel's `#[target_feature(enable)]` + // set). Arm only fires when AVX-512 VNNI is *not* present — + // Alder Lake / Arrow Lake without AVX-512, or Zen 4 builds that + // pinned a ymm-only target. The dispatcher is the safety invariant. + unsafe { crate::hpc::vnni_gemm::int8_gemm_avxvnni_ymm(a, b, c, m, n, k) }; + return; + } + + // Fallback: scalar reference kernel. Always correct; same result the + // VNNI / AMX / SDOT paths produce when they land. Targets without an + // INT8 dot-product instruction (x86-64-v3 baseline without AVX-VNNI, + // ARMv8.0 without dotprod, wasm32, riscv) reach this arm at compile + // time. + #[cfg(not(any( + all(target_arch = "x86_64", target_feature = "avx512vnni"), + all(target_arch = "x86_64", target_feature = "avxvnni"), + )))] + { + crate::hpc::quantized::int8_gemm_i32(a, b, c, m, n, k); + } +} + // ──────────────────────────────────────────────────────────────────────── // min_i8 / max_i8 — horizontal reduction // ──────────────────────────────────────────────────────────────────────── @@ -367,4 +587,148 @@ mod tests { assert_eq!(min_i8(&s), i8::MAX); assert_eq!(max_i8(&s), i8::MIN); } + + // ── gemm_u8_i8 ──────────────────────────────────────────────────────── + + /// Independent scalar reference used to validate `gemm_u8_i8` against + /// the active compile-time dispatch arm (scalar or VNNI), without + /// going through `hpc::quantized::int8_gemm_i32` (which IS the scalar + /// arm — comparing against it on a v3 build would be tautological). + fn ref_gemm_u8_i8(a: &[u8], b: &[i8], m: usize, n: usize, k: usize) -> Vec { + let mut c = vec![0i32; m * n]; + for i in 0..m { + for p in 0..k { + let av = a[i * k + p] as i32; + for j in 0..n { + c[i * n + j] += av * b[p * n + j] as i32; + } + } + } + c + } + + #[test] + fn gemm_u8_i8_4x4_identity() { + let m = 4; + let n = 4; + let k = 4; + let a: Vec = (1..=16).collect(); + let b: Vec = vec![1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]; + let expected = ref_gemm_u8_i8(&a, &b, m, n, k); + let mut c = vec![99i32; m * n]; + gemm_u8_i8(&a, &b, &mut c, m, n, k); + assert_eq!(c, expected); + } + + #[test] + fn gemm_u8_i8_rectangular_3x5x8() { + let m = 3; + let n = 5; + let k = 8; + let a: Vec = (0..m * k).map(|i| (i % 200) as u8).collect(); + let b: Vec = (0..k * n).map(|i| (i % 100) as i8 - 50).collect(); + let expected = ref_gemm_u8_i8(&a, &b, m, n, k); + let mut c = vec![0i32; m * n]; + gemm_u8_i8(&a, &b, &mut c, m, n, k); + assert_eq!(c, expected); + } + + #[test] + fn gemm_u8_i8_17x17_tail() { + // Exercises the VNNI tail-masking path on AVX-512 builds and the + // scalar fallback on v3 builds. Same expected output either way. + let m = 17; + let n = 17; + let k = 17; + let a: Vec = (0..m * k).map(|i| ((i * 7 + 3) % 256) as u8).collect(); + let b: Vec = (0..k * n) + .map(|i| ((i * 11 + 5) % 256) as u8 as i8) + .collect(); + let expected = ref_gemm_u8_i8(&a, &b, m, n, k); + let mut c = vec![0i32; m * n]; + gemm_u8_i8(&a, &b, &mut c, m, n, k); + assert_eq!(c, expected); + } + + #[test] + fn gemm_u8_i8_extreme_values() { + // u8 = 255, i8 alternating ±127 stresses i32 accumulation across + // the AVX-512 tail path and the scalar reference. + let m = 4; + let n = 4; + let k = 8; + let a = vec![255u8; m * k]; + let b: Vec = (0..k * n) + .map(|i| if i % 2 == 0 { 127i8 } else { -128i8 }) + .collect(); + let expected = ref_gemm_u8_i8(&a, &b, m, n, k); + let mut c = vec![0i32; m * n]; + gemm_u8_i8(&a, &b, &mut c, m, n, k); + assert_eq!(c, expected); + } + + /// Sanity-check timing harness — run with: + /// cargo test --release simd_int_ops::tests::bench_gemm_u8_i8_vs_scalar \ + /// -- --ignored --nocapture + /// + /// Re-run under each cfg arm to confirm the kernel actually beats the + /// scalar reference (the question the user raised: "if AVX2 ends up + /// slower than scalar GEMM something isn't done right"): + /// # scalar arm (default v3) + /// cargo test --release ... + /// # avxvnni ymm arm + /// RUSTFLAGS='-Ctarget-cpu=alderlake' cargo test --release ... + /// # avx512vnni zmm arm + /// cargo --config .cargo/config-avx512.toml test --release ... + #[test] + #[ignore] + fn bench_gemm_u8_i8_vs_scalar() { + use std::time::Instant; + + let sizes = [(64usize, 64, 64), (128, 128, 128), (256, 256, 256), (512, 512, 512)]; + + for (m, n, k) in sizes { + let a: Vec = (0..m * k).map(|i| (i % 251) as u8).collect(); + let b: Vec = (0..k * n) + .map(|i| ((i % 127) as i8).wrapping_sub(63)) + .collect(); + let mut c_simd = vec![0i32; m * n]; + let mut c_scalar = vec![0i32; m * n]; + + // Warm-up — first call also resolves any one-time setup. + for _ in 0..2 { + gemm_u8_i8(&a, &b, &mut c_simd, m, n, k); + } + for _ in 0..2 { + crate::hpc::quantized::int8_gemm_i32(&a, &b, &mut c_scalar, m, n, k); + } + + // Iterations scale down with size to keep total time reasonable. + let iters = match m { + 0..=64 => 50, + 65..=128 => 10, + 129..=256 => 3, + _ => 1, + }; + + let t0 = Instant::now(); + for _ in 0..iters { + gemm_u8_i8(&a, &b, &mut c_simd, m, n, k); + } + let dt_simd = t0.elapsed() / iters; + + let t0 = Instant::now(); + for _ in 0..iters { + crate::hpc::quantized::int8_gemm_i32(&a, &b, &mut c_scalar, m, n, k); + } + let dt_scalar = t0.elapsed() / iters; + + assert_eq!(c_simd, c_scalar, "perf bench failed correctness at {m}x{n}x{k}"); + let speedup = dt_scalar.as_nanos() as f64 / dt_simd.as_nanos() as f64; + println!( + "gemm_u8_i8 {m:>3}x{n:>3}x{k:>3}: simd={:>10.3?} scalar={:>10.3?} speedup={speedup:>6.2}x", + dt_simd, dt_scalar, + ); + } + } } diff --git a/src/simd_ops.rs b/src/simd_ops.rs index db00e219..6c223b61 100644 --- a/src/simd_ops.rs +++ b/src/simd_ops.rs @@ -5,6 +5,20 @@ //! These ops inherit that dispatch — no platform-specific code here. //! //! Re-exported flat through `ndarray::simd::add_f32`, etc. +//! +//! # ⚠️ Foundation primitives — do not remove ⚠️ +//! +//! `array_chunks`, `array_chunks_checked`, `array_windows`, +//! `array_windows_checked`, `add_mul_f32`, and `add_mul_f64` are the +//! stable-Rust building blocks the BLAS-graph GEMM hand-rolled kernels +//! lean on. Combined with the polyfilled `F32x16::mul_add` / `F64x8::mul_add` +//! (defined on the SIMD types in `src/simd_*.rs`), they reach within a few +//! percent of a Cranelift-JIT'd inner loop on the bgz17 path — the reason +//! the JIT-native option was deemed unnecessary and dropped. **Prior +//! sessions have removed these helpers under the impression they were +//! unused cruft; the truth is they are the polyfill foundation. Do not +//! delete them without verifying every BLAS-graph kernel still compiles +//! AND that the JIT alternative has been re-evaluated.** use crate::simd::{F32x16, F64x8}; @@ -109,6 +123,75 @@ pub fn scale_f32_inplace(a: &mut [f32], scalar: f32) { } } +// ═══════════════════════════════════════════════════════════════════ +// Fused multiply-add into accumulator slice +// ═══════════════════════════════════════════════════════════════════ +// +// `add_mul_*` is the slice-level FMA-into-accumulator: `acc[i] += a[i] * b[i]`. +// Built directly on `F32x16::mul_add` / `F64x8::mul_add`, the polyfilled +// FMA primitives on the SIMD types (AVX-512: vfmadd231ps; AVX2 + FMA: +// _mm256_fmadd_ps; NEON: vfmaq_f32; scalar: f32::mul_add → also fused +// on FMA-capable hosts). Single rounding step, no intermediate +// `a[i] * b[i]` temporary in scalar emulation — same semantic as +// BLAS-1 `axpy` with a vector multiplier, and the dominant inner-loop +// shape in the BLAS-graph GEMM kernels. + +/// Fused multiply-add into accumulator: `acc[i] += a[i] * b[i]`. +/// +/// Operates on the prefix `min(acc.len(), a.len(), b.len())` lanes. +/// Each lane is computed with a single rounding step via the polyfilled +/// `F32x16::mul_add` (hardware FMA where available). On scalar fallback +/// targets, `f32::mul_add` is also fused if the host has FMA — never +/// less precise than the manual `acc[i] += a[i] * b[i]` two-step. +/// +/// # Examples +/// +/// ``` +/// use ndarray::simd::add_mul_f32; +/// let mut acc = vec![1.0f32; 17]; +/// let a = vec![2.0f32; 17]; +/// let b = vec![3.0f32; 17]; +/// add_mul_f32(&mut acc, &a, &b); +/// assert!(acc.iter().all(|&v| (v - 7.0).abs() < 1e-6)); +/// ``` +#[inline] +pub fn add_mul_f32(acc: &mut [f32], a: &[f32], b: &[f32]) { + let n = acc.len().min(a.len()).min(b.len()); + let mut i = 0; + while i + 16 <= n { + let va = F32x16::from_slice(&a[i..]); + let vb = F32x16::from_slice(&b[i..]); + let vacc = F32x16::from_slice(&acc[i..]); + va.mul_add(vb, vacc).copy_to_slice(&mut acc[i..]); + i += 16; + } + while i < n { + acc[i] = a[i].mul_add(b[i], acc[i]); + i += 1; + } +} + +/// Fused multiply-add into accumulator (f64): `acc[i] += a[i] * b[i]`. +/// +/// f64 sibling of [`add_mul_f32`]. Uses `F64x8::mul_add` (8-wide on +/// AVX-512 / 4-wide on AVX2-FMA / 2-wide on NEON / scalar `f64::mul_add`). +#[inline] +pub fn add_mul_f64(acc: &mut [f64], a: &[f64], b: &[f64]) { + let n = acc.len().min(a.len()).min(b.len()); + let mut i = 0; + while i + 8 <= n { + let va = F64x8::from_slice(&a[i..]); + let vb = F64x8::from_slice(&b[i..]); + let vacc = F64x8::from_slice(&acc[i..]); + va.mul_add(vb, vacc).copy_to_slice(&mut acc[i..]); + i += 8; + } + while i < n { + acc[i] = a[i].mul_add(b[i], acc[i]); + i += 1; + } +} + // ═══════════════════════════════════════════════════════════════════ // f64 binary ops // ═══════════════════════════════════════════════════════════════════ @@ -288,14 +371,24 @@ mod tests { } // ════════════════════════════════════════════════════════════════════ -// PR-X1 — Const-size non-overlapping slice chunk helpers +// PR-X1 — Const-size slice window helpers (non-overlapping + overlapping) // ════════════════════════════════════════════════════════════════════ // -// Slicing primitive for SIMD-staged inner loops. Naming: `array_chunks` -// (NOT `array_windows`) because `std::slice::array_windows::()` -// (nightly) is the **overlapping** iterator already referenced in -// `src/simd.rs` comments. These helpers are the **non-overlapping** -// variant, matching `std::slice::ArrayChunks` / stable `slice::as_chunks`. +// Two stable-Rust helpers for SIMD-staged inner loops: +// +// - `array_chunks::` — **non-overlapping** `&[T; N]` iterator +// (matches `std::slice::ArrayChunks` / stable `slice::as_chunks`). +// - `array_windows::` — **overlapping** `&[T; N]` iterator +// (matches nightly `std::slice::array_windows::()`; this file +// ships the stable equivalent until the std API lands). +// +// Both surface compile-time-fixed window sizes so call sites can +// directly feed `F32x16::from_array` / `F32x8::from_array` etc. and +// let the compiler infer the lane count. The pair lets BLAS-graph +// GEMM-style kernels iterate over a row of `B` (overlapping windows +// of the inner K dimension) and a column of `A` (non-overlapping +// chunks of the M dimension) in one source — the polyfill type +// resolves both to native SIMD width per target. /// Walk `data` as a sequence of non-overlapping const-size windows. /// @@ -372,6 +465,97 @@ pub fn array_chunks_checked(data: &[T]) -> Result(data)) } +/// Walk `data` as a sequence of **overlapping** const-size windows. +/// +/// Stable-Rust equivalent of nightly `std::slice::array_windows::()`. +/// For input length `L`, yields `L.saturating_sub(N - 1)` windows; each +/// window slides forward by exactly one element. +/// +/// Zero-cost: wraps `slice::windows(N)` with a `TryInto` array-ref +/// conversion that the compiler proves infallible (the inner slice is +/// guaranteed length `N`). +/// +/// # Edge case — `N == 0` +/// +/// `slice::windows(0)` panics. To preserve a panic-free surface, this +/// helper returns an empty iterator when `N == 0`. Strict-fallible +/// callers should use [`array_windows_checked`]. +/// +/// # Examples +/// +/// ``` +/// use ndarray::simd::array_windows; +/// let data: Vec = (0..6).collect(); +/// let ws: Vec<&[u8; 4]> = array_windows::(&data).collect(); +/// assert_eq!(ws.len(), 3); // 6 - 4 + 1 = 3 overlapping windows +/// assert_eq!(ws[0], &[0, 1, 2, 3]); +/// assert_eq!(ws[1], &[1, 2, 3, 4]); +/// assert_eq!(ws[2], &[2, 3, 4, 5]); +/// ``` +/// +/// # Examples — fewer elements than `N` +/// +/// ``` +/// use ndarray::simd::array_windows; +/// let data: Vec = (0..3).collect(); +/// let ws: Vec<&[u8; 4]> = array_windows::(&data).collect(); +/// assert_eq!(ws.len(), 0); +/// ``` +#[inline] +pub fn array_windows(data: &[T]) -> impl Iterator + '_ { + // Index-based iteration sidesteps `slice::windows(0)`'s panic — when + // N == 0 the count below evaluates to 0 and the iterator is empty. + let count = if N == 0 || data.len() < N { + 0 + } else { + data.len() - N + 1 + }; + (0..count).map(move |i| { + // `&data[i..i + N]` is exactly N elements by construction (bounds + // checked once when `count` was computed). `try_into` always + // succeeds on the inner conversion; the optimizer folds the + // length check out and lowers this to a pointer-cast. + <&[T; N]>::try_from(&data[i..i + N]).expect("computed window length == N") + }) +} + +/// Walk `data` as `&[T; N]` overlapping windows, returning `Err(())` if +/// the buffer is shorter than `N` or `N == 0`. +/// +/// Strict variant of [`array_windows`]: the consumer asserts up front +/// that the slice is long enough for at least one window and wants the +/// error surfaced rather than receiving an empty iterator. Useful for +/// kernels that pre-pad to `K + N - 1` and want to fail loudly if the +/// pad got dropped. +/// +/// # Edge case — `N == 0` +/// +/// Returns `Err(())`. The unchecked variant would silently produce an +/// empty iterator; the checked surface makes the misuse explicit. +/// +/// # Examples +/// +/// ``` +/// use ndarray::simd::array_windows_checked; +/// let data: Vec = (0..6).collect(); +/// let it = array_windows_checked::(&data).expect("6 >= 4"); +/// assert_eq!(it.count(), 3); +/// +/// let too_short: Vec = (0..3).collect(); +/// assert!(array_windows_checked::(&too_short).is_err()); +/// +/// // N == 0 surfaces as Err rather than yielding an empty iterator. +/// assert!(array_windows_checked::(&[0u8; 8]).is_err()); +/// ``` +#[inline] +#[allow(clippy::result_unit_err)] // matches PR-X1 design § 3 `Result<_, ()>` contract; no error variants needed +pub fn array_windows_checked(data: &[T]) -> Result + '_, ()> { + if N == 0 || data.len() < N { + return Err(()); + } + Ok(array_windows::(data)) +} + #[cfg(test)] mod array_chunks_tests { use super::*; @@ -424,4 +608,145 @@ mod array_chunks_tests { assert!(array_chunks_checked::(&[0u8; 8]).is_err()); assert!(array_chunks_checked::(&[1u32, 2, 3]).is_err()); } + + // ── array_windows (overlapping) ────────────────────────────────── + + #[test] + fn array_windows_4_over_6() { + let data: Vec = (0u8..6).collect(); + let ws: Vec<&[u8; 4]> = array_windows::(&data).collect(); + assert_eq!(ws.len(), 3); + assert_eq!(ws[0], &[0, 1, 2, 3]); + assert_eq!(ws[1], &[1, 2, 3, 4]); + assert_eq!(ws[2], &[2, 3, 4, 5]); + } + + #[test] + fn array_windows_short_buffer_empty_iter() { + let data: Vec = (0u8..3).collect(); + let ws: Vec<&[u8; 4]> = array_windows::(&data).collect(); + assert_eq!(ws.len(), 0); + } + + #[test] + fn array_windows_exact_n_yields_one() { + let data: [u8; 4] = [10, 20, 30, 40]; + let ws: Vec<&[u8; 4]> = array_windows::(&data).collect(); + assert_eq!(ws.len(), 1); + assert_eq!(ws[0], &[10, 20, 30, 40]); + } + + #[test] + fn array_windows_empty_buffer() { + let data: [u8; 0] = []; + let ws: Vec<&[u8; 4]> = array_windows::(&data).collect(); + assert_eq!(ws.len(), 0); + } + + #[test] + fn array_windows_n_zero_yields_empty() { + // N == 0 returns empty iter from the unchecked variant — matches + // slice::windows(0) panic-avoidance contract documented above. + let data: [u8; 8] = [1; 8]; + let ws: Vec<&[u8; 0]> = array_windows::(&data).collect(); + assert_eq!(ws.len(), 0); + } + + #[test] + fn array_windows_checked_accepts_long_enough() { + let data: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; + let it = array_windows_checked::(&data).expect("8 >= 4"); + let ws: Vec<&[u8; 4]> = it.collect(); + assert_eq!(ws.len(), 5); + assert_eq!(ws[4], &[4, 5, 6, 7]); + } + + #[test] + fn array_windows_checked_rejects_short_buffer() { + assert!(array_windows_checked::(&[0u8; 3]).is_err()); + assert!(array_windows_checked::(&[0u8; 0]).is_err()); + } + + #[test] + fn array_windows_checked_rejects_zero_n() { + assert!(array_windows_checked::(&[0u8; 8]).is_err()); + assert!(array_windows_checked::(&[]).is_err()); + } +} + +// ════════════════════════════════════════════════════════════════════ +// FMA-into-accumulator tests +// ════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod add_mul_tests { + use super::*; + + fn ref_add_mul_f32(acc: &[f32], a: &[f32], b: &[f32]) -> Vec { + acc.iter() + .zip(a) + .zip(b) + .map(|((&c, &x), &y)| x.mul_add(y, c)) + .collect() + } + + fn ref_add_mul_f64(acc: &[f64], a: &[f64], b: &[f64]) -> Vec { + acc.iter() + .zip(a) + .zip(b) + .map(|((&c, &x), &y)| x.mul_add(y, c)) + .collect() + } + + #[test] + fn add_mul_f32_aligned() { + let mut acc = vec![1.0f32; 32]; + let a = vec![2.0f32; 32]; + let b = vec![3.0f32; 32]; + let expected = ref_add_mul_f32(&acc, &a, &b); + add_mul_f32(&mut acc, &a, &b); + for (got, want) in acc.iter().zip(expected.iter()) { + assert!((got - want).abs() < 1e-5, "got={got}, want={want}"); + } + } + + #[test] + fn add_mul_f32_tail() { + // 17 elements — exercises 16-wide chunk + 1-element scalar tail. + for &len in &[0usize, 1, 15, 16, 17, 31, 32, 33, 100] { + let mut acc: Vec = (0..len).map(|i| i as f32 + 0.5).collect(); + let a: Vec = (0..len).map(|i| (i as f32) * 0.1).collect(); + let b: Vec = (0..len).map(|i| (i as f32) * 0.2 - 0.3).collect(); + let expected = ref_add_mul_f32(&acc, &a, &b); + add_mul_f32(&mut acc, &a, &b); + for (i, (got, want)) in acc.iter().zip(expected.iter()).enumerate() { + assert!((got - want).abs() < 1e-4, "len={len} idx={i}: got={got} want={want}"); + } + } + } + + #[test] + fn add_mul_f32_mismatched_lengths_takes_min() { + let mut acc = vec![0.0f32; 5]; + let a = vec![1.0f32; 10]; + let b = vec![2.0f32; 3]; + add_mul_f32(&mut acc, &a, &b); + // First 3 lanes accumulate; remaining 2 unchanged. + assert_eq!(&acc[..3], &[2.0, 2.0, 2.0]); + assert_eq!(&acc[3..], &[0.0, 0.0]); + } + + #[test] + fn add_mul_f64_aligned_and_tail() { + for &len in &[0usize, 1, 7, 8, 9, 16, 17, 50] { + let mut acc: Vec = (0..len).map(|i| i as f64 + 0.25).collect(); + let a: Vec = (0..len).map(|i| (i as f64) * 0.5).collect(); + let b: Vec = (0..len).map(|i| (i as f64) * 0.7 - 1.0).collect(); + let expected = ref_add_mul_f64(&acc, &a, &b); + add_mul_f64(&mut acc, &a, &b); + for (i, (got, want)) in acc.iter().zip(expected.iter()).enumerate() { + assert!((got - want).abs() < 1e-10, "len={len} idx={i}: got={got} want={want}"); + } + } + } }