From b1979d799e1311a084c983bade932e85f1757ba0 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 05:45:33 +0000 Subject: [PATCH 1/4] =?UTF-8?q?feat(hpc):=20TD-T2=20=E2=80=94=20AMX=20TDPB?= =?UTF-8?q?USD=20tile=20kernel=20+=20matmul=5Fi8=5Fto=5Fi32=20wiring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirror of the BF16 AMX work (TD-T1 / TD-T1b in PR #182) for the integer operand family. Builds the missing int8 tile kernel from scratch (the BF16 equivalent shipped in PR #104; the int8 one had never been built despite the primitives existing in simd_amx since day one) and wires matmul_i8_to_i32's AMX arm through it. New module `hpc::int8_tile_gemm`: * `int8_tile_gemm_16x16(a_u8, b_i8, c, k)` — public tile kernel, K must be multiple of 64. Mirror shape of `bf16_tile_gemm_16x16` but for the `u8 × i8 → i32` operand family that TDPBUSD natively supports. **One TDPBUSD = 16 384 multiply-accumulates per instruction** (16×16 output tile × 64 K-elements per A row × 4 K-elements per inner-product). That's 256× the VPDPBUSD-zmm throughput per instruction. * Internal `amx_path()` uses the existing primitives in `amx_matmul`: TileConfig::for_dpbusd(64) → tile_loadconfig → tile_zero → K/64 iterations of (tile_load A, tile_load B, tile_dpbusd) → tile_store → tile_release. * `fallback_path()` for non-AMX hosts: scalar u8 × i8 → i32 triple-loop reference. New primitive `amx_matmul::vnni_pack_i8(src, dst, k, n)`: * Packs K × N row-major i8 into K/4 outer rows × (N*4) VNNI quad layout required by TDPBUSD tile 2. * `dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]` * Sibling of `vnni_pack_bf16` (which uses K/2 × (N*2) pair layout for TDPBF16PS — both kernels reach the same 64-byte tile row width via element-width × pack-factor symmetry: BF16 is 2B × 2, INT8 is 1B × 4). Wiring `matmul_i8_to_i32`'s AMX arm (was placebo): Pre-commit the AMX branch shifted i8 → u8 then called the SCALAR `int8_gemm_i32` reference and subtracted the bias — TDPBUSD itself was never reached even on real AMX silicon. Now: 1. Shift A: i8 → u8 via (+128). 2. Tile-loop over M/16 i_tile × N/16 j_tile blocks, calling int8_tile_gemm_16x16 per (i_tile, j_tile). B sub-block extracted into K × 16 scratch once per j_tile, reused across i_tile iterations. 3. Subtract bias: c[i, j] -= 128 × colsum(B[:, j]). The shape requirement is m%16 == 0 && n%16 == 0 && k%64 == 0; misaligned shapes fall back to the scalar reference. Phase-4 work will land mixed AMX-tile + per-axis scalar tail handling for arbitrary shapes (same shape of Phase-4 work TD-T1 deferred). Verification: * Default v3 build: 2092 lib tests pass (was 2087 — adds 5 new tests: 4 in int8_tile_gemm + the existing matmul_i8_to_i32 test now exercises the actual TDPBUSD path because this host has amx_int8 + amx_tile in /proc/cpuinfo; the test continues to pass with bit-identical results to the scalar reference). * `vnni_pack_i8_roundtrip` test verifies the pack layout matches the spec exactly for an 8 × 4 sample. * `fallback_matches_scalar_reference_k64` test verifies the non-AMX path produces the same i32 output as a hand-written reference for a 64-K, pseudo-random u8/i8 matrix pair. * `public_api_diagonal_k128` test asserts a structured pattern (A = identity-like, B = constant 2) gives the expected accumulation through the full dispatch chain. * `cargo clippy --lib -D warnings` clean. * `cargo fmt --all --check` clean. Dropped: `int8_gemm_i32` import in `amx_matmul.rs` since the AMX arm no longer falls back to it (the scalar else-branch uses an inline triple-loop directly). After this commit, the per-CPU dispatch table from PR #180 has the AMX tier wired for BOTH operand families on Sapphire Rapids+: BF16 GEMM: SPR+ → TDPBF16PS (TD-T1 / TD-T1b in PR #182) INT8 GEMM: SPR+ → TDPBUSD (this commit) Out of scope (separate PRs): * VPDPBUSD-zmm arm of matmul_i8_to_i32 for Cooper Lake / Cascade Lake / Zen 4+ (avx512vnni without AMX). The kernel function `vnni_dot_u8_i8` and `vnni_matvec` exist in simd_amx.rs; just need to assemble them into a m×n×k GEMM and wire as the middle dispatch tier (analogous to the VDPBF16PS arm in PR #182's bf16_gemm_dispatch). * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level surface from PR #182) — it's u8 × i8 natively so no sign-shift needed, simpler to wire than matmul_i8_to_i32. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/amx_matmul.rs | 84 ++++++++++++--- src/hpc/int8_tile_gemm.rs | 214 ++++++++++++++++++++++++++++++++++++++ src/hpc/mod.rs | 4 + 3 files changed, 288 insertions(+), 14 deletions(-) create mode 100644 src/hpc/int8_tile_gemm.rs diff --git a/src/hpc/amx_matmul.rs b/src/hpc/amx_matmul.rs index 4c64177d..df959db0 100644 --- a/src/hpc/amx_matmul.rs +++ b/src/hpc/amx_matmul.rs @@ -193,6 +193,32 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) { } } +/// Pack B[K, N] i8 row-major into K/4 × (N*4) VNNI quads for `TDPBUSD`. +/// +/// Output layout required by `TDPBUSD` tile 2 (16 rows × 64 bytes): +/// dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j] +/// +/// For N=16 (AMX tile width), each output "row" holds 16 i8 quads = 64 +/// bytes (matches the 64-byte tile row width). K must be a multiple of +/// 4. The same layout is used for `u8` operands (just bit-cast through +/// — VNNI doesn't care about sign at the packing layer; sign +/// interpretation happens inside TDPBUSD which treats A as u8 and B +/// as i8 for the multiply). +#[inline] +pub fn vnni_pack_i8(src: &[i8], dst: &mut [i8], k: usize, n: usize) { + debug_assert_eq!(src.len(), k * n); + debug_assert_eq!(dst.len(), k * n); + debug_assert_eq!(k % 4, 0, "K must be multiple of 4 for VNNI INT8 quads"); + for kb in 0..(k / 4) { + let dst_row = kb * n * 4; + for j in 0..n { + for p in 0..4 { + dst[dst_row + j * 4 + p] = src[(4 * kb + p) * n + j]; + } + } + } +} + // ═══════════════════════════════════════════════════════════════════════════ // Public ndarray-typed matmul API (sprint A4 / Burn parity item 6) // ═══════════════════════════════════════════════════════════════════════════ @@ -207,7 +233,7 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) { // strided (e.g. `view.slice(s![.., ..;2])`). Strided inputs are repacked // into contiguous staging buffers before the kernel runs. -use crate::hpc::quantized::{bf16_gemm_f32, int8_gemm_i32, BF16}; +use crate::hpc::quantized::{bf16_gemm_f32, BF16}; use crate::{ArrayView2, ArrayViewMut2}; /// Errors returned by the public AMX matmul API. @@ -537,14 +563,17 @@ pub fn matmul_f32( /// Matrix multiply i8 × i8 → i32: `out = lhs · rhs`. /// -/// On AMX hosts uses `TDPBUSD` (256 MACs/instr); otherwise falls back to -/// the scalar `int8_gemm_i32`. +/// On AMX hosts with 16/16/64-aligned shapes uses `TDPBUSD` via the +/// 16×16 tile kernel in [`crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16`] +/// — 16 384 MACs per instruction. Mis-aligned shapes (or non-AMX hosts) +/// fall back to the scalar i8×i8 → i32 reference. /// -/// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For the -/// signed-by-signed surface required here, the LHS is shifted into the -/// unsigned domain and the bias subtracted from the accumulator (only on -/// the AMX path; the scalar path operates directly in i8). The public -/// result is identical. +/// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For +/// the signed-by-signed surface required here, the LHS is shifted into +/// the unsigned domain (i8 + 128 → u8) and the bias `128 · sum(B[:, j] +/// over k)` is subtracted from the accumulator. The public result is +/// bit-identical to the scalar reference because all arithmetic stays +/// in i32 (no float rounding). /// /// `out` must be row-contiguous; inputs may be strided. pub fn matmul_i8_to_i32( @@ -556,13 +585,39 @@ pub fn matmul_i8_to_i32( let b_i8 = pack_contig(&rhs); let mut c = vec![0i32; m * n]; - if amx_available() { - // AMX TDPBUSD path: shift LHS i8 → u8 via (+128) and subtract the - // bias 128·sum(B[:, j] over k) afterwards. This keeps numerics exact. + if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 { + // AMX TDPBUSD path: shift LHS i8 → u8 via (+128), tile-GEMM into + // i32, subtract bias 128·colsum(B). The tile kernel zeroes its + // internal accumulator (TILEZERO + TDPBUSD accumulate); we need + // fresh per-tile output here so we tile manually over M/N and + // call int8_tile_gemm_16x16 per (i, j) block. let a_u8: Vec = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect(); - // Compute C' = A_u8 · B_i8 in i32, then subtract 128 · colsum(B). - int8_gemm_i32(&a_u8, &b_i8, &mut c, m, n, k); + // B sub-block extraction per j-tile (B is row-major K × N; the + // tile kernel wants K × 16 contiguous). Reused across i-tiles. + let mut b_tile = vec![0i8; k * 16]; + let mut tile_c = vec![0i32; 256]; + + for j_tile in (0..n).step_by(16) { + // Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows. + for kk in 0..k { + let row = kk * n + j_tile; + b_tile[kk * 16..(kk + 1) * 16] + .copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) }); + } + for i_tile in (0..m).step_by(16) { + let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k]; + tile_c.fill(0); + crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k); + // Write tile_c (16 × 16) into c at (i_tile, j_tile). + for ii in 0..16 { + let dst_off = (i_tile + ii) * n + j_tile; + c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]); + } + } + } + + // Subtract bias: c[i, j] -= 128 · colsum(B[:, j]). let mut colsum = vec![0i32; n]; for p in 0..k { for j in 0..n { @@ -575,7 +630,8 @@ pub fn matmul_i8_to_i32( } } } else { - // Scalar i8×i8 → i32 reference. + // Scalar i8×i8 → i32 reference — used for non-AMX hosts and for + // shapes that don't fit the 16/16/64 tile alignment. for i in 0..m { for p in 0..k { let av = a_i8[i * k + p] as i32; diff --git a/src/hpc/int8_tile_gemm.rs b/src/hpc/int8_tile_gemm.rs new file mode 100644 index 00000000..05e355b0 --- /dev/null +++ b/src/hpc/int8_tile_gemm.rs @@ -0,0 +1,214 @@ +//! INT8 tile GEMM polyfill — AMX (TDPBUSD) tile kernel. +//! +//! Mirror of `hpc::bf16_tile_gemm` for the `u8 × i8 → i32` shape, the +//! native TDPBUSD operand type. One TDPBUSD: 16×16 output tile × 64 +//! K-elements per A row × 4 K-elements per inner product = **16 384 +//! multiply-accumulates per instruction**. That's 256× the VPDPBUSD +//! zmm throughput per instruction (which does 16 × 4 = 64 MACs). +//! +//! Public surface: +//! * [`int8_tile_gemm_16x16`] — the 16×16 tile kernel; M=16, N=16, +//! K a multiple of 64. AMX path requires runtime feature +//! detection (`amx_available()`); falls back to a scalar reference +//! when AMX isn't OS-enabled. +//! +//! Caller responsibility: +//! * B comes in row-major K × 16 i8; the kernel pre-packs it into +//! VNNI quad layout via [`super::amx_matmul::vnni_pack_i8`]. +//! * A is row-major 16 × K u8 (TDPBUSD's unsigned operand). +//! * C accumulates into the caller's i32 buffer (16 × 16 = 256 i32). +//! +//! Same shape as `bf16_tile_gemm::bf16_tile_gemm_16x16`. The two kernels +//! together cover the SPR/GNR AMX dispatch tier for both `BF16 × BF16 +//! → f32` and `u8 × i8 → i32` — the two operand families that AMX +//! supports natively. + +use crate::hpc::amx_matmul::{ + amx_available, tile_dpbusd, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_i8, + TileConfig, +}; + +// ═════════════════════════════════════════════════════════════════════ +// Public API — safe dispatching wrapper +// ═════════════════════════════════════════════════════════════════════ + +/// Compute C[16, 16] += A[16, K] × B[K, 16] where A is u8 row-major, +/// B is i8 row-major, C is i32 row-major. K must be a multiple of 64. +/// +/// Tier dispatch (runtime): +/// AMX available → TDPBUSD tile GEMM (16×16 × K/64 tile iterations, +/// 16 384 MACs per instruction) +/// AMX unavailable → scalar u8 × i8 → i32 reference +/// +/// Output behavior: this function **accumulates** into `c` (does NOT +/// zero it first). Callers wanting fresh `C = A·B` semantics should +/// zero `c` before calling, the same convention `bf16_tile_gemm_16x16` +/// uses. +pub fn int8_tile_gemm_16x16(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) { + assert_eq!(k % 64, 0, "K must be multiple of 64 for TDPBUSD tiles"); + assert_eq!(a_u8.len(), 16 * k); + assert_eq!(b_i8.len(), k * 16); + assert_eq!(c.len(), 16 * 16); + + if amx_available() { + // AMX path: pack B into VNNI quad layout, call tile GEMM. + let mut b_vnni = vec![0i8; k * 16]; + vnni_pack_i8(b_i8, &mut b_vnni, k, 16); + // SAFETY: amx_available() just confirmed CPUID + XCR0 + prctl. + unsafe { + amx_path(a_u8, &b_vnni, c, k); + } + } else { + fallback_path(a_u8, b_i8, c, k); + } +} + +// ═════════════════════════════════════════════════════════════════════ +// AMX path (TDPBUSD) +// ═════════════════════════════════════════════════════════════════════ + +/// AMX tile GEMM. B must be pre-VNNI-packed (see `vnni_pack_i8`). +/// # Safety +/// Caller must have verified `amx_available() == true`. +#[inline] +unsafe fn amx_path(a_u8: &[u8], b_vnni: &[i8], c: &mut [i32], k: usize) { + // Tile config: 16×64-byte tiles, identical shape to the BF16 tile + // (BF16 is 32 elements × 2 bytes per row, INT8 is 64 elements × 1 + // byte — same 64-byte row width either way). + let cfg = TileConfig::for_dpbusd(64); + tile_loadconfig(&cfg); + tile_zero(0); + + // Accumulate over K/64 tile blocks. Each TDPBUSD consumes 64 + // K-elements per A row × 4 K-elements per inner-product = 256 MACs + // per output cell × 16 × 16 = 16 384 MACs per instruction. + let k_blocks = k / 64; + let a_stride = k; // bytes per A row (u8 = 1 byte each) + let b_stride = 64usize; // VNNI: 16 columns × 4 bytes per row + + for kb in 0..k_blocks { + let a_ptr = a_u8.as_ptr().add(kb * 64); + // B sits in VNNI layout: K/4 outer rows × 64 bytes. Each + // 64-K-element block spans 16 outer rows × 64 bytes = 1024 + // bytes. + let b_ptr = b_vnni.as_ptr().add(kb * 16 * 64) as *const u8; + tile_load(1, a_ptr, a_stride); + tile_load(2, b_ptr, b_stride); + tile_dpbusd(); + } + + tile_store(0, c.as_mut_ptr() as *mut u8, 64); + tile_release(); +} + +// ═════════════════════════════════════════════════════════════════════ +// Scalar fallback (i32 reference) +// ═════════════════════════════════════════════════════════════════════ + +/// Direct scalar u8 × i8 → i32 reference. Accumulates into `c`. +fn fallback_path(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) { + for i in 0..16 { + for kk in 0..k { + let a_val = a_u8[i * k + kk] as i32; + for j in 0..16 { + c[i * 16 + j] += a_val * b_i8[kk * 16 + j] as i32; + } + } + } +} + +// ═════════════════════════════════════════════════════════════════════ +// Tests +// ═════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + /// Reference: scalar u8 × i8 → i32 (matches `fallback_path`). + fn ref_gemm(a: &[u8], b: &[i8], c: &mut [i32], k: usize) { + for i in 0..16 { + for j in 0..16 { + let mut s = 0i32; + for kk in 0..k { + s += a[i * k + kk] as i32 * b[kk * 16 + j] as i32; + } + c[i * 16 + j] = s; + } + } + } + + #[test] + fn fallback_matches_scalar_reference_k64() { + let k = 64; + // Deterministic pseudo-random inputs covering the u8 / i8 ranges. + let a: Vec = (0..16 * k).map(|i| ((i * 7 + 3) % 256) as u8).collect(); + let b: Vec = (0..k * 16) + .map(|i| (((i * 11 + 5) % 256) as u8 as i8)) + .collect(); + + let mut c_ref = vec![0i32; 256]; + ref_gemm(&a, &b, &mut c_ref, k); + + let mut c_fb = vec![0i32; 256]; + fallback_path(&a, &b, &mut c_fb, k); + + for i in 0..256 { + assert_eq!(c_fb[i], c_ref[i], "fallback mismatch at {}", i); + } + } + + #[test] + fn public_api_runs_on_any_hardware_k64() { + let k = 64; + let a = vec![0u8; 16 * k]; + let b = vec![0i8; k * 16]; + let mut c = vec![0i32; 256]; + int8_tile_gemm_16x16(&a, &b, &mut c, k); + for v in c.iter() { + assert_eq!(*v, 0, "zero × zero must be 0"); + } + } + + #[test] + fn public_api_diagonal_k128() { + // A = identity-like (only A[i, i] = 1, but we need 16 × 128), so + // pick A[i, i*8..i*8+8] = 1 (8 ones per i-row). B = constant 2. + // Expected: C[i, j] = sum_{kk in i*8..i*8+8}(1 × 2) = 16. + let k = 128; + let mut a = vec![0u8; 16 * k]; + for i in 0..16 { + for off in 0..8 { + a[i * k + i * 8 + off] = 1; + } + } + let b = vec![2i8; k * 16]; + let mut c = vec![0i32; 256]; + int8_tile_gemm_16x16(&a, &b, &mut c, k); + for i in 0..16 { + for j in 0..16 { + assert_eq!(c[i * 16 + j], 16, "diagonal accumulator at ({}, {})", i, j); + } + } + } + + #[test] + fn vnni_pack_i8_roundtrip() { + // Pack then verify the VNNI layout matches the spec: + // dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j] + let k = 8usize; + let n = 4usize; + let src: Vec = (0..(k * n) as i8).collect(); + let mut dst = vec![0i8; k * n]; + vnni_pack_i8(&src, &mut dst, k, n); + for kb in 0..(k / 4) { + for j in 0..n { + for p in 0..4 { + let dst_idx = kb * n * 4 + j * 4 + p; + let expected = src[(4 * kb + p) * n + j]; + assert_eq!(dst[dst_idx], expected, "vnni quad mismatch at kb={} j={} p={}", kb, j, p); + } + } + } + } +} diff --git a/src/hpc/mod.rs b/src/hpc/mod.rs index 02a264c2..11081ad6 100644 --- a/src/hpc/mod.rs +++ b/src/hpc/mod.rs @@ -66,6 +66,10 @@ pub mod heel_f64x8; pub mod amx_matmul; #[cfg(target_arch = "x86_64")] pub mod bf16_tile_gemm; +/// INT8 (`u8 × i8 → i32`) tile GEMM via AMX `TDPBUSD` — mirror of +/// `bf16_tile_gemm` for the integer operand family. +#[cfg(target_arch = "x86_64")] +pub mod int8_tile_gemm; #[allow(missing_docs)] pub mod bf16_truth; #[allow(missing_docs)] From 33a2bbbfe170b2921cbd9451b088fa76f22f1870 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 05:52:25 +0000 Subject: [PATCH 2/4] fix(ci): drop unneeded parens + use div_ceil on Rust 1.95 clippy gates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two clippy-as-error issues blocking PR #184 CI: 1. `src/hpc/int8_tile_gemm.rs:147` (mine, from b1979d7) — `clippy::unused_parens` flagged the closure body `(((i*11+5) % 256) as u8 as i8)` in the `fallback_matches_scalar_reference_k64` test. The outer parens around the cast chain are redundant; rustfmt re-broke the line to multi-line after removal so it stays readable. 2. `tests/par_rayon.rs:9` (pre-existing) — `clippy::manual_div_ceil` flagged `(M + CHUNK_SIZE - 1) / CHUNK_SIZE`. Replaced with `M.div_ceil(CHUNK_SIZE)` per the clippy hint. This file was already in tree; the lint became active in clippy 1.95 (Rust stable) which CI now uses, so prior PRs weren't blocked by it but the rayon-features test build is now. Both fixes are mechanical / no behaviour change: * `cargo clippy --tests --features rayon,native -- -D warnings` clean. * `cargo fmt --all --check` clean. Stashed work-in-progress on the VPDPBUSD-zmm middle tier for `matmul_i8_to_i32` (the natural symmetric next step after TD-T2, analogous to the VDPBF16PS arm shipped in PR #182's `bf16_gemm_dispatch`); will follow up in a separate commit once CI is unblocked. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/int8_tile_gemm.rs | 2 +- tests/par_rayon.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hpc/int8_tile_gemm.rs b/src/hpc/int8_tile_gemm.rs index 05e355b0..f1062007 100644 --- a/src/hpc/int8_tile_gemm.rs +++ b/src/hpc/int8_tile_gemm.rs @@ -144,7 +144,7 @@ mod tests { // Deterministic pseudo-random inputs covering the u8 / i8 ranges. let a: Vec = (0..16 * k).map(|i| ((i * 7 + 3) % 256) as u8).collect(); let b: Vec = (0..k * 16) - .map(|i| (((i * 11 + 5) % 256) as u8 as i8)) + .map(|i| ((i * 11 + 5) % 256) as u8 as i8) .collect(); let mut c_ref = vec![0i32; 256]; diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 2cfb33e2..776d7686 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -6,7 +6,7 @@ use ndarray::prelude::*; const M: usize = 1024 * 10; const N: usize = 100; const CHUNK_SIZE: usize = 100; -const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE; +const N_CHUNKS: usize = M.div_ceil(CHUNK_SIZE); #[test] fn test_axis_iter() { From bb7b9b7c4e6d2194d9ebadfa463ec65b8ec1c9cd Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 05:57:42 +0000 Subject: [PATCH 3/4] feat(hpc): VPDPBUSD-zmm middle tier for matmul_i8_to_i32 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Completes the per-CPU dispatch chain for `matmul_i8_to_i32`. Per PR #180's table the middle tier between AMX TDPBUSD (Sapphire Rapids+) and the scalar reference is `_mm512_dpbusd_epi32` (zmm form, avx512vnni feature) — covers Cooper Lake, Cascade Lake, Ice Lake-SP, Zen 4+ silicon that has AVX-512 VNNI but not AMX. Mirrors the VDPBF16PS arm structure that landed for BF16 in PR #182's `bf16_gemm_dispatch`. New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm`: * One VPDPBUSD instruction: 16 i32 accumulator lanes, each receiving 4 u8×i8 products = 64 MACs per instruction. * Maps the 16 output lanes to a row of 16 j-columns of `c[i, ·]`, one i row processed at a time, K-quad inner loop accumulating into the same 16 i32 lanes across iterations. * B-column packing: pre-packs B for the current j-block into `b_col_quads[k_quad * 16 + j] = i32 (4 bytes of B[4k_quad.., j_base+j] packed bottom-to-top)` once per j-block; reused across all M i-iterations so the gather cost amortizes. * A row quad broadcast: `_mm512_set1_epi32` of (4 u8 bytes packed) every K-iter — same quad seen by every output column. * K-tail (k % 4 != 0) handled with scalar u8×i8 multiplies per output cell; N-tail (j_count < 16) handled by trimming the store width — padding lanes still receive VPDPBUSD updates but aren't written back. * Stable intrinsic `_mm512_dpbusd_epi32` under `target_feature = "avx512vnni,avx512f"` — no asm-byte needed. Wiring `matmul_i8_to_i32` to three-tier dispatch: 1. amx_available() + 16/16/64-aligned shapes → int8_tile_gemm_16x16 → TDPBUSD asm-byte (16 384 MACs/instr, this commit reuses the kernel from PR #184 fe334de... wait, same PR — from b1979d7 in THIS PR) 2. is_x86_feature_detected!("avx512vnni") → int8_gemm_vpdpbusd_zmm → _mm512_dpbusd_epi32 stable intrinsic (64 MACs/instr, arbitrary shapes, K-tail handled scalar, N-tail handled by per-iteration j_count trim) 3. scalar i8×i8 → i32 reference for non-x86, pre-AVX-512 hosts, or shapes that don't satisfy either SIMD tier's requirements Factored the shared sign-shift bias subtraction into a private `subtract_i8_to_u8_bias(c, b_i8, m, n, k)` helper: both Tier 1 (AMX) and Tier 2 (VNNI) shift LHS i8 → u8 via (+128) then need to subtract 128·colsum(B) from the accumulator. Pure integer arithmetic, bit-identical to the scalar i8×i8 → i32 reference. Verification: * Default v3 build: 2093 lib tests pass (was 2092 — +1 new test `vpdpbusd_zmm_matches_scalar` that exercises the new arm directly with shapes spanning aligned cases, K-tail (k % 4), N-tail (n % 16), and small shapes; asserts byte-equal output vs scalar reference). * Existing `matmul_i8_to_i32_16x16_exact` continues to pass through the AMX tier on this host (which has amx_int8). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit: matmul_bf16_to_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 VPDPBUSD | scalar (b1979d7) | (THIS COMMIT) | (always) So all three of the public matmul entry points now have full three-tier dispatch on x86_64. Out of scope (separate PRs): * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level u8×i8 surface from PR #182) — it's u8×i8 natively, no sign- shift bias needed, simpler than matmul_i8_to_i32. * AVX-VNNI ymm arm (Arrow Lake / Meteor Lake U: avxvnni without avx512vnni) — the `vnni2_*` functions exist in simd_amx.rs but need to be assembled into a m×n×k VNNI-ymm GEMM. Same shape as the avx512vnni arm just with ymm width. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/amx_matmul.rs | 59 +++++++++------ src/hpc/int8_tile_gemm.rs | 147 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+), 23 deletions(-) diff --git a/src/hpc/amx_matmul.rs b/src/hpc/amx_matmul.rs index df959db0..eaff247e 100644 --- a/src/hpc/amx_matmul.rs +++ b/src/hpc/amx_matmul.rs @@ -586,20 +586,14 @@ pub fn matmul_i8_to_i32( let mut c = vec![0i32; m * n]; if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 { - // AMX TDPBUSD path: shift LHS i8 → u8 via (+128), tile-GEMM into - // i32, subtract bias 128·colsum(B). The tile kernel zeroes its - // internal accumulator (TILEZERO + TDPBUSD accumulate); we need - // fresh per-tile output here so we tile manually over M/N and - // call int8_tile_gemm_16x16 per (i, j) block. + // Tier 1 — AMX TDPBUSD tile path: shift LHS i8 → u8 (+128), + // tile-GEMM via int8_tile_gemm_16x16, subtract bias. let a_u8: Vec = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect(); - // B sub-block extraction per j-tile (B is row-major K × N; the - // tile kernel wants K × 16 contiguous). Reused across i-tiles. let mut b_tile = vec![0i8; k * 16]; let mut tile_c = vec![0i32; 256]; for j_tile in (0..n).step_by(16) { - // Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows. for kk in 0..k { let row = kk * n + j_tile; b_tile[kk * 16..(kk + 1) * 16] @@ -609,29 +603,27 @@ pub fn matmul_i8_to_i32( let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k]; tile_c.fill(0); crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k); - // Write tile_c (16 × 16) into c at (i_tile, j_tile). for ii in 0..16 { let dst_off = (i_tile + ii) * n + j_tile; c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]); } } } - - // Subtract bias: c[i, j] -= 128 · colsum(B[:, j]). - let mut colsum = vec![0i32; n]; - for p in 0..k { - for j in 0..n { - colsum[j] += b_i8[p * n + j] as i32; - } - } - for i in 0..m { - for j in 0..n { - c[i * n + j] -= 128 * colsum[j]; - } + subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k); + } else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avx512vnni") { + // Tier 2 — AVX-512 VPDPBUSD zmm: 64 MACs per instruction, no + // shape-alignment requirement (M/N/K all handled via per-block + // trim and scalar K-tail). Same sign-shift bias trick as AMX. + let a_u8: Vec = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect(); + // SAFETY: runtime feature-detected avx512vnni above. + unsafe { + crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm(&a_u8, &b_i8, &mut c, m, n, k); } + subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k); } else { - // Scalar i8×i8 → i32 reference — used for non-AMX hosts and for - // shapes that don't fit the 16/16/64 tile alignment. + // Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts, + // pre-AVX-512 silicon, or shapes that don't satisfy either of + // the SIMD tiers' alignment requirements. for i in 0..m { for p in 0..k { let av = a_i8[i * k + p] as i32; @@ -653,6 +645,27 @@ pub fn matmul_i8_to_i32( Ok(()) } +/// Subtract `128 · colsum(B[:, j])` from each `c[i, j]` lane. +/// +/// Used by both the AMX and AVX-512-VNNI arms of `matmul_i8_to_i32` +/// to undo the LHS sign-shift bias (A_i8 → A_u8 via +128 means +/// `A_u8 · B = (A_i8 + 128) · B = A_i8 · B + 128 · sum_k B[k, j]`). +/// Pure integer arithmetic, no rounding — the public result is +/// bit-identical to the scalar i8 × i8 → i32 reference. +fn subtract_i8_to_u8_bias(c: &mut [i32], b_i8: &[i8], m: usize, n: usize, k: usize) { + let mut colsum = vec![0i32; n]; + for p in 0..k { + for j in 0..n { + colsum[j] += b_i8[p * n + j] as i32; + } + } + for i in 0..m { + for j in 0..n { + c[i * n + j] -= 128 * colsum[j]; + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/hpc/int8_tile_gemm.rs b/src/hpc/int8_tile_gemm.rs index f1062007..d2997f57 100644 --- a/src/hpc/int8_tile_gemm.rs +++ b/src/hpc/int8_tile_gemm.rs @@ -101,6 +101,111 @@ unsafe fn amx_path(a_u8: &[u8], b_vnni: &[i8], c: &mut [i32], k: usize) { tile_release(); } +// ═════════════════════════════════════════════════════════════════════ +// VPDPBUSD-zmm middle tier (avx512vnni without AMX) +// ═════════════════════════════════════════════════════════════════════ + +/// AVX-512 VNNI `u8 × i8 → i32` GEMM kernel for arbitrary M × N × K. +/// +/// One `_mm512_dpbusd_epi32` instruction: 16 i32 accumulator lanes, +/// each receiving the sum of 4 `u8 × i8` products = **64 MACs per +/// instruction**. Pre-packs B in VNNI quad layout once per j-block +/// (16-wide column band) and reuses across all M i-iterations, +/// amortizing the gather cost. +/// +/// K-tail (when K is not a multiple of 4) handled with scalar +/// u8 × i8 multiplies per output cell; N-tail (when the j-block has +/// fewer than 16 valid columns) handled by trimming the store after +/// the VPDPBUSD chain. +/// +/// This is the middle dispatch tier between AMX TDPBUSD (Sapphire +/// Rapids+) and the scalar reference — covers Cooper Lake, Cascade +/// Lake, Ice Lake-SP, Zen 4+ silicon that has avx512vnni but not +/// AMX. Mirrors the VDPBF16PS arm structure shipped for BF16 in +/// PR #182. +/// +/// Output behavior: overwrites `c` (does NOT accumulate). Caller's +/// responsibility to zero `c` first if a fresh-write GEMM is wanted. +/// +/// # Safety +/// Caller must have feature-detected `avx512vnni + avx512f` at runtime. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512vnni,avx512f")] +pub unsafe fn int8_gemm_vpdpbusd_zmm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + use core::arch::x86_64::{ + __m512i, _mm512_dpbusd_epi32, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_setzero_si512, _mm512_storeu_si512, + }; + + let k_quads = k / 4; + let k_tail = k % 4; + + // Pre-pack scratch for B columns of the current j-block: + // 16 i32 lanes per k_quad, each holding 4 consecutive K-bytes + // packed (b[2q+0..2q+4] for output column j+lane). + let mut b_col_quads = vec![0i32; k_quads.max(1) * 16]; + // Scratch for the 16-wide store + N-tail trim. + let mut out_buf = [0i32; 16]; + + for j_base in (0..n).step_by(16) { + let j_count = 16.min(n - j_base); + + // Pack B[0..k, j_base..j_base+j_count] in quad-interleaved layout. + // For lanes j >= j_count (the N-tail of this j_block), pad with 0 + // so the VPDPBUSD doesn't read uninitialized memory; they're not + // stored back. + for k_quad in 0..k_quads { + let row0 = 4 * k_quad * n; + let row1 = (4 * k_quad + 1) * n; + let row2 = (4 * k_quad + 2) * n; + let row3 = (4 * k_quad + 3) * n; + for jj in 0..j_count { + let b0 = b_i8[row0 + j_base + jj] as u8 as u32; + let b1 = b_i8[row1 + j_base + jj] as u8 as u32; + let b2 = b_i8[row2 + j_base + jj] as u8 as u32; + let b3 = b_i8[row3 + j_base + jj] as u8 as u32; + // Pack as i32: bottom byte is k_quad*4+0, top is k_quad*4+3. + b_col_quads[k_quad * 16 + jj] = (b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)) as i32; + } + for jj in j_count..16 { + b_col_quads[k_quad * 16 + jj] = 0; + } + } + + for i in 0..m { + let mut acc = _mm512_setzero_si512(); + let a_row_off = i * k; + for k_quad in 0..k_quads { + // Broadcast A[i, 4*k_quad..4*k_quad+4] (4 u8) across all + // 16 i32 lanes via _mm512_set1_epi32. + let a0 = a_u8[a_row_off + 4 * k_quad] as u32; + let a1 = a_u8[a_row_off + 4 * k_quad + 1] as u32; + let a2 = a_u8[a_row_off + 4 * k_quad + 2] as u32; + let a3 = a_u8[a_row_off + 4 * k_quad + 3] as u32; + let packed_a = a0 | (a1 << 8) | (a2 << 16) | (a3 << 24); + let a_v = _mm512_set1_epi32(packed_a as i32); + let b_v = _mm512_loadu_si512(b_col_quads.as_ptr().add(k_quad * 16) as *const __m512i); + acc = _mm512_dpbusd_epi32(acc, a_v, b_v); + } + _mm512_storeu_si512(out_buf.as_mut_ptr() as *mut __m512i, acc); + + // K-tail: scalar multiplies for k = k_quads*4 .. k. + if k_tail > 0 { + for kk in (k_quads * 4)..k { + let a_val = a_u8[a_row_off + kk] as i32; + let tail_row = kk * n; + for jj in 0..j_count { + out_buf[jj] += a_val * b_i8[tail_row + j_base + jj] as i32; + } + } + } + + // Store j_count valid lanes (drops N-tail padding lanes). + let dst_off = i * n + j_base; + c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]); + } + } +} + // ═════════════════════════════════════════════════════════════════════ // Scalar fallback (i32 reference) // ═════════════════════════════════════════════════════════════════════ @@ -192,6 +297,48 @@ mod tests { } } + /// Direct test for the VPDPBUSD-zmm arm, exercising the path the + /// `matmul_i8_to_i32` dispatcher would skip when AMX is available. + /// Verifies bit-exact parity against the scalar reference for + /// arbitrary (M, N, K) — including non-multiple-of-4 K (so the + /// scalar K-tail branch fires) and non-multiple-of-16 N (so the + /// j-count trim branch fires). + #[cfg(target_arch = "x86_64")] + #[test] + fn vpdpbusd_zmm_matches_scalar() { + if !std::is_x86_feature_detected!("avx512vnni") { + eprintln!("avx512vnni not detected; skipping"); + return; + } + + fn ref_gemm(a: &[u8], b: &[i8], m: usize, n: usize, k: usize) -> Vec { + let mut c = vec![0i32; m * n]; + for i in 0..m { + for kk in 0..k { + let av = a[i * k + kk] as i32; + for j in 0..n { + c[i * n + j] += av * b[kk * n + j] as i32; + } + } + } + c + } + + // Sweep shapes spanning aligned cases, K-tail (k % 4), and + // N-tail (n % 16) to exercise every code path. + for (m, n, k) in [(16, 16, 64), (3, 5, 7), (17, 33, 100), (1, 17, 12), (8, 16, 4)] { + let a: Vec = (0..m * k).map(|i| ((i * 31 + 7) % 256) as u8).collect(); + let b: Vec = (0..k * n) + .map(|i| ((i * 17 + 3) % 256) as u8 as i8) + .collect(); + let expected = ref_gemm(&a, &b, m, n, k); + let mut got = vec![0i32; m * n]; + // SAFETY: avx512vnni confirmed at the top of the test. + unsafe { int8_gemm_vpdpbusd_zmm(&a, &b, &mut got, m, n, k) }; + assert_eq!(got, expected, "VPDPBUSD-zmm mismatch at (M={}, N={}, K={})", m, n, k); + } + } + #[test] fn vnni_pack_i8_roundtrip() { // Pack then verify the VNNI layout matches the spec: From f987937aa2131148da015439cab0dd0342f6a40d Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 06:09:31 +0000 Subject: [PATCH 4/4] fix(hpc): preserve C accumulator in AMX tile paths (codex P1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per codex review on PR #184: `int8_tile_gemm_16x16` is documented as `C += A·B` and the scalar `fallback_path` correctly accumulates, but the AMX `amx_path` did `tile_zero(0)` + `tile_store(0, c, 64)` which **overwrote** any pre-existing values in `c` on AMX-enabled hosts. Hardware-dependent behavior: callers relying on accumulation (blocked GEMM, repeated partial-K updates) would get incorrect results only when AMX was active. Same bug in the BF16 sibling `bf16_tile_gemm::amx_path` (shipped in PR #104) — fixing both. The fix: 1. Add a `tile 0` case to `tile_load` (encoding `C4 E2 7B 4B 04 08` — same SIB byte as the existing tmm1/tmm2 cases, ModR/M `04` = `mod=00, reg=000 (tmm0), r/m=100 (SIB follows)`). 2. Both AMX paths replace `tile_zero(0)` with `tile_load(0, c.as_ptr() as *const u8, 64)` — preloads tmm0 from caller's C buffer. TDPBUSD / TDPBF16PS then accumulate into the pre-loaded values; `tile_store(0, c, 64)` writes back the true `+=` result. Consumer impact: zero. Both `matmul_bf16_to_f32` and `matmul_i8_to_i32` (the only callers of these kernels in this crate) `tile_c.fill(0)` before each call — so the now-accumulating behavior + zero-initialized C = same result as the prior overwrite semantics. The fix removes a latent trap for future blocked-GEMM / partial-K consumers without changing any shipped behavior. New regression test `amx_path_preserves_c_accumulator`: pre-loads C with a known non-zero marker pattern, runs A·B where B=0 (so the contribution is 0), asserts the marker is preserved. Would fail on the pre-fix code because the tile_store would zero everything. Passes on this host (which has amx_int8). Verification: * 2094 lib tests pass (was 2093 — +1 regression test). * 11 amx_matmul tests pass (consumers' fill(0)-then-call pattern continues to produce correct results). * 2 bf16_tile_gemm tests pass. * 6 int8_tile_gemm tests pass. * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/hpc/amx_matmul.rs | 23 +++++++++++++++++++++-- src/hpc/bf16_tile_gemm.rs | 11 ++++++++++- src/hpc/int8_tile_gemm.rs | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/hpc/amx_matmul.rs b/src/hpc/amx_matmul.rs index eaff247e..4939559e 100644 --- a/src/hpc/amx_matmul.rs +++ b/src/hpc/amx_matmul.rs @@ -94,13 +94,32 @@ pub unsafe fn tile_release() { /// Load tile from memory. /// +/// Encoding: `TILELOADD tmmN, [rcx + rax]` is VEX `C4 E2 7B 4B /r` with +/// a SIB byte selecting `[rcx + rax]`. The ModR/M `/r` field encodes the +/// destination tile via `reg = N` (3-bit tile index). Per-tile bytes: +/// +/// tmm0: C4 E2 7B 4B **04** 08 +/// tmm1: C4 E2 7B 4B **0C** 08 +/// tmm2: C4 E2 7B 4B **14** 08 +/// +/// `04 | (N << 3)` gives the ModR/M byte; the `08` SIB is the same +/// across tiles. tmm0 was added when codex flagged the accumulator- +/// preservation bug on PR #184 (`tile_zero(0)` + `tile_store(0, c)` +/// discarded any pre-existing C values — the fix is `tile_load(0, c)` +/// instead of `tile_zero(0)` so TDPBUSD/TDPBF16PS truly accumulate as +/// the documented `C += A·B` semantics promise). +/// /// # Safety /// Pointer must be valid, stride must match tile config. #[inline] pub unsafe fn tile_load(tile: u8, ptr: *const u8, stride: usize) { match tile { - // TILELOADD tmm0, [ptr + stride*row] - // Encoding: VEX.128.F2.0F38.W0 4B /r with memory operand + 0 => asm!( + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x04, 0x08", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), 1 => asm!( ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x0c, 0x08", in("rcx") ptr, diff --git a/src/hpc/bf16_tile_gemm.rs b/src/hpc/bf16_tile_gemm.rs index 59429391..60f6b9ea 100644 --- a/src/hpc/bf16_tile_gemm.rs +++ b/src/hpc/bf16_tile_gemm.rs @@ -60,6 +60,13 @@ pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: us // ═════════════════════════════════════════════════════════════════════ /// AMX tile GEMM. B must be pre-VNNI-packed (see `vnni_pack_bf16`). +/// **Accumulates** into the caller's `c` buffer — matches the +/// documented `C += A·B` semantics. The C tile (tmm0) is preloaded +/// from `c` before the TDPBF16PS loop so any pre-existing values are +/// preserved. (Same accumulator-preservation fix the int8 sibling +/// got after codex P1 on PR #184: prior `tile_zero(0)` discarded +/// pre-existing C values even though docs promised accumulation.) +/// /// # Safety /// Caller must have verified `amx_available() == true`. #[inline] @@ -67,7 +74,9 @@ unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) { // Tile config: shapes at K_bytes=64 match BF16 K=32 case let cfg = TileConfig::for_dpbusd(64); tile_loadconfig(&cfg); - tile_zero(0); + // Preload C accumulator from caller's buffer (was tile_zero(0) + // pre-fix — see method-level note above). + tile_load(0, c.as_ptr() as *const u8, 64); // Accumulate over K/32 tile blocks let k_blocks = k / 32; diff --git a/src/hpc/int8_tile_gemm.rs b/src/hpc/int8_tile_gemm.rs index d2997f57..6dc2e76e 100644 --- a/src/hpc/int8_tile_gemm.rs +++ b/src/hpc/int8_tile_gemm.rs @@ -68,6 +68,11 @@ pub fn int8_tile_gemm_16x16(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) { // ═════════════════════════════════════════════════════════════════════ /// AMX tile GEMM. B must be pre-VNNI-packed (see `vnni_pack_i8`). +/// **Accumulates** into the caller's `c` buffer — matches the +/// documented `C += A·B` semantics. The C tile (tmm0) is preloaded +/// from `c` before the TDPBUSD loop so any pre-existing values are +/// preserved. +/// /// # Safety /// Caller must have verified `amx_available() == true`. #[inline] @@ -77,7 +82,11 @@ unsafe fn amx_path(a_u8: &[u8], b_vnni: &[i8], c: &mut [i32], k: usize) { // byte — same 64-byte row width either way). let cfg = TileConfig::for_dpbusd(64); tile_loadconfig(&cfg); - tile_zero(0); + // Preload C accumulator from caller's buffer so TDPBUSD truly + // accumulates into the existing values (fixes codex P1 from PR + // #184 — the prior `tile_zero(0)` discarded pre-existing C values + // even though the docs promise `C += A·B`). + tile_load(0, c.as_ptr() as *const u8, 64); // Accumulate over K/64 tile blocks. Each TDPBUSD consumes 64 // K-elements per A row × 4 K-elements per inner-product = 256 MACs @@ -275,6 +284,28 @@ mod tests { } } + /// Codex P1 regression on PR #184: `int8_tile_gemm_16x16` is + /// documented as `C += A·B`, but the AMX path used to `tile_zero(0)` + /// then `tile_store(0, c)`, **overwriting** `c` on AMX hosts (the + /// scalar fallback correctly accumulated). This test pre-loads C + /// with a known marker, runs A·B=0 (B is all zeros so the product + /// is zero), and asserts the marker is preserved — would fail on + /// the pre-fix AMX path because the tile_store would zero everything. + #[test] + fn amx_path_preserves_c_accumulator() { + let k = 64; + let a = vec![1u8; 16 * k]; + let b = vec![0i8; k * 16]; // product is exactly 0 + // Pre-load C with a non-zero marker pattern. + let mut c: Vec = (0..256).map(|i| i as i32 * 7 - 100).collect(); + let snapshot = c.clone(); + int8_tile_gemm_16x16(&a, &b, &mut c, k); + // After: c[i] += 0 → c[i] unchanged from snapshot. + for i in 0..256 { + assert_eq!(c[i], snapshot[i], "accumulator marker clobbered at {}: {} → {}", i, snapshot[i], c[i]); + } + } + #[test] fn public_api_diagonal_k128() { // A = identity-like (only A[i, i] = 1, but we need 16 × 128), so