Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,27 @@ rayon = ["dep:rayon", "std"]
# cfg-dispatch in `simd.rs` remains the production path.
nightly-simd = ["std"]

# Runtime SIMD dispatch — release-binary distribution path. Compiles all
# x86_64 backends into one artifact and selects per-op kernels via
# `LazyLock<fn>` trampolines that read `simd_caps()` on first call. The
# `crate::simd_runtime::*` module becomes reachable under this feature;
# the existing compile-time `crate::simd::*` / `crate::simd_ops::*`
# cascade is unchanged (additive). Use case: shipping one binary that
# adapts across heterogeneous deployment silicon (AVX-512 server +
# AVX2-only laptop) from the same artifact.
#
# Mutually exclusive with `nightly-simd` (the portable-SIMD polyfill
# replaces the architecture-specific intrinsics that the runtime
# trampolines select between; they can't coexist coherently).
#
# Per-call overhead: ~2-3 ns indirect-call through a static fn pointer
# (LazyLock fires once at first call, every subsequent call is a
# pointer deref). Invisible against any SIMD op's actual work.
#
# See `.claude/knowledge/simd-dispatch-architecture.md` § 7.1 / Phase 5
# for the design rationale.
runtime-dispatch = ["std"]

# HPC extras: p64 palette/NARS bridge + fractal manifold.
# (blake3 was previously listed here; it is now part of `std` directly
# because the cognitive substrate modules under hpc/ that import blake3
Expand Down
38 changes: 15 additions & 23 deletions src/hpc/amx_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,28 +606,10 @@ pub fn matmul_i8_to_i32(

if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 {
// Tier 1 — AMX TDPBUSD tile path: shift LHS i8 → u8 (+128),
// tile-GEMM via int8_tile_gemm_16x16, subtract bias.
// delegate to the shared int8_gemm_amx_tiled helper, subtract
// the sign-shift bias.
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();

let mut b_tile = vec![0i8; k * 16];
let mut tile_c = vec![0i32; 256];

for j_tile in (0..n).step_by(16) {
for kk in 0..k {
let row = kk * n + j_tile;
b_tile[kk * 16..(kk + 1) * 16]
.copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) });
}
for i_tile in (0..m).step_by(16) {
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
tile_c.fill(0);
crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
for ii in 0..16 {
let dst_off = (i_tile + ii) * n + j_tile;
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
}
}
}
crate::hpc::int8_tile_gemm::int8_gemm_amx_tiled(&a_u8, &b_i8, &mut c, m, n, k);
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
} else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avx512vnni") {
// Tier 2 — AVX-512 VPDPBUSD zmm: 64 MACs per instruction, no
Expand All @@ -639,9 +621,19 @@ pub fn matmul_i8_to_i32(
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm(&a_u8, &b_i8, &mut c, m, n, k);
}
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
} else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avxvnni") {
// Tier 3 — AVX-VNNI ymm VPDPBUSD: 32 MACs per instruction.
// Arrow Lake, Meteor Lake U, Alder Lake silicon that has
// AVX-VNNI but dropped AVX-512. Same sign-shift bias trick.
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
// SAFETY: runtime feature-detected avxvnni above.
unsafe {
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm(&a_u8, &b_i8, &mut c, m, n, k);
}
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
} else {
// Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts,
// pre-AVX-512 silicon, or shapes that don't satisfy either of
// Tier 4 — Scalar i8×i8 → i32 reference for non-x86 hosts,
// pre-AVX-VNNI silicon, or shapes that don't satisfy any of
// the SIMD tiers' alignment requirements.
for i in 0..m {
for p in 0..k {
Expand Down
215 changes: 215 additions & 0 deletions src/hpc/int8_tile_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,97 @@ pub unsafe fn int8_gemm_vpdpbusd_zmm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m:
}
}

// ═════════════════════════════════════════════════════════════════════
// VPDPBUSD-ymm AVX-VNNI tier (Arrow Lake / Meteor Lake U / Alder Lake)
// ═════════════════════════════════════════════════════════════════════

/// AVX-VNNI ymm `u8 × i8 → i32` GEMM kernel for arbitrary M × N × K.
///
/// One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes,
/// each receiving the sum of 4 `u8 × i8` products = **32 MACs per
/// instruction**. Half the throughput-per-instruction of the
/// `_mm512_dpbusd_epi32` zmm version (which does 64 MACs); fires on
/// Arrow Lake / Meteor Lake U / Alder Lake silicon that has AVX-VNNI
/// but NOT AVX-512.
///
/// Same B pre-packing scheme as the zmm version (quad-interleaved per
/// 8-wide j-block), same K-tail and N-tail handling, just narrower.
/// Mirrors the `vnni2_dot_u8_i8` shape in `simd_amx.rs` but as a
/// matrix-product instead of single-row dot.
///
/// Output behavior: overwrites `c` (does NOT accumulate). Caller's
/// responsibility to zero `c` first if needed.
///
/// # Safety
/// Caller must have feature-detected `avxvnni + avx2` at runtime.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avxvnni,avx2")]
pub unsafe fn int8_gemm_vpdpbusd_ymm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) {
use core::arch::x86_64::{
__m256i, _mm256_dpbusd_avx_epi32, _mm256_loadu_si256, _mm256_set1_epi32, _mm256_setzero_si256,
_mm256_storeu_si256,
};

let k_quads = k / 4;
let k_tail = k % 4;

// Pre-pack scratch: 8 i32 lanes per k_quad (vs 16 in the zmm
// version). Same per-lane layout: each i32 holds 4 consecutive
// B K-bytes for output column j+lane.
let mut b_col_quads = vec![0i32; k_quads.max(1) * 8];
let mut out_buf = [0i32; 8];

for j_base in (0..n).step_by(8) {
let j_count = 8.min(n - j_base);

for k_quad in 0..k_quads {
let row0 = 4 * k_quad * n;
let row1 = (4 * k_quad + 1) * n;
let row2 = (4 * k_quad + 2) * n;
let row3 = (4 * k_quad + 3) * n;
for jj in 0..j_count {
let b0 = b_i8[row0 + j_base + jj] as u8 as u32;
let b1 = b_i8[row1 + j_base + jj] as u8 as u32;
let b2 = b_i8[row2 + j_base + jj] as u8 as u32;
let b3 = b_i8[row3 + j_base + jj] as u8 as u32;
b_col_quads[k_quad * 8 + jj] = (b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)) as i32;
}
for jj in j_count..8 {
b_col_quads[k_quad * 8 + jj] = 0;
}
}

for i in 0..m {
let mut acc = _mm256_setzero_si256();
let a_row_off = i * k;
for k_quad in 0..k_quads {
let a0 = a_u8[a_row_off + 4 * k_quad] as u32;
let a1 = a_u8[a_row_off + 4 * k_quad + 1] as u32;
let a2 = a_u8[a_row_off + 4 * k_quad + 2] as u32;
let a3 = a_u8[a_row_off + 4 * k_quad + 3] as u32;
let packed_a = a0 | (a1 << 8) | (a2 << 16) | (a3 << 24);
let a_v = _mm256_set1_epi32(packed_a as i32);
let b_v = _mm256_loadu_si256(b_col_quads.as_ptr().add(k_quad * 8) as *const __m256i);
acc = _mm256_dpbusd_avx_epi32(acc, a_v, b_v);
}
_mm256_storeu_si256(out_buf.as_mut_ptr() as *mut __m256i, acc);

if k_tail > 0 {
for kk in (k_quads * 4)..k {
let a_val = a_u8[a_row_off + kk] as i32;
let tail_row = kk * n;
for jj in 0..j_count {
out_buf[jj] += a_val * b_i8[tail_row + j_base + jj] as i32;
}
}
}

let dst_off = i * n + j_base;
c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]);
}
}
}

