Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions .cargo/config-avx512.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
[build]
# Explicit AVX-512 config — `x86-64-v4`. Use with:
# Explicit AVX-512 config — Sapphire Rapids baseline. Use with:
# cargo --config .cargo/config-avx512.toml build
# cargo --config .cargo/config-avx512.toml test
#
# Compiles `target_feature = "avx512f"` on, so `src/simd.rs` selects the
# `simd_avx512` backend with native `__m512` / `__m512d` / `__m512i`
# storage. Required for the Sapphire Rapids / Granite Rapids hot paths
# (`f32_to_bf16_batch_rne`, the AVX-512BF16 BF16 lanes, the AMX tiles).
# `-Ctarget-cpu=sapphirerapids` enables, in addition to the
# `x86-64-v4` AVX-512 baseline (F + BW + CD + DQ + VL):
#
# Binary produced here will SIGILL on AVX2-only silicon — only use on
# hosts that report `avx512f` in `/proc/cpuinfo`. For shipping a single
# release artifact that adapts at process start, see the LazyLock runtime
# dispatch path in § 7.1 of the architecture doc instead.
# - AVX-512 VNNI (VPDPBUSD u8×i8 → i32)
# - AVX-512 BF16 (VDPBF16PS, VCVTNE2PS2BF16)
# - AVX-512 FP16 (16-wide native FP16 arithmetic)
# - AVX-512 VBMI / VBMI2 (byte permute)
# - AVX-512 IFMA, BITALG, VPOPCNTDQ, GFNI, VAES, VPCLMUL
# - AVX-VNNI (ymm VPDPBUSD on Alder/Sapphire client)
# - AMX-TILE + AMX-INT8 + AMX-BF16 (16×16×k tile kernels)
#
# Effect on the agnostic surfaces in `src/simd_*ops.rs`:
#
# - `simd_int_ops::gemm_u8_i8` resolves to the AVX-512 VNNI `VPDPBUSD`
# zmm kernel (`hpc::vnni_gemm::int8_gemm_vnni_avx512`). When the
# planned `amx-int8` arm lands, it will preempt this one and route
# to `TDPBUSD` instead — same source, no caller changes.
# - BF16 / FP16 lane ops in `src/simd_avx512.rs` light up.
# - `simd_amx::*` tile primitives are usable without further gating.
#
# Pure `x86-64-v4` is NOT used here — Skylake-X is the only AVX-512 CPU
# without VNNI and the project's design pins VNNI as the lowest common
# denominator above the scalar reference. SKX users either build with
# `-Ctarget-cpu=x86-64-v4` explicitly (and accept the scalar arm for
# `gemm_u8_i8`) or run a runtime-LazyLock dispatch binary.
#
# Binary produced here will SIGILL on CPUs that lack any of the
# enabled feature sets — i.e. anything pre-Sapphire-Rapids on x86_64:
#
# - Cooper Lake / Cascade Lake / Ice Lake-SP (no BF16+FP16+AMX)
# - Skylake-X / Skylake-SP / Skylake-W (no VNNI either)
# - Zen 4 / Zen 5 (no AMX)
# - Alder Lake / Arrow Lake (no AVX-512 at all)
# - Haswell ⇢ Coffee Lake (AVX2 only)
#
# Only deploy artifacts built with this config to hosts that report
# `amx_int8 amx_bf16 avx512_bf16 avx512_fp16 avx512_vnni` in
# `/proc/cpuinfo`. For Cascade Lake → Ice Lake-SP → Zen 4 silicon
# (AVX-512 + VNNI but no AMX/BF16/FP16), build with
# `-Ctarget-cpu=cascadelake` or `-Ctarget-cpu=znver4` instead. For
# shipping a single release artifact that adapts at process start,
# see the LazyLock runtime dispatch path in § 7.1 of the architecture
# doc instead.
[target.'cfg(target_arch = "x86_64")']
rustflags = ["-Ctarget-cpu=x86-64-v4"]
rustflags = ["-Ctarget-cpu=sapphirerapids"]
548 changes: 548 additions & 0 deletions .claude/knowledge/agnostic-surface-cpu-matrix.md

Large diffs are not rendered by default.

