Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 112 additions & 24 deletions src/hpc/amx_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,32 @@ pub unsafe fn tile_release() {

/// Load tile from memory.
///
/// Encoding: `TILELOADD tmmN, [rcx + rax]` is VEX `C4 E2 7B 4B /r` with
/// a SIB byte selecting `[rcx + rax]`. The ModR/M `/r` field encodes the
/// destination tile via `reg = N` (3-bit tile index). Per-tile bytes:
///
/// tmm0: C4 E2 7B 4B **04** 08
/// tmm1: C4 E2 7B 4B **0C** 08
/// tmm2: C4 E2 7B 4B **14** 08
///
/// `04 | (N << 3)` gives the ModR/M byte; the `08` SIB is the same
/// across tiles. tmm0 was added when codex flagged the accumulator-
/// preservation bug on PR #184 (`tile_zero(0)` + `tile_store(0, c)`
/// discarded any pre-existing C values — the fix is `tile_load(0, c)`
/// instead of `tile_zero(0)` so TDPBUSD/TDPBF16PS truly accumulate as
/// the documented `C += A·B` semantics promise).
///
/// # Safety
/// Pointer must be valid, stride must match tile config.
#[inline]
pub unsafe fn tile_load(tile: u8, ptr: *const u8, stride: usize) {
match tile {
// TILELOADD tmm0, [ptr + stride*row]
// Encoding: VEX.128.F2.0F38.W0 4B /r with memory operand
0 => asm!(
".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x04, 0x08",
in("rcx") ptr,
in("rax") stride,
options(nostack),
),
1 => asm!(
".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x0c, 0x08",
in("rcx") ptr,
Expand Down Expand Up @@ -193,6 +212,32 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
}
}

/// Pack B[K, N] i8 row-major into K/4 × (N*4) VNNI quads for `TDPBUSD`.
///
/// Output layout required by `TDPBUSD` tile 2 (16 rows × 64 bytes):
/// dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]
///
/// For N=16 (AMX tile width), each output "row" holds 16 i8 quads = 64
/// bytes (matches the 64-byte tile row width). K must be a multiple of
/// 4. The same layout is used for `u8` operands (just bit-cast through
/// — VNNI doesn't care about sign at the packing layer; sign
/// interpretation happens inside TDPBUSD which treats A as u8 and B
/// as i8 for the multiply).
#[inline]
pub fn vnni_pack_i8(src: &[i8], dst: &mut [i8], k: usize, n: usize) {
debug_assert_eq!(src.len(), k * n);
debug_assert_eq!(dst.len(), k * n);
debug_assert_eq!(k % 4, 0, "K must be multiple of 4 for VNNI INT8 quads");
for kb in 0..(k / 4) {
let dst_row = kb * n * 4;
for j in 0..n {
for p in 0..4 {
dst[dst_row + j * 4 + p] = src[(4 * kb + p) * n + j];
}
}
}
}

// ═══════════════════════════════════════════════════════════════════════════
// Public ndarray-typed matmul API (sprint A4 / Burn parity item 6)
// ═══════════════════════════════════════════════════════════════════════════
Expand All @@ -207,7 +252,7 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
// strided (e.g. `view.slice(s![.., ..;2])`). Strided inputs are repacked
// into contiguous staging buffers before the kernel runs.

use crate::hpc::quantized::{bf16_gemm_f32, int8_gemm_i32, BF16};
use crate::hpc::quantized::{bf16_gemm_f32, BF16};
use crate::{ArrayView2, ArrayViewMut2};

/// Errors returned by the public AMX matmul API.
Expand Down Expand Up @@ -537,14 +582,17 @@ pub fn matmul_f32(

/// Matrix multiply i8 × i8 → i32: `out = lhs · rhs`.
///
/// On AMX hosts uses `TDPBUSD` (256 MACs/instr); otherwise falls back to
/// the scalar `int8_gemm_i32`.
/// On AMX hosts with 16/16/64-aligned shapes uses `TDPBUSD` via the
/// 16×16 tile kernel in [`crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16`]
/// — 16 384 MACs per instruction. Mis-aligned shapes (or non-AMX hosts)
/// fall back to the scalar i8×i8 → i32 reference.
///
/// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For the
/// signed-by-signed surface required here, the LHS is shifted into the
/// unsigned domain and the bias subtracted from the accumulator (only on
/// the AMX path; the scalar path operates directly in i8). The public
/// result is identical.
/// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For
/// the signed-by-signed surface required here, the LHS is shifted into
/// the unsigned domain (i8 + 128 → u8) and the bias `128 · sum(B[:, j]
/// over k)` is subtracted from the accumulator. The public result is
/// bit-identical to the scalar reference because all arithmetic stays
/// in i32 (no float rounding).
///
/// `out` must be row-contiguous; inputs may be strided.
pub fn matmul_i8_to_i32(
Expand All @@ -556,26 +604,45 @@ pub fn matmul_i8_to_i32(
let b_i8 = pack_contig(&rhs);
let mut c = vec![0i32; m * n];

if amx_available() {
// AMX TDPBUSD path: shift LHS i8 → u8 via (+128) and subtract the
// bias 128·sum(B[:, j] over k) afterwards. This keeps numerics exact.
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 {
// Tier 1 — AMX TDPBUSD tile path: shift LHS i8 → u8 (+128),
// tile-GEMM via int8_tile_gemm_16x16, subtract bias.
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();

// Compute C' = A_u8 · B_i8 in i32, then subtract 128 · colsum(B).
int8_gemm_i32(&a_u8, &b_i8, &mut c, m, n, k);
let mut colsum = vec![0i32; n];
for p in 0..k {
for j in 0..n {
colsum[j] += b_i8[p * n + j] as i32;
let mut b_tile = vec![0i8; k * 16];
let mut tile_c = vec![0i32; 256];

for j_tile in (0..n).step_by(16) {
for kk in 0..k {
let row = kk * n + j_tile;
b_tile[kk * 16..(kk + 1) * 16]
.copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) });
}
}
for i in 0..m {
for j in 0..n {
c[i * n + j] -= 128 * colsum[j];
for i_tile in (0..m).step_by(16) {
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
tile_c.fill(0);
crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
for ii in 0..16 {
let dst_off = (i_tile + ii) * n + j_tile;
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
}
}
}
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
} else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avx512vnni") {
// Tier 2 — AVX-512 VPDPBUSD zmm: 64 MACs per instruction, no
// shape-alignment requirement (M/N/K all handled via per-block
// trim and scalar K-tail). Same sign-shift bias trick as AMX.
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
// SAFETY: runtime feature-detected avx512vnni above.
unsafe {
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm(&a_u8, &b_i8, &mut c, m, n, k);
}
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
} else {
// Scalar i8×i8 → i32 reference.
// Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts,
// pre-AVX-512 silicon, or shapes that don't satisfy either of
// the SIMD tiers' alignment requirements.
for i in 0..m {
for p in 0..k {
let av = a_i8[i * k + p] as i32;
Expand All @@ -597,6 +664,27 @@ pub fn matmul_i8_to_i32(
Ok(())
}

/// Subtract `128 · colsum(B[:, j])` from each `c[i, j]` lane.
///
/// Used by both the AMX and AVX-512-VNNI arms of `matmul_i8_to_i32`
/// to undo the LHS sign-shift bias (A_i8 → A_u8 via +128 means
/// `A_u8 · B = (A_i8 + 128) · B = A_i8 · B + 128 · sum_k B[k, j]`).
/// Pure integer arithmetic, no rounding — the public result is
/// bit-identical to the scalar i8 × i8 → i32 reference.
fn subtract_i8_to_u8_bias(c: &mut [i32], b_i8: &[i8], m: usize, n: usize, k: usize) {
let mut colsum = vec![0i32; n];
for p in 0..k {
for j in 0..n {
colsum[j] += b_i8[p * n + j] as i32;
}
}
for i in 0..m {
for j in 0..n {
c[i * n + j] -= 128 * colsum[j];
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
11 changes: 10 additions & 1 deletion src/hpc/bf16_tile_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,23 @@ pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: us
// ═════════════════════════════════════════════════════════════════════

/// AMX tile GEMM. B must be pre-VNNI-packed (see `vnni_pack_bf16`).
/// **Accumulates** into the caller's `c` buffer — matches the
/// documented `C += A·B` semantics. The C tile (tmm0) is preloaded
/// from `c` before the TDPBF16PS loop so any pre-existing values are
/// preserved. (Same accumulator-preservation fix the int8 sibling
/// got after codex P1 on PR #184: prior `tile_zero(0)` discarded
/// pre-existing C values even though docs promised accumulation.)
///
/// # Safety
/// Caller must have verified `amx_available() == true`.
#[inline]
unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) {
// Tile config: shapes at K_bytes=64 match BF16 K=32 case
let cfg = TileConfig::for_dpbusd(64);
tile_loadconfig(&cfg);
tile_zero(0);
// Preload C accumulator from caller's buffer (was tile_zero(0)
// pre-fix — see method-level note above).
tile_load(0, c.as_ptr() as *const u8, 64);

// Accumulate over K/32 tile blocks
let k_blocks = k / 32;
Expand Down
Loading
Loading