// ═════════════════════════════════════════════════════════════════════
// Scalar fallback (i32 reference)
// ═════════════════════════════════════════════════════════════════════
Expand All @@ -231,6 +322,71 @@ fn fallback_path(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) {
}
}

// ═════════════════════════════════════════════════════════════════════
// AMX tiled helper — arbitrary 16/16/64-aligned M × N × K via 16×16 tile loop
// ═════════════════════════════════════════════════════════════════════

/// `u8 × i8 → i32` GEMM using AMX `TDPBUSD` for arbitrary M × N × K
/// shapes that satisfy `m % 16 == 0 && n % 16 == 0 && k % 64 == 0`.
///
/// Tile-decomposes the M × N output into 16×16 blocks and calls
/// [`int8_tile_gemm_16x16`] per (i_tile, j_tile). B sub-block extracted
/// into K × 16 scratch once per j-tile, reused across all M i-tiles —
/// amortizes the column gather cost.
///
/// **Overwrite semantics**: `c` is written, not accumulated. Caller
/// does NOT need to zero `c` beforehand. (The underlying
/// `int8_tile_gemm_16x16` accumulates into its tile buffer, but we
/// zero the tile buffer before each call so the per-tile write to `c`
/// is pure overwrite.)
///
/// # Panics
/// Panics if `a_u8`, `b_i8`, or `c` are too small for the requested
/// `(m, n, k)`, mirroring the boundary contract from `gemm_u8_i8`. Also
/// panics in debug builds when AMX isn't OS-enabled or when the shape
/// alignment constraints aren't met (production builds skip those for
/// performance — callers must runtime-check
/// `crate::hpc::amx_matmul::amx_available()` and the 16/16/64
/// alignment themselves).
pub fn int8_gemm_amx_tiled(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) {
// Length assertions (codex P1 from PR #185 — the function reads
// `b_i8` via a 16-wide window per (kk, j_tile) iteration and a_u8
// via a 16-row slice per i_tile, so mismatched shapes would
// trigger out-of-bounds reads without these gates).
assert!(a_u8.len() >= m * k, "int8_gemm_amx_tiled: a_u8.len()={} < m*k={}", a_u8.len(), m * k);
assert!(b_i8.len() >= k * n, "int8_gemm_amx_tiled: b_i8.len()={} < k*n={}", b_i8.len(), k * n);
assert!(c.len() >= m * n, "int8_gemm_amx_tiled: c.len()={} < m*n={}", c.len(), m * n);

debug_assert!(crate::hpc::amx_matmul::amx_available());
debug_assert_eq!(m % 16, 0, "int8_gemm_amx_tiled: M must be multiple of 16");
debug_assert_eq!(n % 16, 0, "int8_gemm_amx_tiled: N must be multiple of 16");
debug_assert_eq!(k % 64, 0, "int8_gemm_amx_tiled: K must be multiple of 64");

let mut b_tile = vec![0i8; k * 16];
let mut tile_c = vec![0i32; 256];

for j_tile in (0..n).step_by(16) {
// Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows
// (contiguous memory for int8_tile_gemm_16x16's input shape).
// Safe slicing — the row..row+16 range is bounded by
// `b_i8.len() >= k * n` asserted at function entry.
for kk in 0..k {
let row = kk * n + j_tile;
b_tile[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[row..row + 16]);
}
for i_tile in (0..m).step_by(16) {
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
tile_c.fill(0);
int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
// Write tile_c (16 × 16, row-major) into c (M × N, row-major).
for ii in 0..16 {
let dst_off = (i_tile + ii) * n + j_tile;
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
}
}
}
}