202 changes: 183 additions & 19 deletions src/hpc/amx_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,15 @@ fn write_contig<A: Copy>(view: &mut ArrayViewMut2<'_, A>, src: &[A]) {

/// Matrix multiply BF16 × BF16 → f32: `out = lhs · rhs`.
///
/// Uses AMX `TDPBF16PS` (256 mul-adds per instruction) when available,
/// otherwise falls back to [`bf16_gemm_f32`].
/// On AMX hardware (Sapphire Rapids+, Granite Rapids), 16×16-aligned tiles
/// dispatch to [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] which
/// emits `TDPBF16PS` via the asm-byte path in `simd_amx.rs` — 256
/// BF16×BF16 multiply-accumulates per instruction (16×16×32 = 8 192 FLOPs)
/// into f32 accumulator tiles. M/N/K tail blocks (when any dim isn't
/// 16/16/32-aligned) fall through to the validated scalar
/// [`crate::hpc::quantized::bf16_gemm_f32`] reference.
///
/// On non-AMX hosts the entire matmul goes through `bf16_gemm_f32`.
///
/// `out` must be row-contiguous (column stride = 1); inputs may be strided.
pub fn matmul_bf16_to_f32(
Expand All @@ -310,26 +317,180 @@ pub fn matmul_bf16_to_f32(
let b = pack_contig(&rhs);
let mut c = vec![0.0f32; m * n];

// AMX path: a tiled 16×16 kernel exists in `bf16_tile_gemm` for sizes that
// fit cleanly. For any leftover tail (or hosts without AMX), defer to the
// scalar `bf16_gemm_f32`. The tile kernel itself is maintained alongside
// the low-level primitives at the top of this file; the public surface
// intentionally goes through the validated scalar path so we always
// produce a numerically-stable f32 result.
if amx_available() {
// Future: AMX-tiled fast path. Today we route through the same
// f32 reference kernel; correctness is identical regardless of
// hardware. The `amx_available()` branch is preserved so callers
// can be sure the AMX detection runs.
bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0);
} else {
bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0);
}
bf16_gemm_dispatch(&a, &b, &mut c, m, n, k);

write_contig(&mut out, &c);
Ok(())
}

/// BF16 × BF16 → f32 GEMM with three-tier dispatch (AMX → VDPBF16PS → scalar).
///
/// Inputs are packed row-major (`a` is M × K, `b` is K × N). Output `c`
/// is M × N row-major and is overwritten (not accumulated).
///
/// Tier selection:
///
/// 1. **AMX `TDPBF16PS`** (Sapphire Rapids+, Granite Rapids) when
/// `amx_available()` is true AND shapes are 16/16/32-aligned.
/// Dispatches through
/// [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] →
/// `simd_amx::tile_dpbf16ps` via asm-byte (`TDPBF16PS` intrinsic is
/// nightly-only on Rust 1.95). 8 192 BF16×BF16 multiplies + 256 f32
/// accumulates per instruction.
/// 2. **`VDPBF16PS`** (Cooper Lake, Cascade Lake AVX-512BF16, Zen 4+)
/// when `is_x86_feature_detected!("avx512bf16")` is true. The
/// intrinsic `_mm512_dpbf16_ps` is stable on Rust 1.95 (no asm-byte
/// needed). Per instruction: 32 BF16×BF16 multiplies + 16 f32
/// accumulates, single-rounded. Handles arbitrary shapes — M / N
/// tails fall through the per-iteration j-block trimming; K-tail
/// (odd K) is handled with a final scalar pair.
/// 3. **Scalar reference** [`bf16_gemm_f32`] for hosts without either
/// extension or for shapes the AMX arm rejects.
///
/// The per-tier dispatch table comes from PR #180's BF16 GEMM column.
fn bf16_gemm_dispatch(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) {
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 32 == 0 {
// SAFETY: BF16 is `#[repr(transparent)] struct BF16(pub u16)`
// (per `hpc::quantized::BF16`). Reinterpreting `&[BF16]` as
// `&[u16]` is bit-pattern preserving.
let a_u16: &[u16] = unsafe { core::slice::from_raw_parts(a.as_ptr() as *const u16, a.len()) };

// B is packed row-major K × N; the 16×16 tile kernel wants a
// K × 16 contiguous sub-block. Extract per (j_tile) into a
// scratch buffer once and reuse across i_tile.
let mut b_tile = vec![0u16; k * 16];
let mut tile_c = vec![0.0f32; 256];

for j_tile in (0..n).step_by(16) {
// Pack b[0..k, j_tile..j_tile+16] into row-major 16-wide K-rows.
for kk in 0..k {
let row = kk * n + j_tile;
for jj in 0..16 {
b_tile[kk * 16 + jj] = b[row + jj].0;
}
}
for i_tile in (0..m).step_by(16) {
// A_tile = a[i_tile..i_tile+16, 0..k] — already contiguous
// since `a` is packed row-major M × K.
let a_tile = &a_u16[i_tile * k..(i_tile + 16) * k];
tile_c.fill(0.0);
crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
// Write tile_c (16 × 16, row-major) into c (M × N, row-major).
for ii in 0..16 {
let dst_off = (i_tile + ii) * n + j_tile;
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
}
}
}
return;
}

#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx512bf16") {
// SAFETY: feature-detected at runtime; the kernel is
// `#[target_feature(enable = "avx512bf16,avx512f")]`.
unsafe {
bf16_gemm_vdpbf16ps(a, b, c, m, n, k);
}
return;
}
}

bf16_gemm_f32(a, b, c, m, n, k, 1.0, 0.0);
}

