Skip to content

Commit bb7b9b7

Browse files
committed
feat(hpc): VPDPBUSD-zmm middle tier for matmul_i8_to_i32
Completes the per-CPU dispatch chain for `matmul_i8_to_i32`. Per PR #180's table the middle tier between AMX TDPBUSD (Sapphire Rapids+) and the scalar reference is `_mm512_dpbusd_epi32` (zmm form, avx512vnni feature) — covers Cooper Lake, Cascade Lake, Ice Lake-SP, Zen 4+ silicon that has AVX-512 VNNI but not AMX. Mirrors the VDPBF16PS arm structure that landed for BF16 in PR #182's `bf16_gemm_dispatch`. New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm`: * One VPDPBUSD instruction: 16 i32 accumulator lanes, each receiving 4 u8×i8 products = 64 MACs per instruction. * Maps the 16 output lanes to a row of 16 j-columns of `c[i, ·]`, one i row processed at a time, K-quad inner loop accumulating into the same 16 i32 lanes across iterations. * B-column packing: pre-packs B for the current j-block into `b_col_quads[k_quad * 16 + j] = i32 (4 bytes of B[4k_quad.., j_base+j] packed bottom-to-top)` once per j-block; reused across all M i-iterations so the gather cost amortizes. * A row quad broadcast: `_mm512_set1_epi32` of (4 u8 bytes packed) every K-iter — same quad seen by every output column. * K-tail (k % 4 != 0) handled with scalar u8×i8 multiplies per output cell; N-tail (j_count < 16) handled by trimming the store width — padding lanes still receive VPDPBUSD updates but aren't written back. * Stable intrinsic `_mm512_dpbusd_epi32` under `target_feature = "avx512vnni,avx512f"` — no asm-byte needed. Wiring `matmul_i8_to_i32` to three-tier dispatch: 1. amx_available() + 16/16/64-aligned shapes → int8_tile_gemm_16x16 → TDPBUSD asm-byte (16 384 MACs/instr, this commit reuses the kernel from PR #184 fe334de... wait, same PR — from b1979d7 in THIS PR) 2. is_x86_feature_detected!("avx512vnni") → int8_gemm_vpdpbusd_zmm → _mm512_dpbusd_epi32 stable intrinsic (64 MACs/instr, arbitrary shapes, K-tail handled scalar, N-tail handled by per-iteration j_count trim) 3. scalar i8×i8 → i32 reference for non-x86, pre-AVX-512 hosts, or shapes that don't satisfy either SIMD tier's requirements Factored the shared sign-shift bias subtraction into a private `subtract_i8_to_u8_bias(c, b_i8, m, n, k)` helper: both Tier 1 (AMX) and Tier 2 (VNNI) shift LHS i8 → u8 via (+128) then need to subtract 128·colsum(B) from the accumulator. Pure integer arithmetic, bit-identical to the scalar i8×i8 → i32 reference. Verification: * Default v3 build: 2093 lib tests pass (was 2092 — +1 new test `vpdpbusd_zmm_matches_scalar` that exercises the new arm directly with shapes spanning aligned cases, K-tail (k % 4), N-tail (n % 16), and small shapes; asserts byte-equal output vs scalar reference). * Existing `matmul_i8_to_i32_16x16_exact` continues to pass through the AMX tier on this host (which has amx_int8). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit: matmul_bf16_to_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 VPDPBUSD | scalar (b1979d7) | (THIS COMMIT) | (always) So all three of the public matmul entry points now have full three-tier dispatch on x86_64. Out of scope (separate PRs): * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level u8×i8 surface from PR #182) — it's u8×i8 natively, no sign- shift bias needed, simpler than matmul_i8_to_i32. * AVX-VNNI ymm arm (Arrow Lake / Meteor Lake U: avxvnni without avx512vnni) — the `vnni2_*` functions exist in simd_amx.rs but need to be assembled into a m×n×k VNNI-ymm GEMM. Same shape as the avx512vnni arm just with ymm width. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent 33a2bbb commit bb7b9b7

2 files changed

Lines changed: 183 additions & 23 deletions

File tree