// ═════════════════════════════════════════════════════════════════════
// Tests
// ═════════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -370,6 +526,65 @@ mod tests {
}
}

/// Codex P1 regression on PR #185: `int8_gemm_amx_tiled` is a
/// safe public function — mismatched (m, n, k) vs slice lengths
/// must panic at the function boundary, not trigger UB inside
/// the unsafe slice/pointer arithmetic in the inner loop. This
/// test passes deliberately-undersized buffers and expects a
/// panic (which `#[should_panic]` catches).
#[test]
#[should_panic(expected = "b_i8.len()")]
fn amx_tiled_panics_on_undersized_b() {
let m = 16;
let n = 32;
let k = 64;
let a = vec![0u8; m * k];
let b = vec![0i8; k * (n - 16)]; // half a j_tile short of what's claimed
let mut c = vec![0i32; m * n];
// Even on non-AMX hosts the assertion fires before reaching
// the (debug-asserted) amx_available() check.
int8_gemm_amx_tiled(&a, &b, &mut c, m, n, k);
}

/// Direct test for the VPDPBUSD-ymm arm (AVX-VNNI tier of
/// `matmul_i8_to_i32`). Same shape / bit-exactness contract as
/// the zmm version's test, just on the narrower 8-wide kernel.
#[cfg(target_arch = "x86_64")]
#[test]
fn vpdpbusd_ymm_matches_scalar() {
if !std::is_x86_feature_detected!("avxvnni") {
eprintln!("avxvnni not detected; skipping");
return;
}

fn ref_gemm(a: &[u8], b: &[i8], m: usize, n: usize, k: usize) -> Vec<i32> {
let mut c = vec![0i32; m * n];
for i in 0..m {
for kk in 0..k {
let av = a[i * k + kk] as i32;
for j in 0..n {
c[i * n + j] += av * b[kk * n + j] as i32;
}
}
}
c
}

// Sweep shapes spanning 8-aligned, K-tail (k % 4), N-tail
// (n % 8), and small shapes to exercise every code path.
for (m, n, k) in [(16, 8, 64), (3, 5, 7), (17, 33, 100), (1, 17, 12), (8, 8, 4)] {
let a: Vec<u8> = (0..m * k).map(|i| ((i * 31 + 7) % 256) as u8).collect();
let b: Vec<i8> = (0..k * n)
.map(|i| ((i * 17 + 3) % 256) as u8 as i8)
.collect();
let expected = ref_gemm(&a, &b, m, n, k);
let mut got = vec![0i32; m * n];
// SAFETY: avxvnni confirmed at the top of the test.
unsafe { int8_gemm_vpdpbusd_ymm(&a, &b, &mut got, m, n, k) };
assert_eq!(got, expected, "VPDPBUSD-ymm mismatch at (M={}, N={}, K={})", m, n, k);
}
}