/// AVX-512BF16 BF16 GEMM using `_mm512_dpbf16_ps` (`VDPBF16PS`).
///
/// One VDPBF16PS instruction: 16 f32 accumulator lanes each receive
/// `acc[j] += a.bf16[2j] * b.bf16[2j] + a.bf16[2j+1] * b.bf16[2j+1]`,
/// single-rounded. The kernel maps the 16 output lanes to a row of 16
/// j-columns of C[i, ·], with one i row processed at a time and a K-pair
/// inner loop accumulating into the same 16 f32 lanes across iterations.
///
/// B-column packing: VDPBF16PS wants the 32 B BF16s per call laid out
/// as 16 lane-pairs (lane j contains `B[2k_pair, j_base+j]` followed by
/// `B[2k_pair+1, j_base+j]`, packed into one u32). We pre-pack B for
/// the current j-block into `b_col_pairs[k_pair * 16 + j] = u32` once
/// per j_block and reuse across all i — amortizes the gather cost.
///
/// K-tail (when K is odd) is handled with a final scalar BF16 multiply
/// per output cell; N-tail (when the j-block has < 16 valid columns)
/// is handled by trimming the store after the VDPBF16PS chain.
///
/// # Safety
/// Caller must have feature-detected `avx512bf16` at runtime.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512bf16,avx512f")]
unsafe fn bf16_gemm_vdpbf16ps(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) {
use core::arch::x86_64::{
__m512bh, __m512i, _mm512_dpbf16_ps, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_setzero_ps, _mm512_storeu_ps,
};

let k_pairs = k / 2;
let k_tail = k % 2;

// SAFETY: BF16 is repr(transparent) over u16.
let a_u16: &[u16] = core::slice::from_raw_parts(a.as_ptr() as *const u16, a.len());
let b_u16: &[u16] = core::slice::from_raw_parts(b.as_ptr() as *const u16, b.len());

// Pre-pack scratch: 16 u32 lanes per k_pair, holding (b_lo | b_hi << 16).
let mut b_col_pairs = vec![0u32; k_pairs.max(1) * 16];
// Scratch for the 16-wide store + N-tail trim.
let mut out_buf = [0.0f32; 16];

for j_base in (0..n).step_by(16) {
let j_count = 16.min(n - j_base);

// Pack B columns [j_base..j_base+j_count] in pair-interleaved layout.
// For lanes j >= j_count (the N-tail of this j_block), pad with 0 —
// they're not stored back, but the VDPBF16PS still touches them.
for k_pair in 0..k_pairs {
let row_lo = 2 * k_pair * n;
let row_hi = (2 * k_pair + 1) * n;
for jj in 0..j_count {
let b_lo = b_u16[row_lo + j_base + jj] as u32;
let b_hi = b_u16[row_hi + j_base + jj] as u32;
b_col_pairs[k_pair * 16 + jj] = (b_hi << 16) | b_lo;
}
for jj in j_count..16 {
b_col_pairs[k_pair * 16 + jj] = 0;
}
}

for i in 0..m {
let mut acc = _mm512_setzero_ps();
let a_row_off = i * k;
for k_pair in 0..k_pairs {
// Broadcast A[i, 2k_pair..2k_pair+2] as the (BF16 lo, BF16 hi)
// pair across all 16 lanes.
let a_lo = a_u16[a_row_off + 2 * k_pair] as u32;
let a_hi = a_u16[a_row_off + 2 * k_pair + 1] as u32;
let pair = (a_hi << 16) | a_lo;
let a_bh: __m512bh = core::mem::transmute(_mm512_set1_epi32(pair as i32));
let b_bh: __m512bh =
core::mem::transmute(_mm512_loadu_si512(b_col_pairs.as_ptr().add(k_pair * 16) as *const __m512i));
acc = _mm512_dpbf16_ps(acc, a_bh, b_bh);
}
_mm512_storeu_ps(out_buf.as_mut_ptr(), acc);

// K-tail: one extra scalar BF16 multiply for k = k_pairs*2.
if k_tail == 1 {
let a_last_f32 = BF16(a_u16[a_row_off + k - 1]).to_f32();
let tail_row = (k - 1) * n;
for jj in 0..j_count {
let b_last_f32 = BF16(b_u16[tail_row + j_base + jj]).to_f32();
out_buf[jj] += a_last_f32 * b_last_f32;
}
}

// Store the j_count valid lanes (drops N-tail padding lanes).
let dst_off = i * n + j_base;
c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]);
}
}
}

// ── f32 → f32 (BF16 compute on AMX) ────────────────────────────────────────

/// Matrix multiply f32 × f32 → f32: `out = lhs · rhs`.
Expand All @@ -349,10 +510,13 @@ pub fn matmul_f32(
let mut c = vec![0.0f32; m * n];

if amx_available() {
// AMX path: down-cast to BF16, run BF16 GEMM, accumulate in f32.
// AMX path: down-cast to BF16 (RNE, ~1 ULP at BF16 mantissa
// precision), then dispatch through the shared BF16 helper
// which picks `TDPBF16PS` tile kernel for 16/16/32-aligned
// shapes and the scalar `bf16_gemm_f32` reference otherwise.
let a_bf16: Vec<BF16> = a_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect();
let b_bf16: Vec<BF16> = b_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect();
bf16_gemm_f32(&a_bf16, &b_bf16, &mut c, m, n, k, 1.0, 0.0);
bf16_gemm_dispatch(&a_bf16, &b_bf16, &mut c, m, n, k);
} else {
// Pure f32 reference path.
for i in 0..m {
Expand Down
Loading
Loading