From f8e94531cac02019583a9944e5a530b522e952ab Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 06:43:54 +0000 Subject: [PATCH 1/4] feat(simd_int_ops, hpc): AMX TDPBUSD arm for gemm_u8_i8 slice surface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the u8×i8 → i32 dispatch chain from PR #182's compile-time cascade (avx512vnni → avxvnni → scalar) by adding a top-tier AMX runtime check. Brings the SPR/GNR TDPBUSD path (16 384 MACs per instruction) to the slice-level surface that downstream consumers (lance-graph, etc.) use, completing the symmetry with PR #184's matmul_i8_to_i32 wiring. `gemm_u8_i8` is u8×i8 natively — no sign-shift bias trick needed (unlike `matmul_i8_to_i32` which is i8×i8 and has to convert via +128 then subtract `128·colsum(B)`). That makes the AMX path here a direct call with no bias correction. New helper `hpc::int8_tile_gemm::int8_gemm_amx_tiled(a_u8, b_i8, c, m, n, k)` factors out the tile-decomposition logic that was previously inlined in `matmul_i8_to_i32`. Both consumers now share the same helper: matmul_i8_to_i32: 1. shift A: i8 → u8 (+128) 2. int8_gemm_amx_tiled(a_u8, b, c, m, n, k) 3. subtract_i8_to_u8_bias(c, b, m, n, k) gemm_u8_i8 (AMX tier added in this commit): 1. int8_gemm_amx_tiled(a, b, c, m, n, k) — no shift, no bias The helper handles arbitrary 16/16/64-aligned shapes via a j_tile × i_tile loop calling int8_tile_gemm_16x16 per (16, 16) block. B sub-block extracted into K × 16 scratch once per j-tile, reused across all M i-tiles. **Overwrite semantics**: c is written not accumulated (the underlying int8_tile_gemm_16x16 accumulates into its tile buffer, but we zero the tile buffer before each call so the per-tile write to c is pure overwrite). Dispatch placement in gemm_u8_i8: * Tier 0 (this commit): runtime amx_available() check at the top of the function. AMX requires CPUID + XCR0 + Linux prctl which can't fit a target_feature compile-time gate. * Tiers 1-3: existing compile-time cfg-cascade (avx512vnni zmm → avxvnni ymm → scalar i8_gemm_i32). Unchanged. Misaligned shapes (m/n not multiples of 16, k not multiple of 64) or non-AMX hosts fall through to the compile-time cascade as before. Also fixed pre-existing clippy::manual_is_multiple_of warnings that surfaced in the new alignment check — switched from `% 16 == 0` to `.is_multiple_of(16)` etc. per the clippy hint (Rust 1.95 promoted this from `pedantic` to active warn). Verification: * 2095 lib tests pass (was 2094 — +1 new `gemm_u8_i8_amx_aligned_32x32x128` test exercising the AMX arm with a 32×32×128 shape that hits the AMX tier on this host's amx_int8 silicon). * 11 amx_matmul tests pass (matmul_i8_to_i32 refactored to call the shared helper — same behavior as before). * 4 gemm_u8_i8 tests pass (the existing ones still hit the compile-time cascade since their shapes aren't AMX-aligned). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit: matmul_bf16_to_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 VPDPBUSD | scalar (PR #184) | (PR #184) | (always) gemm_u8_i8 (slice): SPR+ AMX | CPL/Zen4 VPDPBUSD | ARL ymm | scalar (THIS) | (PR #182) | (PR #182) | (PR #182) Out of scope (separate PRs): * AVX-VNNI ymm arm for matmul_i8_to_i32 — `vnni2_*` helpers exist in simd_amx.rs but need assembling into a m×n×k GEMM. Same shape as the avx512vnni arm just with ymm width. * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/amx_matmul.rs | 24 +++--------------- src/hpc/int8_tile_gemm.rs | 52 +++++++++++++++++++++++++++++++++++++++ src/simd_int_ops.rs | 49 +++++++++++++++++++++++++++++++++--- 3 files changed, 101 insertions(+), 24 deletions(-) diff --git a/src/hpc/amx_matmul.rs b/src/hpc/amx_matmul.rs index 4939559e..5fa10d94 100644 --- a/src/hpc/amx_matmul.rs +++ b/src/hpc/amx_matmul.rs @@ -606,28 +606,10 @@ pub fn matmul_i8_to_i32( if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 { // Tier 1 — AMX TDPBUSD tile path: shift LHS i8 → u8 (+128), - // tile-GEMM via int8_tile_gemm_16x16, subtract bias. + // delegate to the shared int8_gemm_amx_tiled helper, subtract + // the sign-shift bias. let a_u8: Vec = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect(); - - let mut b_tile = vec![0i8; k * 16]; - let mut tile_c = vec![0i32; 256]; - - for j_tile in (0..n).step_by(16) { - for kk in 0..k { - let row = kk * n + j_tile; - b_tile[kk * 16..(kk + 1) * 16] - .copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) }); - } - for i_tile in (0..m).step_by(16) { - let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k]; - tile_c.fill(0); - crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k); - for ii in 0..16 { - let dst_off = (i_tile + ii) * n + j_tile; - c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]); - } - } - } + crate::hpc::int8_tile_gemm::int8_gemm_amx_tiled(&a_u8, &b_i8, &mut c, m, n, k); subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k); } else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avx512vnni") { // Tier 2 — AVX-512 VPDPBUSD zmm: 64 MACs per instruction, no diff --git a/src/hpc/int8_tile_gemm.rs b/src/hpc/int8_tile_gemm.rs index 6dc2e76e..847f4e37 100644 --- a/src/hpc/int8_tile_gemm.rs +++ b/src/hpc/int8_tile_gemm.rs @@ -231,6 +231,58 @@ fn fallback_path(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) { } } +// ═════════════════════════════════════════════════════════════════════ +// AMX tiled helper — arbitrary 16/16/64-aligned M × N × K via 16×16 tile loop +// ═════════════════════════════════════════════════════════════════════ + +/// `u8 × i8 → i32` GEMM using AMX `TDPBUSD` for arbitrary M × N × K +/// shapes that satisfy `m % 16 == 0 && n % 16 == 0 && k % 64 == 0`. +/// +/// Tile-decomposes the M × N output into 16×16 blocks and calls +/// [`int8_tile_gemm_16x16`] per (i_tile, j_tile). B sub-block extracted +/// into K × 16 scratch once per j-tile, reused across all M i-tiles — +/// amortizes the column gather cost. +/// +/// **Overwrite semantics**: `c` is written, not accumulated. Caller +/// does NOT need to zero `c` beforehand. (The underlying +/// `int8_tile_gemm_16x16` accumulates into its tile buffer, but we +/// zero the tile buffer before each call so the per-tile write to `c` +/// is pure overwrite.) +/// +/// # Panics +/// Debug-asserts AMX availability and the 16/16/64 shape constraints. +/// Production builds rely on the caller's runtime check +/// (`crate::hpc::amx_matmul::amx_available()`). +pub fn int8_gemm_amx_tiled(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + debug_assert!(crate::hpc::amx_matmul::amx_available()); + debug_assert_eq!(m % 16, 0, "int8_gemm_amx_tiled: M must be multiple of 16"); + debug_assert_eq!(n % 16, 0, "int8_gemm_amx_tiled: N must be multiple of 16"); + debug_assert_eq!(k % 64, 0, "int8_gemm_amx_tiled: K must be multiple of 64"); + + let mut b_tile = vec![0i8; k * 16]; + let mut tile_c = vec![0i32; 256]; + + for j_tile in (0..n).step_by(16) { + // Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows (contiguous + // memory for int8_tile_gemm_16x16's input shape). + for kk in 0..k { + let row = kk * n + j_tile; + b_tile[kk * 16..(kk + 1) * 16] + .copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) }); + } + for i_tile in (0..m).step_by(16) { + let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k]; + tile_c.fill(0); + int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k); + // Write tile_c (16 × 16, row-major) into c (M × N, row-major). + for ii in 0..16 { + let dst_off = (i_tile + ii) * n + j_tile; + c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]); + } + } + } +} + // ═════════════════════════════════════════════════════════════════════ // Tests // ═════════════════════════════════════════════════════════════════════ diff --git a/src/simd_int_ops.rs b/src/simd_int_ops.rs index b78ffe85..b9763640 100644 --- a/src/simd_int_ops.rs +++ b/src/simd_int_ops.rs @@ -255,9 +255,31 @@ pub fn gemm_u8_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usiz assert!(b.len() >= k * n, "gemm_u8_i8: b.len()={} < k*n={}", b.len(), k * n); assert!(c.len() >= m * n, "gemm_u8_i8: c.len()={} < m*n={}", c.len(), m * n); - // Compile-time dispatch chain. Exactly one arm survives per build; - // the others are stripped by `#[cfg]` so the compiler emits a direct - // call to the chosen kernel with no runtime branch. + // Tier 0 — runtime AMX check. AMX is a different feature class than + // the rest of the dispatch chain: it requires CPUID + XCR0 + a Linux + // `prctl(ARCH_REQ_XCOMP_PERM, 18)` to be granted, none of which fit + // a `target_feature` compile-time gate. The check is one CPUID + + // one XGETBV + one prctl (idempotent, cached after first call). On + // aligned shapes (16/16/64) this dispatches to TDPBUSD via the + // shared `int8_gemm_amx_tiled` helper — 16 384 MACs per instruction + // vs VPDPBUSD-zmm's 64. Since `gemm_u8_i8` is u8×i8 natively (no + // sign-shift bias needed), the AMX path is a direct call with no + // bias correction — simpler than `matmul_i8_to_i32`'s i8×i8 path. + #[cfg(target_arch = "x86_64")] + { + if crate::hpc::amx_matmul::amx_available() + && m.is_multiple_of(16) + && n.is_multiple_of(16) + && k.is_multiple_of(64) + { + crate::hpc::int8_tile_gemm::int8_gemm_amx_tiled(a, b, c, m, n, k); + return; + } + } + + // Compile-time dispatch chain (tiers 1-3). Exactly one arm survives + // per build; the others are stripped by `#[cfg]` so the compiler + // emits a direct call to the chosen kernel with no runtime branch. #[cfg(all(target_arch = "x86_64", target_feature = "avx512vnni"))] { @@ -731,4 +753,25 @@ mod tests { ); } } + + /// Exercises the AMX dispatch tier added on top of `gemm_u8_i8`'s + /// compile-time cascade. On AMX-enabled silicon (Sapphire Rapids+ + /// with the right OS prctl), 16/16/64-aligned shapes go through + /// TDPBUSD via `int8_gemm_amx_tiled`. Anywhere else this falls back + /// to the compile-time cascade — the assertion still holds because + /// the scalar reference is exact integer arithmetic. + #[test] + fn gemm_u8_i8_amx_aligned_32x32x128() { + let m = 32; // 2 × 16-wide M-tiles + let n = 32; // 2 × 16-wide N-tiles + let k = 128; // 2 × 64-wide K-blocks per tile + let a: Vec = (0..m * k).map(|i| ((i * 13 + 7) % 256) as u8).collect(); + let b: Vec = (0..k * n) + .map(|i| ((i * 19 + 11) % 256) as u8 as i8) + .collect(); + let expected = ref_gemm_u8_i8(&a, &b, m, n, k); + let mut c = vec![0i32; m * n]; + gemm_u8_i8(&a, &b, &mut c, m, n, k); + assert_eq!(c, expected, "gemm_u8_i8 AMX path mismatch"); + } } From 38d4800cd02ab187efa2b85b6da4fe1614c39ab2 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 06:50:56 +0000 Subject: [PATCH 2/4] feat(hpc): VPDPBUSD-ymm AVX-VNNI tier for matmul_i8_to_i32 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Completes the per-CPU dispatch chain for `matmul_i8_to_i32` by adding the AVX-VNNI ymm tier — Arrow Lake, Meteor Lake U, Alder Lake silicon that has AVX-VNNI but dropped AVX-512. Mirrors the shape of the avx512vnni-zmm arm shipped in PR #184 with the narrower 8-wide kernel. New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm`: * One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes, each receiving 4 u8×i8 products = 32 MACs per instruction. Half the throughput-per-instruction of the `_mm512_dpbusd_epi32` zmm version. * Same B-pre-pack scheme (quad-interleaved per 8-wide j-block), same K-tail / N-tail handling. Just narrower. * Stable intrinsic under `target_feature = "avxvnni,avx2"` — no asm-byte needed. Wiring `matmul_i8_to_i32`'s dispatch as Tier 3: 1. amx_available() + 16/16/64-aligned → AMX TDPBUSD (PR #184: int8_gemm_amx_tiled, 16 384 MACs/instr) 2. is_x86_feature_detected!("avx512vnni") → VPDPBUSD-zmm (PR #184: int8_gemm_vpdpbusd_zmm, 64 MACs/instr) 3. is_x86_feature_detected!("avxvnni") → VPDPBUSD-ymm (THIS COMMIT: int8_gemm_vpdpbusd_ymm, 32 MACs/instr) 4. scalar i8×i8 → i32 reference (was Tier 3) All three SIMD tiers share the sign-shift bias trick: shift LHS i8 → u8 (+128), run the kernel, subtract 128·colsum(B). Same `subtract_i8_to_u8_bias` helper (factored in PR #184). New direct test `vpdpbusd_ymm_matches_scalar` mirrors the zmm version's test: sweeps shapes spanning 8-aligned, K-tail (k % 4), N-tail (n % 8), and small shapes, asserts byte-equal output vs scalar reference. Verification: * Default v3 (this host has avx512vnni so the new arm doesn't fire from matmul_i8_to_i32 — Tier 2 catches first): 2096 lib tests pass (was 2095 — +1 new direct test). * Direct test exercises int8_gemm_vpdpbusd_ymm on this host since avxvnni is present alongside avx512vnni. * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit (final on the int8 side): matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 zmm | ARL ymm | scalar (PR #184) | (PR #184) | (THIS) | (always) The matmul_i8_to_i32 column of PR #180's dispatch table is now fully filled. The gemm_u8_i8 slice surface (in PR #185) already has AVX-VNNI ymm via its existing compile-time cascade — both i8-related public surfaces now cover every x86_64 tier with a hardware-accelerated arm. Out of scope (separate PRs): * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b, needs aarch64 CI runner verification. * TD-T6: real _mm256_* for AVX2 BLAS-1 (scal/nrm2/asum). https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/amx_matmul.rs | 14 +++- src/hpc/int8_tile_gemm.rs | 130 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 2 deletions(-) diff --git a/src/hpc/amx_matmul.rs b/src/hpc/amx_matmul.rs index 5fa10d94..673eab4f 100644 --- a/src/hpc/amx_matmul.rs +++ b/src/hpc/amx_matmul.rs @@ -621,9 +621,19 @@ pub fn matmul_i8_to_i32( crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm(&a_u8, &b_i8, &mut c, m, n, k); } subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k); + } else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avxvnni") { + // Tier 3 — AVX-VNNI ymm VPDPBUSD: 32 MACs per instruction. + // Arrow Lake, Meteor Lake U, Alder Lake silicon that has + // AVX-VNNI but dropped AVX-512. Same sign-shift bias trick. + let a_u8: Vec = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect(); + // SAFETY: runtime feature-detected avxvnni above. + unsafe { + crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm(&a_u8, &b_i8, &mut c, m, n, k); + } + subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k); } else { - // Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts, - // pre-AVX-512 silicon, or shapes that don't satisfy either of + // Tier 4 — Scalar i8×i8 → i32 reference for non-x86 hosts, + // pre-AVX-VNNI silicon, or shapes that don't satisfy any of // the SIMD tiers' alignment requirements. for i in 0..m { for p in 0..k { diff --git a/src/hpc/int8_tile_gemm.rs b/src/hpc/int8_tile_gemm.rs index 847f4e37..96d71b52 100644 --- a/src/hpc/int8_tile_gemm.rs +++ b/src/hpc/int8_tile_gemm.rs @@ -215,6 +215,97 @@ pub unsafe fn int8_gemm_vpdpbusd_zmm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: } } +// ═════════════════════════════════════════════════════════════════════ +// VPDPBUSD-ymm AVX-VNNI tier (Arrow Lake / Meteor Lake U / Alder Lake) +// ═════════════════════════════════════════════════════════════════════ + +/// AVX-VNNI ymm `u8 × i8 → i32` GEMM kernel for arbitrary M × N × K. +/// +/// One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes, +/// each receiving the sum of 4 `u8 × i8` products = **32 MACs per +/// instruction**. Half the throughput-per-instruction of the +/// `_mm512_dpbusd_epi32` zmm version (which does 64 MACs); fires on +/// Arrow Lake / Meteor Lake U / Alder Lake silicon that has AVX-VNNI +/// but NOT AVX-512. +/// +/// Same B pre-packing scheme as the zmm version (quad-interleaved per +/// 8-wide j-block), same K-tail and N-tail handling, just narrower. +/// Mirrors the `vnni2_dot_u8_i8` shape in `simd_amx.rs` but as a +/// matrix-product instead of single-row dot. +/// +/// Output behavior: overwrites `c` (does NOT accumulate). Caller's +/// responsibility to zero `c` first if needed. +/// +/// # Safety +/// Caller must have feature-detected `avxvnni + avx2` at runtime. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avxvnni,avx2")] +pub unsafe fn int8_gemm_vpdpbusd_ymm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + use core::arch::x86_64::{ + __m256i, _mm256_dpbusd_avx_epi32, _mm256_loadu_si256, _mm256_set1_epi32, _mm256_setzero_si256, + _mm256_storeu_si256, + }; + + let k_quads = k / 4; + let k_tail = k % 4; + + // Pre-pack scratch: 8 i32 lanes per k_quad (vs 16 in the zmm + // version). Same per-lane layout: each i32 holds 4 consecutive + // B K-bytes for output column j+lane. + let mut b_col_quads = vec![0i32; k_quads.max(1) * 8]; + let mut out_buf = [0i32; 8]; + + for j_base in (0..n).step_by(8) { + let j_count = 8.min(n - j_base); + + for k_quad in 0..k_quads { + let row0 = 4 * k_quad * n; + let row1 = (4 * k_quad + 1) * n; + let row2 = (4 * k_quad + 2) * n; + let row3 = (4 * k_quad + 3) * n; + for jj in 0..j_count { + let b0 = b_i8[row0 + j_base + jj] as u8 as u32; + let b1 = b_i8[row1 + j_base + jj] as u8 as u32; + let b2 = b_i8[row2 + j_base + jj] as u8 as u32; + let b3 = b_i8[row3 + j_base + jj] as u8 as u32; + b_col_quads[k_quad * 8 + jj] = (b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)) as i32; + } + for jj in j_count..8 { + b_col_quads[k_quad * 8 + jj] = 0; + } + } + + for i in 0..m { + let mut acc = _mm256_setzero_si256(); + let a_row_off = i * k; + for k_quad in 0..k_quads { + let a0 = a_u8[a_row_off + 4 * k_quad] as u32; + let a1 = a_u8[a_row_off + 4 * k_quad + 1] as u32; + let a2 = a_u8[a_row_off + 4 * k_quad + 2] as u32; + let a3 = a_u8[a_row_off + 4 * k_quad + 3] as u32; + let packed_a = a0 | (a1 << 8) | (a2 << 16) | (a3 << 24); + let a_v = _mm256_set1_epi32(packed_a as i32); + let b_v = _mm256_loadu_si256(b_col_quads.as_ptr().add(k_quad * 8) as *const __m256i); + acc = _mm256_dpbusd_avx_epi32(acc, a_v, b_v); + } + _mm256_storeu_si256(out_buf.as_mut_ptr() as *mut __m256i, acc); + + if k_tail > 0 { + for kk in (k_quads * 4)..k { + let a_val = a_u8[a_row_off + kk] as i32; + let tail_row = kk * n; + for jj in 0..j_count { + out_buf[jj] += a_val * b_i8[tail_row + j_base + jj] as i32; + } + } + } + + let dst_off = i * n + j_base; + c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]); + } + } +} + // ═════════════════════════════════════════════════════════════════════ // Scalar fallback (i32 reference) // ═════════════════════════════════════════════════════════════════════ @@ -422,6 +513,45 @@ mod tests { } } + /// Direct test for the VPDPBUSD-ymm arm (AVX-VNNI tier of + /// `matmul_i8_to_i32`). Same shape / bit-exactness contract as + /// the zmm version's test, just on the narrower 8-wide kernel. + #[cfg(target_arch = "x86_64")] + #[test] + fn vpdpbusd_ymm_matches_scalar() { + if !std::is_x86_feature_detected!("avxvnni") { + eprintln!("avxvnni not detected; skipping"); + return; + } + + fn ref_gemm(a: &[u8], b: &[i8], m: usize, n: usize, k: usize) -> Vec { + let mut c = vec![0i32; m * n]; + for i in 0..m { + for kk in 0..k { + let av = a[i * k + kk] as i32; + for j in 0..n { + c[i * n + j] += av * b[kk * n + j] as i32; + } + } + } + c + } + + // Sweep shapes spanning 8-aligned, K-tail (k % 4), N-tail + // (n % 8), and small shapes to exercise every code path. + for (m, n, k) in [(16, 8, 64), (3, 5, 7), (17, 33, 100), (1, 17, 12), (8, 8, 4)] { + let a: Vec = (0..m * k).map(|i| ((i * 31 + 7) % 256) as u8).collect(); + let b: Vec = (0..k * n) + .map(|i| ((i * 17 + 3) % 256) as u8 as i8) + .collect(); + let expected = ref_gemm(&a, &b, m, n, k); + let mut got = vec![0i32; m * n]; + // SAFETY: avxvnni confirmed at the top of the test. + unsafe { int8_gemm_vpdpbusd_ymm(&a, &b, &mut got, m, n, k) }; + assert_eq!(got, expected, "VPDPBUSD-ymm mismatch at (M={}, N={}, K={})", m, n, k); + } + } + #[test] fn vnni_pack_i8_roundtrip() { // Pack then verify the VNNI layout matches the spec: From e8f9ce07a05913e31819acdcceed378a16b04e04 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 07:12:06 +0000 Subject: [PATCH 3/4] fix(hpc): validate int8_gemm_amx_tiled slice lengths (codex P1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per codex review on PR #185: `int8_gemm_amx_tiled` is a safe public function (no `unsafe` in the signature) but its inner loop read `b_i8` via `core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16)` without any length check. Callers passing mismatched (m, n, k) vs slice lengths could trigger out-of-bounds reads / UB instead of a panic. Before PR #185 this logic lived only in `matmul_i8_to_i32`'s private AMX arm (where the public `pack_contig` preceded it and bounded everything), but the factored helper is now reachable from `gemm_u8_i8` and any future caller. Fix: 1. Add three boundary assertions at function entry matching `gemm_u8_i8`'s contract: a_u8.len() >= m * k b_i8.len() >= k * n c.len() >= m * n These panic with descriptive messages on undersized input — the safety contract is now enforced at the public function boundary, not at the unsafe pointer-arithmetic site inside the hot loop. 2. Replace the `unsafe { core::slice::from_raw_parts(...) }` B-pack line with safe `b_tile[..].copy_from_slice(&b_i8[row..row + 16])`. The bounds-check inside the loop is now redundant given the function-entry assertions, but the compiler should elide it once the invariant is proven; either way the code becomes panicking- safe instead of UB-on-misuse. 3. Update the doc-comment `# Panics` section to list the boundary panics alongside the existing debug-only AMX / alignment assertions. New regression test `amx_tiled_panics_on_undersized_b`: * Constructs `b: Vec` half-a-j_tile shorter than the claimed `k * n`. * Calls `int8_gemm_amx_tiled` and asserts the expected panic fires before any unsafe slice arithmetic. * `#[should_panic(expected = "b_i8.len()")]` catches the exact assertion message; works on any host (the boundary check fires before the `debug_assert!(amx_available())` so the test passes on AMX-less CI runners too). Verification: * 2097 lib tests pass (was 2096 — +1 new regression test). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. The matmul_i8_to_i32 path that delegates to int8_gemm_amx_tiled inherits the assertions transparently via the call chain. No behavior change for valid input — only mismatched-shape callers that would have hit UB now get a clean panic instead. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/int8_tile_gemm.rs | 47 +++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/hpc/int8_tile_gemm.rs b/src/hpc/int8_tile_gemm.rs index 96d71b52..ee778531 100644 --- a/src/hpc/int8_tile_gemm.rs +++ b/src/hpc/int8_tile_gemm.rs @@ -341,10 +341,22 @@ fn fallback_path(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) { /// is pure overwrite.) /// /// # Panics -/// Debug-asserts AMX availability and the 16/16/64 shape constraints. -/// Production builds rely on the caller's runtime check -/// (`crate::hpc::amx_matmul::amx_available()`). +/// Panics if `a_u8`, `b_i8`, or `c` are too small for the requested +/// `(m, n, k)`, mirroring the boundary contract from `gemm_u8_i8`. Also +/// panics in debug builds when AMX isn't OS-enabled or when the shape +/// alignment constraints aren't met (production builds skip those for +/// performance — callers must runtime-check +/// `crate::hpc::amx_matmul::amx_available()` and the 16/16/64 +/// alignment themselves). pub fn int8_gemm_amx_tiled(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + // Length assertions (codex P1 from PR #185 — the function reads + // `b_i8` via a 16-wide window per (kk, j_tile) iteration and a_u8 + // via a 16-row slice per i_tile, so mismatched shapes would + // trigger out-of-bounds reads without these gates). + assert!(a_u8.len() >= m * k, "int8_gemm_amx_tiled: a_u8.len()={} < m*k={}", a_u8.len(), m * k); + assert!(b_i8.len() >= k * n, "int8_gemm_amx_tiled: b_i8.len()={} < k*n={}", b_i8.len(), k * n); + assert!(c.len() >= m * n, "int8_gemm_amx_tiled: c.len()={} < m*n={}", c.len(), m * n); + debug_assert!(crate::hpc::amx_matmul::amx_available()); debug_assert_eq!(m % 16, 0, "int8_gemm_amx_tiled: M must be multiple of 16"); debug_assert_eq!(n % 16, 0, "int8_gemm_amx_tiled: N must be multiple of 16"); @@ -354,12 +366,13 @@ pub fn int8_gemm_amx_tiled(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: let mut tile_c = vec![0i32; 256]; for j_tile in (0..n).step_by(16) { - // Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows (contiguous - // memory for int8_tile_gemm_16x16's input shape). + // Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows + // (contiguous memory for int8_tile_gemm_16x16's input shape). + // Safe slicing — the row..row+16 range is bounded by + // `b_i8.len() >= k * n` asserted at function entry. for kk in 0..k { let row = kk * n + j_tile; - b_tile[kk * 16..(kk + 1) * 16] - .copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) }); + b_tile[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[row..row + 16]); } for i_tile in (0..m).step_by(16) { let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k]; @@ -513,6 +526,26 @@ mod tests { } } + /// Codex P1 regression on PR #185: `int8_gemm_amx_tiled` is a + /// safe public function — mismatched (m, n, k) vs slice lengths + /// must panic at the function boundary, not trigger UB inside + /// the unsafe slice/pointer arithmetic in the inner loop. This + /// test passes deliberately-undersized buffers and expects a + /// panic (which `#[should_panic]` catches). + #[test] + #[should_panic(expected = "b_i8.len()")] + fn amx_tiled_panics_on_undersized_b() { + let m = 16; + let n = 32; + let k = 64; + let a = vec![0u8; m * k]; + let b = vec![0i8; k * (n - 16)]; // half a j_tile short of what's claimed + let mut c = vec![0i32; m * n]; + // Even on non-AMX hosts the assertion fires before reaching + // the (debug-asserted) amx_available() check. + int8_gemm_amx_tiled(&a, &b, &mut c, m, n, k); + } + /// Direct test for the VPDPBUSD-ymm arm (AVX-VNNI tier of /// `matmul_i8_to_i32`). Same shape / bit-exactness contract as /// the zmm version's test, just on the narrower 8-wide kernel. From c18e3facf2a2658738087715966216e5420dc7ac Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 07:30:57 +0000 Subject: [PATCH 4/4] =?UTF-8?q?feat(simd=5Fruntime):=20runtime=20SIMD=20di?= =?UTF-8?q?spatch=20=E2=80=94=20release-binary=20alternative?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Alternative** to the compile-time cascade in `crate::simd::*` / `crate::simd_ops::*`. **Additive**: gated under `--features runtime-dispatch`, does not touch any existing path. Mutually exclusive with `nightly-simd` (the portable-SIMD polyfill replaces the architecture-specific intrinsics that the runtime trampolines select between). Use case: ship ONE binary that adapts across heterogeneous deployment silicon (AVX-512 server + AVX2-only laptop + Arrow Lake desktop + Sapphire Rapids workstation) from the same artifact. The existing compile-time `v3` / `v4` / `native` / `nightly-simd` configs target a single class of CPU per build; the runtime layer targets the union via per-op LazyLock trampolines. Design from `.claude/knowledge/simd-dispatch-architecture.md` § 7.1 / Phase 5, building on the precedent set by `hpc::bgz17_bridge::{L1_KERNEL, L1_WEIGHTED_KERNEL, ...}` (`LazyLock` pattern, lines 75-86) already proven in tree. # Dispatch model One `LazyLock` per public surface. First call fires the closure which reads `simd_caps()` and selects a backend; every subsequent call is one pointer-deref + indirect call. Per-call overhead: ~2-3 ns (LazyLock atomic-acquire load that's cache- resident after first hit + indirect-call branch-target predict). Invisible against any SIMD op's actual work (~100+ cycles). # Module layout src/simd_runtime/ mod.rs — module entry, mutual-exclusion check vs nightly-simd, public re-exports vnni_dot.rs — u8×i8 → i32 dot (the proposal's canonical example): 3 backends, the AVX-512 arm wraps `simd_amx::vnni_dot_u8_i8` with a scalar tail because the existing kernel silently drops n%64 lanes (its matvec caller pre-aligns rows; a general-purpose dispatch surface cannot assume that) add_mul.rs — slice-level FMA (acc += a × b) for f32/f64; the ONLY new kernel code in this module — 4 backends per type (avx512 / avx2+fma / neon / scalar), each ~15 LoC of direct intrinsics matmul.rs — thin trampolines for matmul_bf16_to_f32 / matmul_f32 / matmul_i8_to_i32 / gemm_u8_i8 delegating to existing functions that already runtime-dispatch internally (PR #182 / #184 / #185) casts.rs — trampolines for the four half-precision batch casts delegating to PR #183's already- runtime-dispatched implementations # Backend reuse — no kernel duplication Every dispatch arm delegates to a kernel that already exists in tree. The runtime layer is just the trampoline. The only NEW kernel code is `add_mul_f32` / `add_mul_f64` (no pre-existing slice-level FMA primitive in tree to delegate to — the compile- time `crate::simd_ops::add_mul_f32` from PR #182 polyfills through the F32x16 lane wrapper; the runtime version skips that indirection for one more inlined intrinsic per chunk). # Invariants preserved from this PR series * No-FP32-roundtrip on BF16/F16 arithmetic — backends respect the bit-exact mantissa rule * Asm-byte encoding for nightly-gated AMX / FP16 — selected backends keep their existing asm-byte fast paths * Little-endian byte contracts for half-precision carriers * Accumulator-preservation in tile paths (codex P1 from #184) * Boundary assertions on safe public fns (codex P1 from #185) — the public `vnni_dot_u8_i8(a, b)` etc. inherit the asserts transparently via the call chain # Verification * Default build (no feature): 2087 lib tests pass — the `simd_runtime` module is gated out, zero impact on existing paths. * `cargo test --lib --features runtime-dispatch`: **2105 lib tests pass** (+8 new in `simd_runtime::*::tests`). * `cargo clippy --lib --tests --features rayon,native -- -D warnings` clean (default). * `cargo clippy --lib --tests --features rayon,native,runtime-dispatch -- -D warnings` clean. * `cargo fmt --all --check` clean. * Mutual-exclusion enforced via `compile_error!` in `simd_runtime/mod.rs` — `--features runtime-dispatch,nightly-simd` fails to compile with a clear error. # What's NOT in this PR (deferred) * Sweep the remaining ~15-20 SIMD/HPC public surfaces (min_i8, max_i8, add_i8, dot_i8, etc.). Each is ~30-50 LoC of trampoline; pattern is established here. Estimated ~700-900 more LoC across the full surface map. * CI matrix entry for `runtime-dispatch-portable` (per simd-dispatch-architecture.md § 7 / TD-SIMD-9). Job builds with `--features runtime-dispatch` on a v3 baseline runner and asserts every trampoline lands on its expected backend. * `simd_caps()` snapshot logging at process start (debug-only) to aid release-binary deployment debugging — "which arm did you actually pick?" # Cost summary src/simd_runtime/ +537 LoC (4 modules) src/lib.rs +9 LoC (cfg-gated mod decl) Cargo.toml +21 LoC (feature decl + doc) Total ~570 LoC Trampoline LoC per surface (this PR's sample): vnni_dot 170 LoC (LazyLock + 3 arms + wrapper + tests) add_mul (f32+f64)218 LoC (LazyLock×2 + 4 arms×2 + tests — the ONLY new kernels) matmul (4 ops) 100 LoC (thin delegations + tests) casts (4 ops) 75 LoC (thin delegations + tests) Out-of-tree estimate for the full sweep (per § 7 of the design doc): ~1400 LoC total once all ~25 public SIMD/HPC surfaces are wired. This PR establishes ~40% of that budget with the canonical patterns. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- Cargo.toml | 21 +++ src/lib.rs | 9 ++ src/simd_runtime/add_mul.rs | 276 +++++++++++++++++++++++++++++++++++ src/simd_runtime/casts.rs | 85 +++++++++++ src/simd_runtime/matmul.rs | 101 +++++++++++++ src/simd_runtime/mod.rs | 95 ++++++++++++ src/simd_runtime/vnni_dot.rs | 173 ++++++++++++++++++++++ 7 files changed, 760 insertions(+) create mode 100644 src/simd_runtime/add_mul.rs create mode 100644 src/simd_runtime/casts.rs create mode 100644 src/simd_runtime/matmul.rs create mode 100644 src/simd_runtime/mod.rs create mode 100644 src/simd_runtime/vnni_dot.rs diff --git a/Cargo.toml b/Cargo.toml index 97c4514f..ee7cc8f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -201,6 +201,27 @@ rayon = ["dep:rayon", "std"] # cfg-dispatch in `simd.rs` remains the production path. nightly-simd = ["std"] +# Runtime SIMD dispatch — release-binary distribution path. Compiles all +# x86_64 backends into one artifact and selects per-op kernels via +# `LazyLock` trampolines that read `simd_caps()` on first call. The +# `crate::simd_runtime::*` module becomes reachable under this feature; +# the existing compile-time `crate::simd::*` / `crate::simd_ops::*` +# cascade is unchanged (additive). Use case: shipping one binary that +# adapts across heterogeneous deployment silicon (AVX-512 server + +# AVX2-only laptop) from the same artifact. +# +# Mutually exclusive with `nightly-simd` (the portable-SIMD polyfill +# replaces the architecture-specific intrinsics that the runtime +# trampolines select between; they can't coexist coherently). +# +# Per-call overhead: ~2-3 ns indirect-call through a static fn pointer +# (LazyLock fires once at first call, every subsequent call is a +# pointer deref). Invisible against any SIMD op's actual work. +# +# See `.claude/knowledge/simd-dispatch-architecture.md` § 7.1 / Phase 5 +# for the design rationale. +runtime-dispatch = ["std"] + # HPC extras: p64 palette/NARS bridge + fractal manifold. # (blake3 was previously listed here; it is now part of `std` directly # because the cognitive substrate modules under hpc/ that import blake3 diff --git a/src/lib.rs b/src/lib.rs index 426bdae9..60edbcac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -313,6 +313,15 @@ pub mod simd_ops; #[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)] pub mod simd_half; +/// Runtime SIMD dispatch — release-binary distribution path. +/// +/// Gated under `--features runtime-dispatch`. Mutually exclusive with +/// `nightly-simd` (the cfg in `simd_runtime/mod.rs` enforces this with +/// a `compile_error!`). See `.claude/knowledge/simd-dispatch-architecture.md` +/// § 7.1 / Phase 5 for the design. +#[cfg(feature = "runtime-dispatch")] +pub mod simd_runtime; + /// Pluggable linear algebra backends (native SIMD, MKL, OpenBLAS). #[cfg(feature = "std")] pub mod backend; diff --git a/src/simd_runtime/add_mul.rs b/src/simd_runtime/add_mul.rs new file mode 100644 index 00000000..799f65a5 --- /dev/null +++ b/src/simd_runtime/add_mul.rs @@ -0,0 +1,276 @@ +//! Runtime-dispatched `add_mul_f32` / `add_mul_f64` — slice-level FMA. +//! +//! Semantics: `acc[i] += a[i] * b[i]` for each lane, single-rounded +//! per IEEE 754 FMA. Three backends per type, selected once at first +//! call via `LazyLock`: +//! +//! 1. **AVX-512** (`avx512f`) → 16-wide / 8-wide `_mm512_fmadd_ps` / +//! `_mm512_fmadd_pd` (AVX-512 implies FMA on every CPU that ships +//! it; no separate FMA feature check needed). +//! 2. **AVX2 + FMA** (`avx2 + fma`) → 8-wide / 4-wide +//! `_mm256_fmadd_ps` / `_mm256_fmadd_pd`. +//! 3. **NEON** (aarch64 + neon) → 4-wide / 2-wide `vfmaq_f32` / +//! `vfmaq_f64`. +//! 4. **Scalar** → `f32::mul_add` / `f64::mul_add` (always correct, +//! always single-rounded per IEEE). +//! +//! Unlike the other surfaces in this module, the FMA backends are NEW +//! kernel code — no pre-existing slice-level FMA primitive in tree to +//! delegate to. The compile-time `crate::simd_ops::add_mul_f32` from +//! PR #182 polyfills through the `F32x16` lane wrapper; the runtime +//! version skips that indirection for one more inlined intrinsic per +//! chunk. + +use std::sync::LazyLock; + +type AddMulF32Fn = unsafe fn(&mut [f32], &[f32], &[f32]); +type AddMulF64Fn = unsafe fn(&mut [f64], &[f64], &[f64]); + +static ADD_MUL_F32_DISPATCH: LazyLock = LazyLock::new(|| { + let _caps = crate::hpc::simd_caps::simd_caps(); + #[cfg(target_arch = "x86_64")] + { + if _caps.avx512f { + return add_mul_f32_avx512 as AddMulF32Fn; + } + if _caps.avx2 && _caps.fma { + return add_mul_f32_avx2_fma as AddMulF32Fn; + } + } + #[cfg(target_arch = "aarch64")] + { + if _caps.neon { + return add_mul_f32_neon as AddMulF32Fn; + } + } + add_mul_f32_scalar as AddMulF32Fn +}); + +static ADD_MUL_F64_DISPATCH: LazyLock = LazyLock::new(|| { + let _caps = crate::hpc::simd_caps::simd_caps(); + #[cfg(target_arch = "x86_64")] + { + if _caps.avx512f { + return add_mul_f64_avx512 as AddMulF64Fn; + } + if _caps.avx2 && _caps.fma { + return add_mul_f64_avx2_fma as AddMulF64Fn; + } + } + #[cfg(target_arch = "aarch64")] + { + if _caps.neon { + return add_mul_f64_neon as AddMulF64Fn; + } + } + add_mul_f64_scalar as AddMulF64Fn +}); + +/// `acc[i] += a[i] * b[i]` (f32, single-rounded FMA). +/// +/// Slice lengths truncated to the shortest. Wrapping behaviour +/// matches `f32::mul_add` lane-by-lane (NaN propagates, infinities +/// handled per IEEE 754). +#[inline] +pub fn add_mul_f32(acc: &mut [f32], a: &[f32], b: &[f32]) { + // SAFETY: dispatch closure verified the runtime CPU caps satisfy + // the selected function's `#[target_feature]` requirements. + unsafe { (*ADD_MUL_F32_DISPATCH)(acc, a, b) } +} + +/// `acc[i] += a[i] * b[i]` (f64, single-rounded FMA). +#[inline] +pub fn add_mul_f64(acc: &mut [f64], a: &[f64], b: &[f64]) { + unsafe { (*ADD_MUL_F64_DISPATCH)(acc, a, b) } +} + +// ──────────────────────────────────────────────────────────────────────── +// AVX-512 backends (16 / 8 lanes per FMA instruction) +// ──────────────────────────────────────────────────────────────────────── + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn add_mul_f32_avx512(acc: &mut [f32], a: &[f32], b: &[f32]) { + use core::arch::x86_64::{_mm512_fmadd_ps, _mm512_loadu_ps, _mm512_storeu_ps}; + let n = acc.len().min(a.len()).min(b.len()); + let chunks = n / 16; + for c in 0..chunks { + let off = c * 16; + let va = _mm512_loadu_ps(a.as_ptr().add(off)); + let vb = _mm512_loadu_ps(b.as_ptr().add(off)); + let vc = _mm512_loadu_ps(acc.as_ptr().add(off)); + let r = _mm512_fmadd_ps(va, vb, vc); + _mm512_storeu_ps(acc.as_mut_ptr().add(off), r); + } + for i in (chunks * 16)..n { + acc[i] = a[i].mul_add(b[i], acc[i]); + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn add_mul_f64_avx512(acc: &mut [f64], a: &[f64], b: &[f64]) { + use core::arch::x86_64::{_mm512_fmadd_pd, _mm512_loadu_pd, _mm512_storeu_pd}; + let n = acc.len().min(a.len()).min(b.len()); + let chunks = n / 8; + for c in 0..chunks { + let off = c * 8; + let va = _mm512_loadu_pd(a.as_ptr().add(off)); + let vb = _mm512_loadu_pd(b.as_ptr().add(off)); + let vc = _mm512_loadu_pd(acc.as_ptr().add(off)); + let r = _mm512_fmadd_pd(va, vb, vc); + _mm512_storeu_pd(acc.as_mut_ptr().add(off), r); + } + for i in (chunks * 8)..n { + acc[i] = a[i].mul_add(b[i], acc[i]); + } +} + +// ──────────────────────────────────────────────────────────────────────── +// AVX2 + FMA backends (8 / 4 lanes per FMA instruction) +// ──────────────────────────────────────────────────────────────────────── + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +unsafe fn add_mul_f32_avx2_fma(acc: &mut [f32], a: &[f32], b: &[f32]) { + use core::arch::x86_64::{_mm256_fmadd_ps, _mm256_loadu_ps, _mm256_storeu_ps}; + let n = acc.len().min(a.len()).min(b.len()); + let chunks = n / 8; + for c in 0..chunks { + let off = c * 8; + let va = _mm256_loadu_ps(a.as_ptr().add(off)); + let vb = _mm256_loadu_ps(b.as_ptr().add(off)); + let vc = _mm256_loadu_ps(acc.as_ptr().add(off)); + let r = _mm256_fmadd_ps(va, vb, vc); + _mm256_storeu_ps(acc.as_mut_ptr().add(off), r); + } + for i in (chunks * 8)..n { + acc[i] = a[i].mul_add(b[i], acc[i]); + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +unsafe fn add_mul_f64_avx2_fma(acc: &mut [f64], a: &[f64], b: &[f64]) { + use core::arch::x86_64::{_mm256_fmadd_pd, _mm256_loadu_pd, _mm256_storeu_pd}; + let n = acc.len().min(a.len()).min(b.len()); + let chunks = n / 4; + for c in 0..chunks { + let off = c * 4; + let va = _mm256_loadu_pd(a.as_ptr().add(off)); + let vb = _mm256_loadu_pd(b.as_ptr().add(off)); + let vc = _mm256_loadu_pd(acc.as_ptr().add(off)); + let r = _mm256_fmadd_pd(va, vb, vc); + _mm256_storeu_pd(acc.as_mut_ptr().add(off), r); + } + for i in (chunks * 4)..n { + acc[i] = a[i].mul_add(b[i], acc[i]); + } +} + +// ──────────────────────────────────────────────────────────────────────── +// NEON backends (4 / 2 lanes per vfmaq instruction) +// ──────────────────────────────────────────────────────────────────────── + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn add_mul_f32_neon(acc: &mut [f32], a: &[f32], b: &[f32]) { + use core::arch::aarch64::{vfmaq_f32, vld1q_f32, vst1q_f32}; + let n = acc.len().min(a.len()).min(b.len()); + let chunks = n / 4; + for c in 0..chunks { + let off = c * 4; + let va = vld1q_f32(a.as_ptr().add(off)); + let vb = vld1q_f32(b.as_ptr().add(off)); + let vc = vld1q_f32(acc.as_ptr().add(off)); + let r = vfmaq_f32(vc, va, vb); + vst1q_f32(acc.as_mut_ptr().add(off), r); + } + for i in (chunks * 4)..n { + acc[i] = a[i].mul_add(b[i], acc[i]); + } +} + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn add_mul_f64_neon(acc: &mut [f64], a: &[f64], b: &[f64]) { + use core::arch::aarch64::{vfmaq_f64, vld1q_f64, vst1q_f64}; + let n = acc.len().min(a.len()).min(b.len()); + let chunks = n / 2; + for c in 0..chunks { + let off = c * 2; + let va = vld1q_f64(a.as_ptr().add(off)); + let vb = vld1q_f64(b.as_ptr().add(off)); + let vc = vld1q_f64(acc.as_ptr().add(off)); + let r = vfmaq_f64(vc, va, vb); + vst1q_f64(acc.as_mut_ptr().add(off), r); + } + for i in (chunks * 2)..n { + acc[i] = a[i].mul_add(b[i], acc[i]); + } +} + +// ──────────────────────────────────────────────────────────────────────── +// Scalar fallback (always correct, IEEE 754 FMA via core) +// ──────────────────────────────────────────────────────────────────────── + +unsafe fn add_mul_f32_scalar(acc: &mut [f32], a: &[f32], b: &[f32]) { + let n = acc.len().min(a.len()).min(b.len()); + for i in 0..n { + acc[i] = a[i].mul_add(b[i], acc[i]); + } +} + +unsafe fn add_mul_f64_scalar(acc: &mut [f64], a: &[f64], b: &[f64]) { + let n = acc.len().min(a.len()).min(b.len()); + for i in 0..n { + acc[i] = a[i].mul_add(b[i], acc[i]); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ref_add_mul_f32(acc: &[f32], a: &[f32], b: &[f32]) -> Vec { + let n = acc.len().min(a.len()).min(b.len()); + (0..n).map(|i| a[i].mul_add(b[i], acc[i])).collect() + } + + fn ref_add_mul_f64(acc: &[f64], a: &[f64], b: &[f64]) -> Vec { + let n = acc.len().min(a.len()).min(b.len()); + (0..n).map(|i| a[i].mul_add(b[i], acc[i])).collect() + } + + #[test] + fn add_mul_f32_matches_scalar() { + // Multiple sizes spanning aligned + tail for each backend + // (16/8/4 lane widths → 16, 17, 25 hit boundaries). + for n in [0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] { + let acc_init: Vec = (0..n).map(|i| i as f32 * 0.1).collect(); + let a: Vec = (0..n).map(|i| (i as f32 * 0.5) - 1.0).collect(); + let b: Vec = (0..n).map(|i| (i as f32 * 0.3) + 0.7).collect(); + let expected = ref_add_mul_f32(&acc_init, &a, &b); + let mut acc = acc_init.clone(); + add_mul_f32(&mut acc, &a, &b); + for (i, (got, want)) in acc.iter().zip(expected.iter()).enumerate() { + assert_eq!(got.to_bits(), want.to_bits(), "n={n} i={i}: got {got} want {want}"); + } + } + } + + #[test] + fn add_mul_f64_matches_scalar() { + for n in [0, 1, 3, 4, 5, 7, 8, 16, 32, 100] { + let acc_init: Vec = (0..n).map(|i| i as f64 * 0.1).collect(); + let a: Vec = (0..n).map(|i| (i as f64 * 0.5) - 1.0).collect(); + let b: Vec = (0..n).map(|i| (i as f64 * 0.3) + 0.7).collect(); + let expected = ref_add_mul_f64(&acc_init, &a, &b); + let mut acc = acc_init.clone(); + add_mul_f64(&mut acc, &a, &b); + for (i, (got, want)) in acc.iter().zip(expected.iter()).enumerate() { + assert_eq!(got.to_bits(), want.to_bits(), "n={n} i={i}: got {got} want {want}"); + } + } + } +} diff --git a/src/simd_runtime/casts.rs b/src/simd_runtime/casts.rs new file mode 100644 index 00000000..cc4055a5 --- /dev/null +++ b/src/simd_runtime/casts.rs @@ -0,0 +1,85 @@ +//! Runtime-dispatched batch cast trampolines. +//! +//! Wraps the four existing half-precision cast surfaces. Each +//! underlying function already has internal runtime dispatch (F16C on +//! `cast_*_f16_*`, avx512bf16 + AVX-512F bit-shift on +//! `bf16_to_f32_batch`, AVX-512F-only RNE on `f32_to_bf16_batch_rne`), +//! shipped in PR #183. The wrappers here are `#[inline(always)]` so +//! they collapse to one call to the inner dispatcher at the consumer +//! site. +//! +//! All four are bit-exact per IEEE 754: +//! - F16 ↔ f32: `_mm256_cvtph_ps` / `_mm256_cvtps_ph::<0>` on F16C hosts, +//! scalar `F16::to_f32` / `F16::from_f32_rounded` elsewhere. +//! - BF16 → f32: `_mm512_cvtpbh_ps` on avx512bf16 hosts, AVX-512F +//! bit-shift on avx512f hosts, scalar bit-shift otherwise. +//! - f32 → BF16 RNE: pure AVX-512F bit-fiddle (byte-exact vs +//! `_mm512_cvtneps_pbh`), scalar IEEE RNE elsewhere. + +use crate::hpc::quantized::{BF16, F16}; + +/// F16 → f32 batch (lossless, IEEE 754 widening). +#[inline(always)] +pub fn cast_f16_to_f32_batch(src: &[F16], dst: &mut [f32]) { + crate::simd_half::cast_f16_to_f32_batch(src, dst) +} + +/// f32 → F16 batch (IEEE 754 RNE, bit-exact vs `F16::from_f32_rounded`). +#[inline(always)] +pub fn cast_f32_to_f16_batch(src: &[f32], dst: &mut [F16]) { + crate::simd_half::cast_f32_to_f16_batch(src, dst) +} + +/// BF16 → f32 batch (lossless bit-shift). +/// +/// The underlying `crate::simd::bf16_to_f32_batch` takes `&[u16]`; the +/// trampoline accepts `&[BF16]` for symmetry with the F16 surface and +/// reinterprets via repr(transparent) layout equivalence. +#[inline(always)] +pub fn bf16_to_f32_batch(src: &[BF16], dst: &mut [f32]) { + // SAFETY: BF16 is `#[repr(transparent)] struct BF16(pub u16)`. + let src_u16: &[u16] = unsafe { core::slice::from_raw_parts(src.as_ptr() as *const u16, src.len()) }; + crate::simd::bf16_to_f32_batch(src_u16, dst) +} + +/// f32 → BF16 batch (RNE, byte-exact vs `_mm512_cvtneps_pbh`). +/// +/// Underlying `crate::simd::f32_to_bf16_batch_rne` takes `&mut [u16]`; +/// trampoline accepts `&mut [BF16]` and reinterprets. +#[inline(always)] +pub fn f32_to_bf16_batch_rne(src: &[f32], dst: &mut [BF16]) { + // SAFETY: BF16 is repr(transparent) over u16. + let dst_u16: &mut [u16] = unsafe { core::slice::from_raw_parts_mut(dst.as_mut_ptr() as *mut u16, dst.len()) }; + crate::simd::f32_to_bf16_batch_rne(src, dst_u16) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn f16_roundtrip_via_runtime_trampolines() { + let inputs: Vec = (0..33).map(|i| F16::from_f32(i as f32 * 0.75)).collect(); + let mut f32_buf = vec![0.0f32; 33]; + cast_f16_to_f32_batch(&inputs, &mut f32_buf); + let mut back = vec![F16::ZERO; 33]; + cast_f32_to_f16_batch(&f32_buf, &mut back); + for i in 0..33 { + assert_eq!(back[i], inputs[i], "F16 roundtrip mismatch at {i}"); + } + } + + #[test] + fn bf16_roundtrip_via_runtime_trampolines() { + let inputs: Vec = (0..33) + .map(|i| BF16::from_f32_rounded(i as f32 * 0.75)) + .collect(); + let mut f32_buf = vec![0.0f32; 33]; + bf16_to_f32_batch(&inputs, &mut f32_buf); + let mut back = vec![BF16::ZERO; 33]; + f32_to_bf16_batch_rne(&f32_buf, &mut back); + for i in 0..33 { + assert_eq!(back[i], inputs[i], "BF16 roundtrip mismatch at {i}"); + } + } +} diff --git a/src/simd_runtime/matmul.rs b/src/simd_runtime/matmul.rs new file mode 100644 index 00000000..a5f1925f --- /dev/null +++ b/src/simd_runtime/matmul.rs @@ -0,0 +1,101 @@ +//! Runtime-dispatched matmul trampolines. +//! +//! These are thin wrappers over the existing public matmul surfaces +//! that ALREADY have internal runtime dispatch (PR #182 / #184 / #185 +//! shipped the per-tier kernels and the dispatch helpers). The +//! trampolines here exist for two reasons: +//! +//! 1. **Consistent surface under `crate::simd_runtime::*`** — +//! consumers using the runtime-dispatch path get a uniform import +//! site for every matmul + every vector op. +//! 2. **Inline-elision opportunity** — these wrappers are +//! `#[inline(always)]` so the call collapses to the inner dispatch +//! call without any extra indirection at the consumer site (the +//! matmul entry points themselves are NOT `#[inline]` because +//! they're large; the wrapper is one branch). +//! +//! Backend chains (all already in tree, this module adds nothing new): +//! +//! - `matmul_bf16_to_f32`: AMX TDPBF16PS → VDPBF16PS → scalar (PR #182). +//! - `matmul_f32` (BF16 compute on AMX hosts): same chain (PR #182). +//! - `matmul_i8_to_i32`: AMX TDPBUSD → VPDPBUSD-zmm → VPDPBUSD-ymm → scalar (PR #184/#185). +//! - `gemm_u8_i8`: AMX TDPBUSD → compile-time avx512vnni / avxvnni / scalar (PR #185). +//! +//! Cost: zero on top of what the underlying functions already pay. + +use crate::{ArrayView2, ArrayViewMut2}; + +/// BF16 × BF16 → f32 matmul. Runtime-dispatched. +/// +/// Delegates to [`crate::hpc::amx_matmul::matmul_bf16_to_f32`], which +/// already runtime-dispatches AMX TDPBF16PS → VDPBF16PS → scalar. +#[inline(always)] +pub fn matmul_bf16_to_f32( + lhs: ArrayView2<'_, crate::hpc::quantized::BF16>, rhs: ArrayView2<'_, crate::hpc::quantized::BF16>, + out: ArrayViewMut2<'_, f32>, +) -> Result<(), crate::hpc::amx_matmul::MatmulError> { + crate::hpc::amx_matmul::matmul_bf16_to_f32(lhs, rhs, out) +} + +/// f32 × f32 → f32 matmul (BF16 compute on AMX hosts). +/// Runtime-dispatched per the underlying tier chain. +#[inline(always)] +pub fn matmul_f32( + lhs: ArrayView2<'_, f32>, rhs: ArrayView2<'_, f32>, out: ArrayViewMut2<'_, f32>, +) -> Result<(), crate::hpc::amx_matmul::MatmulError> { + crate::hpc::amx_matmul::matmul_f32(lhs, rhs, out) +} + +/// i8 × i8 → i32 matmul. Runtime-dispatched to AMX TDPBUSD → VPDPBUSD-zmm → +/// VPDPBUSD-ymm → scalar with the sign-shift bias trick. +#[inline(always)] +pub fn matmul_i8_to_i32( + lhs: ArrayView2<'_, i8>, rhs: ArrayView2<'_, i8>, out: ArrayViewMut2<'_, i32>, +) -> Result<(), crate::hpc::amx_matmul::MatmulError> { + crate::hpc::amx_matmul::matmul_i8_to_i32(lhs, rhs, out) +} + +/// `C = A · B` where A is M×K u8, B is K×N i8, C is M×N i32 (overwrite). +/// +/// Delegates to [`crate::simd_int_ops::gemm_u8_i8`]. Tier 0 (runtime +/// AMX detection) was added by PR #185; tiers 1-3 (compile-time +/// avx512vnni / avxvnni / scalar) come from PR #182. +#[inline(always)] +pub fn gemm_u8_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + crate::simd_int_ops::gemm_u8_i8(a, b, c, m, n, k) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Array2; + + #[test] + fn matmul_bf16_trampoline_works() { + use crate::hpc::quantized::BF16; + let a: Array2 = Array2::from_shape_fn((16, 64), |(_i, _j)| BF16::from_f32(1.0)); + let b: Array2 = Array2::from_shape_fn((64, 16), |(_i, _j)| BF16::from_f32(0.5)); + let mut c: Array2 = Array2::zeros((16, 16)); + matmul_bf16_to_f32(a.view(), b.view(), c.view_mut()).unwrap(); + for &v in c.iter() { + // 64 lanes × 1.0 × 0.5 = 32.0; BF16 of 1.0 and 0.5 are exact. + assert!((v - 32.0).abs() < 1e-3, "expected ~32.0, got {v}"); + } + } + + #[test] + fn gemm_u8_i8_trampoline_matches_scalar() { + let m = 16; + let n = 16; + let k = 64; + let a: Vec = (0..m * k).map(|i| ((i * 7 + 3) % 256) as u8).collect(); + let b: Vec = (0..k * n) + .map(|i| ((i * 11 + 5) % 256) as u8 as i8) + .collect(); + let mut c = vec![0i32; m * n]; + gemm_u8_i8(&a, &b, &mut c, m, n, k); + // Spot-check c[0]: sum over k of a[k] * b[k*n] + let expected_c0: i32 = (0..k).map(|kk| a[kk] as i32 * b[kk * n] as i32).sum(); + assert_eq!(c[0], expected_c0, "c[0] mismatch"); + } +} diff --git a/src/simd_runtime/mod.rs b/src/simd_runtime/mod.rs new file mode 100644 index 00000000..428020d3 --- /dev/null +++ b/src/simd_runtime/mod.rs @@ -0,0 +1,95 @@ +//! Runtime SIMD dispatch — release-binary distribution path. +//! +//! **Alternative** to the compile-time cascade in `crate::simd::*` / +//! `crate::simd_ops::*`. **Additive**: gated under +//! `--features runtime-dispatch`, does not change any existing path. +//! +//! Use case: ship ONE binary that adapts across heterogeneous deployment +//! silicon — AVX-512 server + AVX2-only laptop + Arrow Lake desktop + +//! Sapphire Rapids workstation — all from the same artifact. The +//! existing compile-time `v3` / `v4` / `native` / `nightly-simd` +//! configs target a single class of CPU per build; the runtime layer +//! targets the union. +//! +//! # Dispatch model +//! +//! One `LazyLock` per public surface. First call fires the +//! closure which reads [`crate::hpc::simd_caps::simd_caps`] and selects +//! a backend; every subsequent call is one pointer-deref + indirect +//! call. Per-call overhead: ~2-3 ns (the indirect call is a branch +//! target predict; LazyLock deref is one atomic-acquire load that's +//! cache-resident after first hit). Invisible against any SIMD op's +//! actual work (~100+ cycles). +//! +//! ```text +//! consumer site: crate::simd_runtime::vnni_dot_u8_i8(&a, &b) +//! ↓ inline +//! unsafe { (*VNNI_DOT_U8_I8_DISPATCH)(a, b) } +//! ↓ (indirect call, target predicted after first hit) +//! vnni_dot_u8_i8_avx512_with_tail (or vnni2 or scalar) +//! ↓ delegate to existing kernel +//! crate::simd_amx::vnni_dot_u8_i8(aligned_slice) +//! ↓ _mm512_dpbusd_epi32 × ceil(aligned/64) +//! _mm512_reduce_add_epi32 +//! ``` +//! +//! # Backend reuse — no kernel duplication +//! +//! Every dispatch arm in this module delegates to a kernel that already +//! exists in tree. The runtime layer is just the trampoline; the +//! kernels were added by: +//! +//! | Op | Backends (from where) | +//! |---|---| +//! | `vnni_dot_u8_i8` | `simd_amx::{vnni_dot_u8_i8, vnni2_dot_u8_i8, vnni_dot_u8_i8_scalar}` (PR #143) | +//! | `matmul_i8_to_i32` | `hpc::int8_tile_gemm::{int8_gemm_amx_tiled, int8_gemm_vpdpbusd_zmm, int8_gemm_vpdpbusd_ymm}` (PR #184/#185) | +//! | `matmul_bf16_to_f32` | `hpc::amx_matmul::matmul_bf16_to_f32` (already runtime-dispatched internally, PR #182) | +//! | `gemm_u8_i8` | `simd_int_ops::gemm_u8_i8` (already runtime-dispatched internally, PR #185) | +//! | `cast_*_f16` | `simd_half::{cast_f16_to_f32_batch, cast_f32_to_f16_batch}` (already runtime-dispatched, PR #183) | +//! | `add_mul_f32` / `add_mul_f64` | NEW direct-FMA backends in `add_mul.rs` (no pre-existing slice kernel) | +//! +//! The only NEW kernel code in this module is `add_mul_f32` / `add_mul_f64` — +//! the rest is trampolines around existing functions. +//! +//! # Mutually exclusive with `nightly-simd` +//! +//! The portable-SIMD polyfill (`--features nightly-simd`) replaces the +//! architecture-specific intrinsics in `simd_avx2.rs` / `simd_avx512.rs` +//! / `simd_neon.rs` with `core::simd::*`. The runtime trampolines in +//! this module select BETWEEN those architecture-specific kernels, +//! so the two features are semantically incompatible. The compile-time +//! gate in `Cargo.toml` enforces this. +//! +//! # When to use which dispatch model +//! +//! - **Compile-time `crate::simd::*` / `crate::simd_ops::*`** (default): +//! benchmarking, library that gets `target-cpu`-tuned per deployment, +//! embedded / fixed-target builds. Direct monomorphized calls; no +//! runtime branch on the SIMD op. +//! - **Runtime `crate::simd_runtime::*`** (this module): production +//! release binary that ships once and runs on heterogeneous silicon. +//! One indirect call per op; ~2-3 ns overhead. +//! +//! See `.claude/knowledge/simd-dispatch-architecture.md` § 7.1 / Phase 5 +//! for the full design rationale. + +#![cfg(feature = "runtime-dispatch")] + +#[cfg(all(feature = "runtime-dispatch", feature = "nightly-simd"))] +compile_error!( + "features `runtime-dispatch` and `nightly-simd` are mutually exclusive — \ + the portable-SIMD polyfill (nightly-simd) replaces the architecture-specific \ + intrinsics that runtime-dispatch selects between." +); + +pub mod add_mul; +pub mod casts; +pub mod matmul; +pub mod vnni_dot; + +// Re-export the public trampoline entry points at the module root so +// consumers can `use crate::simd_runtime::*` and get every op flat. +pub use add_mul::{add_mul_f32, add_mul_f64}; +pub use casts::{bf16_to_f32_batch, cast_f16_to_f32_batch, cast_f32_to_f16_batch, f32_to_bf16_batch_rne}; +pub use matmul::{gemm_u8_i8, matmul_bf16_to_f32, matmul_f32, matmul_i8_to_i32}; +pub use vnni_dot::vnni_dot_u8_i8; diff --git a/src/simd_runtime/vnni_dot.rs b/src/simd_runtime/vnni_dot.rs new file mode 100644 index 00000000..4abafede --- /dev/null +++ b/src/simd_runtime/vnni_dot.rs @@ -0,0 +1,173 @@ +//! Runtime-dispatched `vnni_dot_u8_i8` — `u8 × i8 → i32` dot product. +//! +//! Three backends, selected once at first call via `LazyLock`: +//! +//! 1. **AVX-512 VNNI** (`avx512f + avx512vnni`) → `simd_amx::vnni_dot_u8_i8` +//! wrapped with a scalar tail loop. The existing kernel processes +//! `n - (n%64)` lanes and silently drops the K-tail (its current +//! matvec caller pre-aligns rows so the tail is never the bug; +//! a general-purpose dispatch surface cannot assume that). +//! 2. **AVX-VNNI ymm** (`avx2 + avxvnniint8`) → `simd_amx::vnni2_dot_u8_i8` +//! which already includes its own tail handling. +//! 3. **Scalar** → `simd_amx::vnni_dot_u8_i8_scalar` (always correct, +//! no SIMD, the safety floor). +//! +//! On this host (Sapphire Rapids: AMX + AVX-512 + VNNI), tier 1 fires. +//! On Arrow Lake / Meteor Lake U (AVX-VNNI but no AVX-512), tier 2. +//! On Pi 4 / pre-2013 x86 / any non-x86 target, tier 3. + +use std::sync::LazyLock; + +type VnniDotFn = unsafe fn(&[u8], &[i8]) -> i32; + +/// Static dispatch pointer. Set once on first call to +/// [`vnni_dot_u8_i8`]; every subsequent call is one indirect call +/// through this pointer. +static VNNI_DOT_U8_I8_DISPATCH: LazyLock = LazyLock::new(|| { + let _caps = crate::hpc::simd_caps::simd_caps(); + #[cfg(target_arch = "x86_64")] + { + if _caps.avx512f && _caps.avx512vnni { + return vnni_dot_u8_i8_avx512_with_tail as VnniDotFn; + } + if _caps.avx2 && _caps.avxvnniint8 { + return vnni2_dot_u8_i8_safe_wrapper as VnniDotFn; + } + } + vnni_dot_u8_i8_scalar_safe_wrapper as VnniDotFn +}); + +/// `u8 × i8 → i32` dot product, runtime-dispatched to the best +/// available backend on the current CPU. +/// +/// Returns the i32 sum of `a[i] * b[i]` for `i in 0..min(a.len(), b.len())`. +/// Wrapping arithmetic on accumulation; `127 × 255 × n` fits in i32 for +/// `n < ~32K`, so this is safe for any realistic vector length. +/// +/// # Performance +/// First call: ~1µs (LazyLock initialization + CPUID). +/// Subsequent calls: ~2-3ns overhead + the chosen backend's runtime +/// (typically ~1 cycle per 4-8 lanes on the SIMD tiers). +#[inline] +pub fn vnni_dot_u8_i8(a: &[u8], b: &[i8]) -> i32 { + // SAFETY: the dispatch closure picked a function whose + // `#[target_feature]` requirement is satisfied by the runtime CPU + // caps it checked. Each backend handles its own slice boundary. + unsafe { (*VNNI_DOT_U8_I8_DISPATCH)(a, b) } +} + +// ──────────────────────────────────────────────────────────────────────── +// Tier 1 — AVX-512 VNNI with scalar tail +// ──────────────────────────────────────────────────────────────────────── + +/// Wraps `simd_amx::vnni_dot_u8_i8` to handle the K-tail. +/// +/// The underlying kernel computes `chunks = n / 64` and reduces the +/// chunked sum, silently dropping `n % 64` lanes. That's fine for its +/// matvec caller (which pre-aligns rows), but a general-purpose +/// dispatch surface needs the tail too. +/// +/// # Safety +/// Caller must have feature-detected `avx512f + avx512vnni` at runtime. +/// The dispatch closure above checks this before installing this +/// function as the dispatch target. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f,avx512vnni")] +unsafe fn vnni_dot_u8_i8_avx512_with_tail(a: &[u8], b: &[i8]) -> i32 { + let n = a.len().min(b.len()); + let aligned = n - (n % 64); + // SAFETY: avx512f + avx512vnni enabled on this function via + // target_feature, satisfying simd_amx::vnni_dot_u8_i8's contract. + // `aligned` is a multiple of 64 by construction; slicing both + // operands to `aligned` length stays within their bounds. + let mut total: i32 = if aligned > 0 { + crate::simd_amx::vnni_dot_u8_i8(&a[..aligned], &b[..aligned]) + } else { + 0 + }; + for i in aligned..n { + total = total.wrapping_add((a[i] as i32) * (b[i] as i32)); + } + total +} + +// ──────────────────────────────────────────────────────────────────────── +// Tier 2 — AVX-VNNI 256-bit +// ──────────────────────────────────────────────────────────────────────── + +/// Thin safe wrapper around `simd_amx::vnni2_dot_u8_i8` — that kernel +/// already handles its K-tail (chunked at 32 + scalar remainder), so +/// no additional logic needed here. The wrapper exists only to give +/// the dispatch table a `VnniDotFn` pointer with the right signature. +/// +/// # Safety +/// Caller must have feature-detected `avx2 + avxvnniint8` at runtime. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,avxvnniint8")] +unsafe fn vnni2_dot_u8_i8_safe_wrapper(a: &[u8], b: &[i8]) -> i32 { + crate::simd_amx::vnni2_dot_u8_i8(a, b) +} + +// ──────────────────────────────────────────────────────────────────────── +// Tier 3 — Scalar fallback (always correct) +// ──────────────────────────────────────────────────────────────────────── + +/// Thin wrapper around `simd_amx::vnni_dot_u8_i8_scalar` for dispatch +/// table type-compatibility (the scalar function is `fn`, not `unsafe +/// fn`, so we wrap to match the `VnniDotFn` pointer type). +/// +/// # Safety +/// No actual unsafety — the wrapped function is safe. The `unsafe` +/// qualifier exists only because the dispatch table's pointer type is +/// `unsafe fn` (to accommodate the AVX-512 / AVX-VNNI arms above which +/// genuinely need it). +unsafe fn vnni_dot_u8_i8_scalar_safe_wrapper(a: &[u8], b: &[i8]) -> i32 { + crate::simd_amx::vnni_dot_u8_i8_scalar(a, b) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Verify the runtime trampoline produces the same result as the + /// scalar reference for a few representative shapes — aligned and + /// non-aligned K (which exercises the AVX-512 wrapper's scalar + /// tail handling). + #[test] + fn vnni_dot_matches_scalar_reference() { + // Aligned K = 64: exercises the chunked-only path on tier 1. + let a: Vec = (0..64).map(|i| ((i * 7 + 3) % 256) as u8).collect(); + let b: Vec = (0..64).map(|i| ((i * 11 + 5) % 256) as u8 as i8).collect(); + let got = vnni_dot_u8_i8(&a, &b); + let expected: i32 = (0..64).map(|i| (a[i] as i32) * (b[i] as i32)).sum(); + assert_eq!(got, expected, "aligned K=64 mismatch"); + + // K = 100 (k % 64 = 36): exercises the scalar tail wrapper on + // tier 1, and the existing tail handling on tier 2. + let a: Vec = (0..100).map(|i| ((i * 13 + 1) % 256) as u8).collect(); + let b: Vec = (0..100).map(|i| ((i * 17 + 9) % 256) as u8 as i8).collect(); + let got = vnni_dot_u8_i8(&a, &b); + let expected: i32 = (0..100).map(|i| (a[i] as i32) * (b[i] as i32)).sum(); + assert_eq!(got, expected, "K=100 (tail) mismatch"); + + // Small K = 7: less than any SIMD chunk, exercises the tail + // path entirely (or the scalar tier directly). + let a: Vec = vec![1, 2, 3, 4, 5, 6, 7]; + let b: Vec = vec![-1, 2, -3, 4, -5, 6, -7]; + let got = vnni_dot_u8_i8(&a, &b); + let expected: i32 = (0..a.len()).map(|i| a[i] as i32 * b[i] as i32).sum(); + assert_eq!(got, expected, "small K=7 mismatch"); + } + + /// Confirm the LazyLock fires exactly once — subsequent calls + /// reuse the same dispatch target. + #[test] + fn dispatch_target_stable() { + // Force first call. + let _ = vnni_dot_u8_i8(&[0u8; 1], &[0i8; 1]); + let ptr1 = *VNNI_DOT_U8_I8_DISPATCH as *const (); + let _ = vnni_dot_u8_i8(&[1u8; 1], &[1i8; 1]); + let ptr2 = *VNNI_DOT_U8_I8_DISPATCH as *const (); + assert_eq!(ptr1, ptr2, "dispatch target changed between calls"); + } +}