#[test]
fn vnni_pack_i8_roundtrip() {
// Pack then verify the VNNI layout matches the spec:
Expand Down
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,15 @@ pub mod simd_ops;
#[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)]
pub mod simd_half;

/// Runtime SIMD dispatch — release-binary distribution path.
///
/// Gated under `--features runtime-dispatch`. Mutually exclusive with
/// `nightly-simd` (the cfg in `simd_runtime/mod.rs` enforces this with
/// a `compile_error!`). See `.claude/knowledge/simd-dispatch-architecture.md`
/// § 7.1 / Phase 5 for the design.
#[cfg(feature = "runtime-dispatch")]
pub mod simd_runtime;

/// Pluggable linear algebra backends (native SIMD, MKL, OpenBLAS).
#[cfg(feature = "std")]
pub mod backend;
Expand Down
49 changes: 46 additions & 3 deletions src/simd_int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,31 @@ pub fn gemm_u8_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usiz
assert!(b.len() >= k * n, "gemm_u8_i8: b.len()={} < k*n={}", b.len(), k * n);
assert!(c.len() >= m * n, "gemm_u8_i8: c.len()={} < m*n={}", c.len(), m * n);

// Compile-time dispatch chain. Exactly one arm survives per build;
// the others are stripped by `#[cfg]` so the compiler emits a direct
// call to the chosen kernel with no runtime branch.
// Tier 0 — runtime AMX check. AMX is a different feature class than
// the rest of the dispatch chain: it requires CPUID + XCR0 + a Linux
// `prctl(ARCH_REQ_XCOMP_PERM, 18)` to be granted, none of which fit
// a `target_feature` compile-time gate. The check is one CPUID +
// one XGETBV + one prctl (idempotent, cached after first call). On
// aligned shapes (16/16/64) this dispatches to TDPBUSD via the
// shared `int8_gemm_amx_tiled` helper — 16 384 MACs per instruction
// vs VPDPBUSD-zmm's 64. Since `gemm_u8_i8` is u8×i8 natively (no
// sign-shift bias needed), the AMX path is a direct call with no
// bias correction — simpler than `matmul_i8_to_i32`'s i8×i8 path.
#[cfg(target_arch = "x86_64")]
{
if crate::hpc::amx_matmul::amx_available()
&& m.is_multiple_of(16)
&& n.is_multiple_of(16)
&& k.is_multiple_of(64)
{
crate::hpc::int8_tile_gemm::int8_gemm_amx_tiled(a, b, c, m, n, k);
return;
}
}

// Compile-time dispatch chain (tiers 1-3). Exactly one arm survives
// per build; the others are stripped by `#[cfg]` so the compiler
// emits a direct call to the chosen kernel with no runtime branch.

#[cfg(all(target_arch = "x86_64", target_feature = "avx512vnni"))]
{
Expand Down Expand Up @@ -731,4 +753,25 @@ mod tests {
);
}
}

/// Exercises the AMX dispatch tier added on top of `gemm_u8_i8`'s
/// compile-time cascade. On AMX-enabled silicon (Sapphire Rapids+
/// with the right OS prctl), 16/16/64-aligned shapes go through
/// TDPBUSD via `int8_gemm_amx_tiled`. Anywhere else this falls back
/// to the compile-time cascade — the assertion still holds because
/// the scalar reference is exact integer arithmetic.
#[test]
fn gemm_u8_i8_amx_aligned_32x32x128() {
let m = 32; // 2 × 16-wide M-tiles
let n = 32; // 2 × 16-wide N-tiles
let k = 128; // 2 × 64-wide K-blocks per tile
let a: Vec<u8> = (0..m * k).map(|i| ((i * 13 + 7) % 256) as u8).collect();
let b: Vec<i8> = (0..k * n)
.map(|i| ((i * 19 + 11) % 256) as u8 as i8)
.collect();
let expected = ref_gemm_u8_i8(&a, &b, m, n, k);
let mut c = vec![0i32; m * n];
gemm_u8_i8(&a, &b, &mut c, m, n, k);
assert_eq!(c, expected, "gemm_u8_i8 AMX path mismatch");
}
}
Loading
Loading