src/hpc/amx_matmul.rs

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -586,20 +586,14 @@ pub fn matmul_i8_to_i32(
586586
let mut c = vec![0i32; m * n];
587587

588588
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 {
589-
// AMX TDPBUSD path: shift LHS i8 → u8 via (+128), tile-GEMM into
590-
// i32, subtract bias 128·colsum(B). The tile kernel zeroes its
591-
// internal accumulator (TILEZERO + TDPBUSD accumulate); we need
592-
// fresh per-tile output here so we tile manually over M/N and
593-
// call int8_tile_gemm_16x16 per (i, j) block.
589+
// Tier 1 — AMX TDPBUSD tile path: shift LHS i8 → u8 (+128),
590+
// tile-GEMM via int8_tile_gemm_16x16, subtract bias.
594591
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
595592

596-
// B sub-block extraction per j-tile (B is row-major K × N; the
597-
// tile kernel wants K × 16 contiguous). Reused across i-tiles.
598593
let mut b_tile = vec![0i8; k * 16];
599594
let mut tile_c = vec![0i32; 256];
600595

601596
for j_tile in (0..n).step_by(16) {
602-
// Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows.
603597
for kk in 0..k {
604598
let row = kk * n + j_tile;
605599
b_tile[kk * 16..(kk + 1) * 16]
@@ -609,29 +603,27 @@ pub fn matmul_i8_to_i32(
609603
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
610604
tile_c.fill(0);
611605
crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
612-
// Write tile_c (16 × 16) into c at (i_tile, j_tile).
613606
for ii in 0..16 {
614607
let dst_off = (i_tile + ii) * n + j_tile;
615608
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
616609
}
617610
}
618611
}
619-
620-
// Subtract bias: c[i, j] -= 128 · colsum(B[:, j]).
621-
let mut colsum = vec![0i32; n];
622-
for p in 0..k {
623-
for j in 0..n {
624-
colsum[j] += b_i8[p * n + j] as i32;
625-
}
626-
}
627-
for i in 0..m {
628-
for j in 0..n {
629-
c[i * n + j] -= 128 * colsum[j];
630-
}
612+
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
613+
} else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avx512vnni") {
614+
// Tier 2 — AVX-512 VPDPBUSD zmm: 64 MACs per instruction, no
615+
// shape-alignment requirement (M/N/K all handled via per-block
616+
// trim and scalar K-tail). Same sign-shift bias trick as AMX.
617+
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
618+
// SAFETY: runtime feature-detected avx512vnni above.
619+
unsafe {
620+
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm(&a_u8, &b_i8, &mut c, m, n, k);
631621
}
622+
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
632623
} else {
633-
// Scalar i8×i8 → i32 reference — used for non-AMX hosts and for
634-
// shapes that don't fit the 16/16/64 tile alignment.
624+
// Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts,
625+
// pre-AVX-512 silicon, or shapes that don't satisfy either of
626+
// the SIMD tiers' alignment requirements.
635627
for i in 0..m {
636628
for p in 0..k {
637629
let av = a_i8[i * k + p] as i32;
@@ -653,6 +645,27 @@ pub fn matmul_i8_to_i32(
653645
Ok(())
654646
}
655647

648+
/// Subtract `128 · colsum(B[:, j])` from each `c[i, j]` lane.
649+
///
650+
/// Used by both the AMX and AVX-512-VNNI arms of `matmul_i8_to_i32`
651+
/// to undo the LHS sign-shift bias (A_i8 → A_u8 via +128 means
652+
/// `A_u8 · B = (A_i8 + 128) · B = A_i8 · B + 128 · sum_k B[k, j]`).
653+
/// Pure integer arithmetic, no rounding — the public result is
654+
/// bit-identical to the scalar i8 × i8 → i32 reference.
655+
fn subtract_i8_to_u8_bias(c: &mut [i32], b_i8: &[i8], m: usize, n: usize, k: usize) {
656+
let mut colsum = vec![0i32; n];
657+
for p in 0..k {
658+
for j in 0..n {
659+
colsum[j] += b_i8[p * n + j] as i32;
660+
}
661+
}
662+
for i in 0..m {
663+
for j in 0..n {
664+
c[i * n + j] -= 128 * colsum[j];
665+
}
666+
}
667+
}
668+
656669
#[cfg(test)]
657670
mod tests {
658671
use super::*;

src/hpc/int8_tile_gemm.rs

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,111 @@ unsafe fn amx_path(a_u8: &[u8], b_vnni: &[i8], c: &mut [i32], k: usize) {
101101
tile_release();
102102
}
103103

104+
// ═════════════════════════════════════════════════════════════════════
105+
// VPDPBUSD-zmm middle tier (avx512vnni without AMX)
106+
// ═════════════════════════════════════════════════════════════════════
107+
108+
/// AVX-512 VNNI `u8 × i8 → i32` GEMM kernel for arbitrary M × N × K.
109+
///
110+
/// One `_mm512_dpbusd_epi32` instruction: 16 i32 accumulator lanes,
111+
/// each receiving the sum of 4 `u8 × i8` products = **64 MACs per
112+
/// instruction**. Pre-packs B in VNNI quad layout once per j-block
113+
/// (16-wide column band) and reuses across all M i-iterations,
114+
/// amortizing the gather cost.
115+
///
116+
/// K-tail (when K is not a multiple of 4) handled with scalar
117+
/// u8 × i8 multiplies per output cell; N-tail (when the j-block has
118+
/// fewer than 16 valid columns) handled by trimming the store after
119+
/// the VPDPBUSD chain.
120+
///
121+
/// This is the middle dispatch tier between AMX TDPBUSD (Sapphire
122+
/// Rapids+) and the scalar reference — covers Cooper Lake, Cascade
123+
/// Lake, Ice Lake-SP, Zen 4+ silicon that has avx512vnni but not
124+
/// AMX. Mirrors the VDPBF16PS arm structure shipped for BF16 in
125+
/// PR #182.
126+
///
127+
/// Output behavior: overwrites `c` (does NOT accumulate). Caller's
128+
/// responsibility to zero `c` first if a fresh-write GEMM is wanted.
129+
///
130+
/// # Safety
131+
/// Caller must have feature-detected `avx512vnni + avx512f` at runtime.
132+
#[cfg(target_arch = "x86_64")]
133+
#[target_feature(enable = "avx512vnni,avx512f")]
134+
pub unsafe fn int8_gemm_vpdpbusd_zmm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) {
135+
use core::arch::x86_64::{
136+
__m512i, _mm512_dpbusd_epi32, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_setzero_si512, _mm512_storeu_si512,
137+
};
138+
139+
let k_quads = k / 4;
140+
let k_tail = k % 4;
141+
142+
// Pre-pack scratch for B columns of the current j-block:
143+
// 16 i32 lanes per k_quad, each holding 4 consecutive K-bytes
144+
// packed (b[2q+0..2q+4] for output column j+lane).
145+
let mut b_col_quads = vec![0i32; k_quads.max(1) * 16];
146+
// Scratch for the 16-wide store + N-tail trim.
147+
let mut out_buf = [0i32; 16];
148+
149+
for j_base in (0..n).step_by(16) {
150+
let j_count = 16.min(n - j_base);
151+
152+
// Pack B[0..k, j_base..j_base+j_count] in quad-interleaved layout.
153+
// For lanes j >= j_count (the N-tail of this j_block), pad with 0
154+
// so the VPDPBUSD doesn't read uninitialized memory; they're not
155+
// stored back.
156+
for k_quad in 0..k_quads {
157+
let row0 = 4 * k_quad * n;
158+
let row1 = (4 * k_quad + 1) * n;
159+
let row2 = (4 * k_quad + 2) * n;
160+
let row3 = (4 * k_quad + 3) * n;
161+
for jj in 0..j_count {
162+
let b0 = b_i8[row0 + j_base + jj] as u8 as u32;
163+
let b1 = b_i8[row1 + j_base + jj] as u8 as u32;
164+
let b2 = b_i8[row2 + j_base + jj] as u8 as u32;
165+
let b3 = b_i8[row3 + j_base + jj] as u8 as u32;
166+
// Pack as i32: bottom byte is k_quad*4+0, top is k_quad*4+3.
167+
b_col_quads[k_quad * 16 + jj] = (b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)) as i32;
168+
}
169+
for jj in j_count..16 {
170+
b_col_quads[k_quad * 16 + jj] = 0;
171+
}
172+
}
173+
174+
for i in 0..m {
175+
let mut acc = _mm512_setzero_si512();
176+
let a_row_off = i * k;
177+
for k_quad in 0..k_quads {
178+
// Broadcast A[i, 4*k_quad..4*k_quad+4] (4 u8) across all
179+
// 16 i32 lanes via _mm512_set1_epi32.
180+
let a0 = a_u8[a_row_off + 4 * k_quad] as u32;
181+
let a1 = a_u8[a_row_off + 4 * k_quad + 1] as u32;
182+
let a2 = a_u8[a_row_off + 4 * k_quad + 2] as u32;
183+
let a3 = a_u8[a_row_off + 4 * k_quad + 3] as u32;
184+
let packed_a = a0 | (a1 << 8) | (a2 << 16) | (a3 << 24);
185+
let a_v = _mm512_set1_epi32(packed_a as i32);
186+
let b_v = _mm512_loadu_si512(b_col_quads.as_ptr().add(k_quad * 16) as *const __m512i);
187+
acc = _mm512_dpbusd_epi32(acc, a_v, b_v);
188+
}
189+
_mm512_storeu_si512(out_buf.as_mut_ptr() as *mut __m512i, acc);
190+
191+
// K-tail: scalar multiplies for k = k_quads*4 .. k.
192+
if k_tail > 0 {
193+
for kk in (k_quads * 4)..k {
194+
let a_val = a_u8[a_row_off + kk] as i32;
195+
let tail_row = kk * n;
196+
for jj in 0..j_count {
197+
out_buf[jj] += a_val * b_i8[tail_row + j_base + jj] as i32;
198+
}
199+
}
200+
}
201+
202+
// Store j_count valid lanes (drops N-tail padding lanes).
203+
let dst_off = i * n + j_base;
204+
c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]);
205+
}
206+
}
207+
}
208+
104209
// ═════════════════════════════════════════════════════════════════════
105210
// Scalar fallback (i32 reference)
106211
// ═════════════════════════════════════════════════════════════════════
@@ -192,6 +297,48 @@ mod tests {
192297
}
193298
}
194299

300+
/// Direct test for the VPDPBUSD-zmm arm, exercising the path the
301+
/// `matmul_i8_to_i32` dispatcher would skip when AMX is available.
302+
/// Verifies bit-exact parity against the scalar reference for
303+
/// arbitrary (M, N, K) — including non-multiple-of-4 K (so the
304+
/// scalar K-tail branch fires) and non-multiple-of-16 N (so the
305+
/// j-count trim branch fires).
306+
#[cfg(target_arch = "x86_64")]
307+
#[test]
308+
fn vpdpbusd_zmm_matches_scalar() {
309+
if !std::is_x86_feature_detected!("avx512vnni") {
310+
eprintln!("avx512vnni not detected; skipping");
311+
return;
312+
}
313+
314+
fn ref_gemm(a: &[u8], b: &[i8], m: usize, n: usize, k: usize) -> Vec<i32> {
315+
let mut c = vec![0i32; m * n];
316+
for i in 0..m {
317+
for kk in 0..k {
318+
let av = a[i * k + kk] as i32;
319+
for j in 0..n {
320+
c[i * n + j] += av * b[kk * n + j] as i32;
321+
}
322+
}
323+
}
324+
c
325+
}
326+
327+
// Sweep shapes spanning aligned cases, K-tail (k % 4), and
328+
// N-tail (n % 16) to exercise every code path.
329+
for (m, n, k) in [(16, 16, 64), (3, 5, 7), (17, 33, 100), (1, 17, 12), (8, 16, 4)] {
330+
let a: Vec<u8> = (0..m * k).map(|i| ((i * 31 + 7) % 256) as u8).collect();
331+
let b: Vec<i8> = (0..k * n)
332+
.map(|i| ((i * 17 + 3) % 256) as u8 as i8)
333+
.collect();
334+
let expected = ref_gemm(&a, &b, m, n, k);
335+
let mut got = vec![0i32; m * n];
336+
// SAFETY: avx512vnni confirmed at the top of the test.
337+
unsafe { int8_gemm_vpdpbusd_zmm(&a, &b, &mut got, m, n, k) };
338+
assert_eq!(got, expected, "VPDPBUSD-zmm mismatch at (M={}, N={}, K={})", m, n, k);
339+
}
340+
}
341+
195342
#[test]
196343
fn vnni_pack_i8_roundtrip() {
197344
// Pack then verify the VNNI layout matches the spec:

0 commit comments

Comments
 (0)