hpc: TD-T2 — AMX TDPBUSD tile kernel + matmul_i8_to_i32 wiring#184
Conversation
Mirror of the BF16 AMX work (TD-T1 / TD-T1b in PR #182) for the integer operand family. Builds the missing int8 tile kernel from scratch (the BF16 equivalent shipped in PR #104; the int8 one had never been built despite the primitives existing in simd_amx since day one) and wires matmul_i8_to_i32's AMX arm through it. New module `hpc::int8_tile_gemm`: * `int8_tile_gemm_16x16(a_u8, b_i8, c, k)` — public tile kernel, K must be multiple of 64. Mirror shape of `bf16_tile_gemm_16x16` but for the `u8 × i8 → i32` operand family that TDPBUSD natively supports. **One TDPBUSD = 16 384 multiply-accumulates per instruction** (16×16 output tile × 64 K-elements per A row × 4 K-elements per inner-product). That's 256× the VPDPBUSD-zmm throughput per instruction. * Internal `amx_path()` uses the existing primitives in `amx_matmul`: TileConfig::for_dpbusd(64) → tile_loadconfig → tile_zero → K/64 iterations of (tile_load A, tile_load B, tile_dpbusd) → tile_store → tile_release. * `fallback_path()` for non-AMX hosts: scalar u8 × i8 → i32 triple-loop reference. New primitive `amx_matmul::vnni_pack_i8(src, dst, k, n)`: * Packs K × N row-major i8 into K/4 outer rows × (N*4) VNNI quad layout required by TDPBUSD tile 2. * `dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]` * Sibling of `vnni_pack_bf16` (which uses K/2 × (N*2) pair layout for TDPBF16PS — both kernels reach the same 64-byte tile row width via element-width × pack-factor symmetry: BF16 is 2B × 2, INT8 is 1B × 4). Wiring `matmul_i8_to_i32`'s AMX arm (was placebo): Pre-commit the AMX branch shifted i8 → u8 then called the SCALAR `int8_gemm_i32` reference and subtracted the bias — TDPBUSD itself was never reached even on real AMX silicon. Now: 1. Shift A: i8 → u8 via (+128). 2. Tile-loop over M/16 i_tile × N/16 j_tile blocks, calling int8_tile_gemm_16x16 per (i_tile, j_tile). B sub-block extracted into K × 16 scratch once per j_tile, reused across i_tile iterations. 3. Subtract bias: c[i, j] -= 128 × colsum(B[:, j]). The shape requirement is m%16 == 0 && n%16 == 0 && k%64 == 0; misaligned shapes fall back to the scalar reference. Phase-4 work will land mixed AMX-tile + per-axis scalar tail handling for arbitrary shapes (same shape of Phase-4 work TD-T1 deferred). Verification: * Default v3 build: 2092 lib tests pass (was 2087 — adds 5 new tests: 4 in int8_tile_gemm + the existing matmul_i8_to_i32 test now exercises the actual TDPBUSD path because this host has amx_int8 + amx_tile in /proc/cpuinfo; the test continues to pass with bit-identical results to the scalar reference). * `vnni_pack_i8_roundtrip` test verifies the pack layout matches the spec exactly for an 8 × 4 sample. * `fallback_matches_scalar_reference_k64` test verifies the non-AMX path produces the same i32 output as a hand-written reference for a 64-K, pseudo-random u8/i8 matrix pair. * `public_api_diagonal_k128` test asserts a structured pattern (A = identity-like, B = constant 2) gives the expected accumulation through the full dispatch chain. * `cargo clippy --lib -D warnings` clean. * `cargo fmt --all --check` clean. Dropped: `int8_gemm_i32` import in `amx_matmul.rs` since the AMX arm no longer falls back to it (the scalar else-branch uses an inline triple-loop directly). After this commit, the per-CPU dispatch table from PR #180 has the AMX tier wired for BOTH operand families on Sapphire Rapids+: BF16 GEMM: SPR+ → TDPBF16PS (TD-T1 / TD-T1b in PR #182) INT8 GEMM: SPR+ → TDPBUSD (this commit) Out of scope (separate PRs): * VPDPBUSD-zmm arm of matmul_i8_to_i32 for Cooper Lake / Cascade Lake / Zen 4+ (avx512vnni without AMX). The kernel function `vnni_dot_u8_i8` and `vnni_matvec` exist in simd_amx.rs; just need to assemble them into a m×n×k GEMM and wire as the middle dispatch tier (analogous to the VDPBF16PS arm in PR #182's bf16_gemm_dispatch). * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level surface from PR #182) — it's u8 × i8 natively so no sign-shift needed, simpler to wire than matmul_i8_to_i32. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b1979d799e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| // byte — same 64-byte row width either way). | ||
| let cfg = TileConfig::for_dpbusd(64); | ||
| tile_loadconfig(&cfg); | ||
| tile_zero(0); |
There was a problem hiding this comment.
Preserve existing C accumulator in AMX int8 tile path
int8_tile_gemm_16x16 is documented (and its fallback implements) C += A·B, but the AMX branch unconditionally executes tile_zero(0) before compute and then stores tmm0 back to c, which discards any pre-existing values in c on AMX-enabled hosts. This creates hardware-dependent behavior: callers that rely on accumulation (e.g., blocked GEMM or repeated partial-K updates) will get incorrect results only when AMX is available.
Useful? React with 👍 / 👎.
Two clippy-as-error issues blocking PR #184 CI: 1. `src/hpc/int8_tile_gemm.rs:147` (mine, from b1979d7) — `clippy::unused_parens` flagged the closure body `(((i*11+5) % 256) as u8 as i8)` in the `fallback_matches_scalar_reference_k64` test. The outer parens around the cast chain are redundant; rustfmt re-broke the line to multi-line after removal so it stays readable. 2. `tests/par_rayon.rs:9` (pre-existing) — `clippy::manual_div_ceil` flagged `(M + CHUNK_SIZE - 1) / CHUNK_SIZE`. Replaced with `M.div_ceil(CHUNK_SIZE)` per the clippy hint. This file was already in tree; the lint became active in clippy 1.95 (Rust stable) which CI now uses, so prior PRs weren't blocked by it but the rayon-features test build is now. Both fixes are mechanical / no behaviour change: * `cargo clippy --tests --features rayon,native -- -D warnings` clean. * `cargo fmt --all --check` clean. Stashed work-in-progress on the VPDPBUSD-zmm middle tier for `matmul_i8_to_i32` (the natural symmetric next step after TD-T2, analogous to the VDPBF16PS arm shipped in PR #182's `bf16_gemm_dispatch`); will follow up in a separate commit once CI is unblocked. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Completes the per-CPU dispatch chain for `matmul_i8_to_i32`. Per PR #180's table the middle tier between AMX TDPBUSD (Sapphire Rapids+) and the scalar reference is `_mm512_dpbusd_epi32` (zmm form, avx512vnni feature) — covers Cooper Lake, Cascade Lake, Ice Lake-SP, Zen 4+ silicon that has AVX-512 VNNI but not AMX. Mirrors the VDPBF16PS arm structure that landed for BF16 in PR #182's `bf16_gemm_dispatch`. New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm`: * One VPDPBUSD instruction: 16 i32 accumulator lanes, each receiving 4 u8×i8 products = 64 MACs per instruction. * Maps the 16 output lanes to a row of 16 j-columns of `c[i, ·]`, one i row processed at a time, K-quad inner loop accumulating into the same 16 i32 lanes across iterations. * B-column packing: pre-packs B for the current j-block into `b_col_quads[k_quad * 16 + j] = i32 (4 bytes of B[4k_quad.., j_base+j] packed bottom-to-top)` once per j-block; reused across all M i-iterations so the gather cost amortizes. * A row quad broadcast: `_mm512_set1_epi32` of (4 u8 bytes packed) every K-iter — same quad seen by every output column. * K-tail (k % 4 != 0) handled with scalar u8×i8 multiplies per output cell; N-tail (j_count < 16) handled by trimming the store width — padding lanes still receive VPDPBUSD updates but aren't written back. * Stable intrinsic `_mm512_dpbusd_epi32` under `target_feature = "avx512vnni,avx512f"` — no asm-byte needed. Wiring `matmul_i8_to_i32` to three-tier dispatch: 1. amx_available() + 16/16/64-aligned shapes → int8_tile_gemm_16x16 → TDPBUSD asm-byte (16 384 MACs/instr, this commit reuses the kernel from PR #184 fe334de... wait, same PR — from b1979d7 in THIS PR) 2. is_x86_feature_detected!("avx512vnni") → int8_gemm_vpdpbusd_zmm → _mm512_dpbusd_epi32 stable intrinsic (64 MACs/instr, arbitrary shapes, K-tail handled scalar, N-tail handled by per-iteration j_count trim) 3. scalar i8×i8 → i32 reference for non-x86, pre-AVX-512 hosts, or shapes that don't satisfy either SIMD tier's requirements Factored the shared sign-shift bias subtraction into a private `subtract_i8_to_u8_bias(c, b_i8, m, n, k)` helper: both Tier 1 (AMX) and Tier 2 (VNNI) shift LHS i8 → u8 via (+128) then need to subtract 128·colsum(B) from the accumulator. Pure integer arithmetic, bit-identical to the scalar i8×i8 → i32 reference. Verification: * Default v3 build: 2093 lib tests pass (was 2092 — +1 new test `vpdpbusd_zmm_matches_scalar` that exercises the new arm directly with shapes spanning aligned cases, K-tail (k % 4), N-tail (n % 16), and small shapes; asserts byte-equal output vs scalar reference). * Existing `matmul_i8_to_i32_16x16_exact` continues to pass through the AMX tier on this host (which has amx_int8). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit: matmul_bf16_to_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 VPDPBUSD | scalar (b1979d7) | (THIS COMMIT) | (always) So all three of the public matmul entry points now have full three-tier dispatch on x86_64. Out of scope (separate PRs): * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level u8×i8 surface from PR #182) — it's u8×i8 natively, no sign- shift bias needed, simpler than matmul_i8_to_i32. * AVX-VNNI ymm arm (Arrow Lake / Meteor Lake U: avxvnni without avx512vnni) — the `vnni2_*` functions exist in simd_amx.rs but need to be assembled into a m×n×k VNNI-ymm GEMM. Same shape as the avx512vnni arm just with ymm width. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Per codex review on PR #184: `int8_tile_gemm_16x16` is documented as `C += A·B` and the scalar `fallback_path` correctly accumulates, but the AMX `amx_path` did `tile_zero(0)` + `tile_store(0, c, 64)` which **overwrote** any pre-existing values in `c` on AMX-enabled hosts. Hardware-dependent behavior: callers relying on accumulation (blocked GEMM, repeated partial-K updates) would get incorrect results only when AMX was active. Same bug in the BF16 sibling `bf16_tile_gemm::amx_path` (shipped in PR #104) — fixing both. The fix: 1. Add a `tile 0` case to `tile_load` (encoding `C4 E2 7B 4B 04 08` — same SIB byte as the existing tmm1/tmm2 cases, ModR/M `04` = `mod=00, reg=000 (tmm0), r/m=100 (SIB follows)`). 2. Both AMX paths replace `tile_zero(0)` with `tile_load(0, c.as_ptr() as *const u8, 64)` — preloads tmm0 from caller's C buffer. TDPBUSD / TDPBF16PS then accumulate into the pre-loaded values; `tile_store(0, c, 64)` writes back the true `+=` result. Consumer impact: zero. Both `matmul_bf16_to_f32` and `matmul_i8_to_i32` (the only callers of these kernels in this crate) `tile_c.fill(0)` before each call — so the now-accumulating behavior + zero-initialized C = same result as the prior overwrite semantics. The fix removes a latent trap for future blocked-GEMM / partial-K consumers without changing any shipped behavior. New regression test `amx_path_preserves_c_accumulator`: pre-loads C with a known non-zero marker pattern, runs A·B where B=0 (so the contribution is 0), asserts the marker is preserved. Would fail on the pre-fix code because the tile_store would zero everything. Passes on this host (which has amx_int8). Verification: * 2094 lib tests pass (was 2093 — +1 regression test). * 11 amx_matmul tests pass (consumers' fill(0)-then-call pattern continues to produce correct results). * 2 bf16_tile_gemm tests pass. * 6 int8_tile_gemm tests pass. * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
…able Rebased onto master post-#181, #182, #183. Replaces the polyfill-based add_mul_f32/f64 with LazyLock-cached function pointers picking real hardware FMA per silicon, and adds two more LazyLock-cached primitives the consumer needs: is_amx_available() and vnni_dot_u8_i8. WHY: F32x16::mul_add on AVX2 builds drops to per-lane scalar f32::mul_add (simd_avx2.rs:586). The polyfill abstracts lane width but cannot pick between _mm256_fmadd_ps and _mm512_fmadd_ps — that is an instruction-family choice, not a lane-width one. LazyLock amortises a one-time simd_caps() read into a frozen fn pointer; every subsequent call is a single indirect jump with zero is_x86_feature_detected! overhead. No SimdProfile exposed at the consumer surface — agnostic contract preserved. add_mul_f32(acc, a, b) — acc[i] += a[i]*b[i] AVX-512F+FMA → _mm512_fmadd_ps 16-wide + 8-wide tail + scalar tail AVX2+FMA → _mm256_fmadd_ps 8-wide + scalar tail NEON → vfmaq_f32 4-wide + scalar tail scalar → f32::mul_add per lane no_std build → preserves the polyfill F32x16::mul_add path (LazyLock requires std) add_mul_f64(acc, a, b) — f64 sibling, same shape with 8/4/2 lanes. is_amx_available() — wraps simd_amx::amx_available() (CPUID + OSXSAVE + XCR0[17,18] + Linux arch_prctl(XCOMP_PERM)) in LazyLock<bool>. The 4-step gate, including the syscall, fires exactly once per process. Always false on non-x86_64. vnni_dot_u8_i8(a, b) — i32 dot of u8 × i8 slices: AVX-512 VNNI → delegates to simd_amx::vnni_dot_u8_i8 wrapped with scalar tail handling (the existing kernel processes only n - (n%64) since its cognitive-shader caller pre-aligns rows; general-purpose callers need the tail) AVX-VNNI 256 → delegates to simd_amx::vnni2_dot_u8_i8 directly (that one already handles its scalar tail) scalar → simd_amx::vnni_dot_u8_i8_scalar No intrinsic code is duplicated. The dispatcher composes existing simd_amx::* kernels (which #182/#184 also build on) into a safe LazyLock-cached consumer-facing wrapper. simd_amx::matvec_dispatch runs the same selection logic but uses is_x86_feature_detected! per call; this wrapper amortises that to once at startup. PARITY CONTRACT: - add_mul_f32 / add_mul_f64: bit-identical to f32::mul_add / f64::mul_add per lane via to_bits() assertion. All vector backends emit single-rounded IEEE-754 FMA. - vnni_dot_u8_i8: bit-identical i32 to scalar widen-and-multiply. VPDPBUSD does not saturate the accumulator (intermediate u8*i8 products bounded by 32385, four-element sums by 129540). Tests: 2101/2101 lib pass (7 new lazylock_dispatch_tests over 12 problem sizes / tail lengths). cargo clippy --lib clean under default and --features cpu-spr. On Sapphire Rapids host the LazyLock resolved to AVX-512+FMA for add_mul, AVX-512 VNNI for vnni_dot; AMX is_amx_available returns false (hypervisor masks XCR0[17,18]) — matches the Risk #3 demotion from 61b4563. This commit was rebased atop master after the parallel session shipped PR #182 (BF16 AMX tile kernels), #183 (F16C cast batch), and prepared #184 (TDPBUSD int8 tile + matmul_i8_to_i32 wiring). The earlier 469ecc7 (coarse + SimdTier) and 77e3971 (mul_add_f32_into + walkback) and be65595 (is_amx_available + vnni_dot duplicating intrinsics) are subsumed by this single clean commit: no public SimdProfile / SimdTier re-export, no duplicated intrinsic code, no mul_add_f32_into (master's add_mul_f32 shape is the right primitive).
Extends the u8×i8 → i32 dispatch chain from PR #182's compile-time cascade (avx512vnni → avxvnni → scalar) by adding a top-tier AMX runtime check. Brings the SPR/GNR TDPBUSD path (16 384 MACs per instruction) to the slice-level surface that downstream consumers (lance-graph, etc.) use, completing the symmetry with PR #184's matmul_i8_to_i32 wiring. `gemm_u8_i8` is u8×i8 natively — no sign-shift bias trick needed (unlike `matmul_i8_to_i32` which is i8×i8 and has to convert via +128 then subtract `128·colsum(B)`). That makes the AMX path here a direct call with no bias correction. New helper `hpc::int8_tile_gemm::int8_gemm_amx_tiled(a_u8, b_i8, c, m, n, k)` factors out the tile-decomposition logic that was previously inlined in `matmul_i8_to_i32`. Both consumers now share the same helper: matmul_i8_to_i32: 1. shift A: i8 → u8 (+128) 2. int8_gemm_amx_tiled(a_u8, b, c, m, n, k) 3. subtract_i8_to_u8_bias(c, b, m, n, k) gemm_u8_i8 (AMX tier added in this commit): 1. int8_gemm_amx_tiled(a, b, c, m, n, k) — no shift, no bias The helper handles arbitrary 16/16/64-aligned shapes via a j_tile × i_tile loop calling int8_tile_gemm_16x16 per (16, 16) block. B sub-block extracted into K × 16 scratch once per j-tile, reused across all M i-tiles. **Overwrite semantics**: c is written not accumulated (the underlying int8_tile_gemm_16x16 accumulates into its tile buffer, but we zero the tile buffer before each call so the per-tile write to c is pure overwrite). Dispatch placement in gemm_u8_i8: * Tier 0 (this commit): runtime amx_available() check at the top of the function. AMX requires CPUID + XCR0 + Linux prctl which can't fit a target_feature compile-time gate. * Tiers 1-3: existing compile-time cfg-cascade (avx512vnni zmm → avxvnni ymm → scalar i8_gemm_i32). Unchanged. Misaligned shapes (m/n not multiples of 16, k not multiple of 64) or non-AMX hosts fall through to the compile-time cascade as before. Also fixed pre-existing clippy::manual_is_multiple_of warnings that surfaced in the new alignment check — switched from `% 16 == 0` to `.is_multiple_of(16)` etc. per the clippy hint (Rust 1.95 promoted this from `pedantic` to active warn). Verification: * 2095 lib tests pass (was 2094 — +1 new `gemm_u8_i8_amx_aligned_32x32x128` test exercising the AMX arm with a 32×32×128 shape that hits the AMX tier on this host's amx_int8 silicon). * 11 amx_matmul tests pass (matmul_i8_to_i32 refactored to call the shared helper — same behavior as before). * 4 gemm_u8_i8 tests pass (the existing ones still hit the compile-time cascade since their shapes aren't AMX-aligned). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit: matmul_bf16_to_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 VPDPBUSD | scalar (PR #184) | (PR #184) | (always) gemm_u8_i8 (slice): SPR+ AMX | CPL/Zen4 VPDPBUSD | ARL ymm | scalar (THIS) | (PR #182) | (PR #182) | (PR #182) Out of scope (separate PRs): * AVX-VNNI ymm arm for matmul_i8_to_i32 — `vnni2_*` helpers exist in simd_amx.rs but need assembling into a m×n×k GEMM. Same shape as the avx512vnni arm just with ymm width. * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Completes the per-CPU dispatch chain for `matmul_i8_to_i32` by adding the AVX-VNNI ymm tier — Arrow Lake, Meteor Lake U, Alder Lake silicon that has AVX-VNNI but dropped AVX-512. Mirrors the shape of the avx512vnni-zmm arm shipped in PR #184 with the narrower 8-wide kernel. New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm`: * One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes, each receiving 4 u8×i8 products = 32 MACs per instruction. Half the throughput-per-instruction of the `_mm512_dpbusd_epi32` zmm version. * Same B-pre-pack scheme (quad-interleaved per 8-wide j-block), same K-tail / N-tail handling. Just narrower. * Stable intrinsic under `target_feature = "avxvnni,avx2"` — no asm-byte needed. Wiring `matmul_i8_to_i32`'s dispatch as Tier 3: 1. amx_available() + 16/16/64-aligned → AMX TDPBUSD (PR #184: int8_gemm_amx_tiled, 16 384 MACs/instr) 2. is_x86_feature_detected!("avx512vnni") → VPDPBUSD-zmm (PR #184: int8_gemm_vpdpbusd_zmm, 64 MACs/instr) 3. is_x86_feature_detected!("avxvnni") → VPDPBUSD-ymm (THIS COMMIT: int8_gemm_vpdpbusd_ymm, 32 MACs/instr) 4. scalar i8×i8 → i32 reference (was Tier 3) All three SIMD tiers share the sign-shift bias trick: shift LHS i8 → u8 (+128), run the kernel, subtract 128·colsum(B). Same `subtract_i8_to_u8_bias` helper (factored in PR #184). New direct test `vpdpbusd_ymm_matches_scalar` mirrors the zmm version's test: sweeps shapes spanning 8-aligned, K-tail (k % 4), N-tail (n % 8), and small shapes, asserts byte-equal output vs scalar reference. Verification: * Default v3 (this host has avx512vnni so the new arm doesn't fire from matmul_i8_to_i32 — Tier 2 catches first): 2096 lib tests pass (was 2095 — +1 new direct test). * Direct test exercises int8_gemm_vpdpbusd_ymm on this host since avxvnni is present alongside avx512vnni. * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit (final on the int8 side): matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 zmm | ARL ymm | scalar (PR #184) | (PR #184) | (THIS) | (always) The matmul_i8_to_i32 column of PR #180's dispatch table is now fully filled. The gemm_u8_i8 slice surface (in PR #185) already has AVX-VNNI ymm via its existing compile-time cascade — both i8-related public surfaces now cover every x86_64 tier with a hardware-accelerated arm. Out of scope (separate PRs): * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b, needs aarch64 CI runner verification. * TD-T6: real _mm256_* for AVX2 BLAS-1 (scal/nrm2/asum). https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
**Alternative** to the compile-time cascade in `crate::simd::*` /
`crate::simd_ops::*`. **Additive**: gated under
`--features runtime-dispatch`, does not touch any existing path.
Mutually exclusive with `nightly-simd` (the portable-SIMD polyfill
replaces the architecture-specific intrinsics that the runtime
trampolines select between).
Use case: ship ONE binary that adapts across heterogeneous
deployment silicon (AVX-512 server + AVX2-only laptop + Arrow Lake
desktop + Sapphire Rapids workstation) from the same artifact. The
existing compile-time `v3` / `v4` / `native` / `nightly-simd`
configs target a single class of CPU per build; the runtime layer
targets the union via per-op LazyLock<fn ptr> trampolines.
Design from `.claude/knowledge/simd-dispatch-architecture.md` § 7.1
/ Phase 5, building on the precedent set by
`hpc::bgz17_bridge::{L1_KERNEL, L1_WEIGHTED_KERNEL, ...}`
(`LazyLock<L1Fn>` pattern, lines 75-86) already proven in tree.
# Dispatch model
One `LazyLock<fn ptr>` per public surface. First call fires the
closure which reads `simd_caps()` and selects a backend; every
subsequent call is one pointer-deref + indirect call. Per-call
overhead: ~2-3 ns (LazyLock atomic-acquire load that's cache-
resident after first hit + indirect-call branch-target predict).
Invisible against any SIMD op's actual work (~100+ cycles).
# Module layout
src/simd_runtime/
mod.rs — module entry, mutual-exclusion check vs
nightly-simd, public re-exports
vnni_dot.rs — u8×i8 → i32 dot (the proposal's canonical
example): 3 backends, the AVX-512 arm
wraps `simd_amx::vnni_dot_u8_i8` with a
scalar tail because the existing kernel
silently drops n%64 lanes (its matvec
caller pre-aligns rows; a general-purpose
dispatch surface cannot assume that)
add_mul.rs — slice-level FMA (acc += a × b) for f32/f64;
the ONLY new kernel code in this module —
4 backends per type (avx512 / avx2+fma /
neon / scalar), each ~15 LoC of direct
intrinsics
matmul.rs — thin trampolines for matmul_bf16_to_f32 /
matmul_f32 / matmul_i8_to_i32 / gemm_u8_i8
delegating to existing functions that
already runtime-dispatch internally
(PR #182 / #184 / #185)
casts.rs — trampolines for the four half-precision
batch casts delegating to PR #183's already-
runtime-dispatched implementations
# Backend reuse — no kernel duplication
Every dispatch arm delegates to a kernel that already exists in
tree. The runtime layer is just the trampoline. The only NEW
kernel code is `add_mul_f32` / `add_mul_f64` (no pre-existing
slice-level FMA primitive in tree to delegate to — the compile-
time `crate::simd_ops::add_mul_f32` from PR #182 polyfills through
the F32x16 lane wrapper; the runtime version skips that
indirection for one more inlined intrinsic per chunk).
# Invariants preserved from this PR series
* No-FP32-roundtrip on BF16/F16 arithmetic — backends respect
the bit-exact mantissa rule
* Asm-byte encoding for nightly-gated AMX / FP16 — selected
backends keep their existing asm-byte fast paths
* Little-endian byte contracts for half-precision carriers
* Accumulator-preservation in tile paths (codex P1 from #184)
* Boundary assertions on safe public fns (codex P1 from #185) —
the public `vnni_dot_u8_i8(a, b)` etc. inherit the asserts
transparently via the call chain
# Verification
* Default build (no feature): 2087 lib tests pass — the
`simd_runtime` module is gated out, zero impact on existing
paths.
* `cargo test --lib --features runtime-dispatch`: **2105 lib
tests pass** (+8 new in `simd_runtime::*::tests`).
* `cargo clippy --lib --tests --features rayon,native -- -D warnings`
clean (default).
* `cargo clippy --lib --tests --features rayon,native,runtime-dispatch
-- -D warnings` clean.
* `cargo fmt --all --check` clean.
* Mutual-exclusion enforced via `compile_error!` in
`simd_runtime/mod.rs` — `--features runtime-dispatch,nightly-simd`
fails to compile with a clear error.
# What's NOT in this PR (deferred)
* Sweep the remaining ~15-20 SIMD/HPC public surfaces (min_i8,
max_i8, add_i8, dot_i8, etc.). Each is ~30-50 LoC of trampoline;
pattern is established here. Estimated ~700-900 more LoC across
the full surface map.
* CI matrix entry for `runtime-dispatch-portable` (per
simd-dispatch-architecture.md § 7 / TD-SIMD-9). Job builds
with `--features runtime-dispatch` on a v3 baseline runner and
asserts every trampoline lands on its expected backend.
* `simd_caps()` snapshot logging at process start (debug-only)
to aid release-binary deployment debugging — "which arm did
you actually pick?"
# Cost summary
src/simd_runtime/ +537 LoC (4 modules)
src/lib.rs +9 LoC (cfg-gated mod decl)
Cargo.toml +21 LoC (feature decl + doc)
Total ~570 LoC
Trampoline LoC per surface (this PR's sample):
vnni_dot 170 LoC (LazyLock + 3 arms + wrapper + tests)
add_mul (f32+f64)218 LoC (LazyLock×2 + 4 arms×2 + tests — the ONLY new kernels)
matmul (4 ops) 100 LoC (thin delegations + tests)
casts (4 ops) 75 LoC (thin delegations + tests)
Out-of-tree estimate for the full sweep (per § 7 of the design
doc): ~1400 LoC total once all ~25 public SIMD/HPC surfaces are
wired. This PR establishes ~40% of that budget with the canonical
patterns.
https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Summary
TD-T2 — the int8 mirror of the BF16 AMX work that landed in PR #182. Builds the missing int8 tile kernel from scratch (the BF16 equivalent shipped in PR #104; the int8 one had never been built despite the AMX primitives existing in
simd_amxfrom day one) and wiresmatmul_i8_to_i32's AMX arm through it.What's new
hpc::int8_tile_gemmmodule (new):int8_tile_gemm_16x16(a_u8, b_i8, c, k)— public 16×16 tile kernel, K must be multiple of 64. Mirror shape ofbf16_tile_gemm_16x16for theu8 × i8 → i32operand family TDPBUSD natively supports. One TDPBUSD = 16 384 multiply-accumulates per instruction (16×16 output tile × 64 K-elements per A row × 4 K-elements per inner-product) — 256× the VPDPBUSD-zmm throughput per instruction.amx_matmul::vnni_pack_i8(src, dst, k, n)(new primitive):dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]vnni_pack_bf16. Both kernels reach the same 64-byte tile row width via element-width × pack-factor symmetry (BF16 = 2B × 2, INT8 = 1B × 4).matmul_i8_to_i32AMX arm — was placebo, now realPre-this-PR: the AMX branch shifted i8 → u8 then called the SCALAR
int8_gemm_i32reference and subtracted the bias — TDPBUSD itself was never reached even on real AMX silicon. After this PR:int8_tile_gemm_16x16per (i_tile, j_tile). B sub-block extracted into K × 16 scratch once per j_tile, reused across i_tile iterations.c[i, j] -= 128 × colsum(B[:, j]).Shape requirement is
m%16 == 0 && n%16 == 0 && k%64 == 0; misaligned shapes fall back to the scalar reference. Phase-4 work will land mixed AMX-tile + per-axis scalar tail handling (same shape of Phase-4 work TD-T1 deferred).Test plan
x86-64-v3AVX2): 2092 lib tests pass (was 2087 — +5 new tests).hpc::int8_tile_gemm::tests:fallback_matches_scalar_reference_k64,public_api_runs_on_any_hardware_k64,public_api_diagonal_k128,vnni_pack_i8_roundtrip.matmul_i8_to_i32_16x16_exactnow exercises the actual TDPBUSD path because this host hasamx_int8 + amx_bf16 + amx_tilein/proc/cpuinfo; the test continues to pass with bit-identical results to the scalar reference.cargo clippy --lib -- -D warningsclean.cargo fmt --all --checkclean.Per-CPU dispatch state after this PR
After this PR + #182 + #183, the AMX tier is wired for BOTH operand families on Sapphire Rapids+:
matmul_bf16_to_f32matmul_f32(BF16 compute)matmul_i8_to_i32Out of scope (separate PRs)
matmul_i8_to_i32for Cooper Lake / Cascade Lake / Zen 4+ (avx512vnni without AMX). The kernel functionsvnni_dot_u8_i8andvnni_matvecexist insimd_amx.rsalready — just need to assemble them into an m×n×k GEMM and wire as the middle dispatch tier (analogous to the VDPBF16PS arm inbf16_gemm_dispatch).simd_int_ops::gemm_u8_i8(the slice-level u8×i8 surface from PR simd: agnostic gemm_u8_i8 surface, integer-slice-op lift, per-CPU matrix, BF16 AMX wiring #182) — it's u8 × i8 natively so no sign-shift needed, simpler to wire thanmatmul_i8_to_i32.https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Generated by Claude Code