Skip to content

hpc: TD-T2 — AMX TDPBUSD tile kernel + matmul_i8_to_i32 wiring#184

Merged
AdaWorldAPI merged 4 commits into
masterfrom
claude/continue-ndarray-x0Oaw
May 21, 2026
Merged

hpc: TD-T2 — AMX TDPBUSD tile kernel + matmul_i8_to_i32 wiring#184
AdaWorldAPI merged 4 commits into
masterfrom
claude/continue-ndarray-x0Oaw

Conversation

@AdaWorldAPI
Copy link
Copy Markdown
Owner

Summary

TD-T2 — the int8 mirror of the BF16 AMX work that landed in PR #182. Builds the missing int8 tile kernel from scratch (the BF16 equivalent shipped in PR #104; the int8 one had never been built despite the AMX primitives existing in simd_amx from day one) and wires matmul_i8_to_i32's AMX arm through it.

What's new

hpc::int8_tile_gemm module (new):

  • int8_tile_gemm_16x16(a_u8, b_i8, c, k) — public 16×16 tile kernel, K must be multiple of 64. Mirror shape of bf16_tile_gemm_16x16 for the u8 × i8 → i32 operand family TDPBUSD natively supports. One TDPBUSD = 16 384 multiply-accumulates per instruction (16×16 output tile × 64 K-elements per A row × 4 K-elements per inner-product) — 256× the VPDPBUSD-zmm throughput per instruction.
  • AMX path: TileConfig::for_dpbusd(64) → tile_loadconfig → tile_zero → K/64 iterations of (tile_load A, tile_load B, tile_dpbusd) → tile_store → tile_release.
  • Scalar fallback path for non-AMX hosts.

amx_matmul::vnni_pack_i8(src, dst, k, n) (new primitive):

  • Packs K × N row-major i8 into K/4 outer rows × (N*4) VNNI quad layout required by TDPBUSD tile 2.
  • dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]
  • Sibling of vnni_pack_bf16. Both kernels reach the same 64-byte tile row width via element-width × pack-factor symmetry (BF16 = 2B × 2, INT8 = 1B × 4).

matmul_i8_to_i32 AMX arm — was placebo, now real

Pre-this-PR: the AMX branch shifted i8 → u8 then called the SCALAR int8_gemm_i32 reference and subtracted the bias — TDPBUSD itself was never reached even on real AMX silicon. After this PR:

  1. Shift A: i8 → u8 via (+128).
  2. Tile-loop over M/16 i_tile × N/16 j_tile blocks, calling int8_tile_gemm_16x16 per (i_tile, j_tile). B sub-block extracted into K × 16 scratch once per j_tile, reused across i_tile iterations.
  3. Subtract bias: c[i, j] -= 128 × colsum(B[:, j]).

Shape requirement is m%16 == 0 && n%16 == 0 && k%64 == 0; misaligned shapes fall back to the scalar reference. Phase-4 work will land mixed AMX-tile + per-axis scalar tail handling (same shape of Phase-4 work TD-T1 deferred).

Test plan

  • Default v3 (x86-64-v3 AVX2): 2092 lib tests pass (was 2087 — +5 new tests).
    • 4 new in hpc::int8_tile_gemm::tests: fallback_matches_scalar_reference_k64, public_api_runs_on_any_hardware_k64, public_api_diagonal_k128, vnni_pack_i8_roundtrip.
    • Existing matmul_i8_to_i32_16x16_exact now exercises the actual TDPBUSD path because this host has amx_int8 + amx_bf16 + amx_tile in /proc/cpuinfo; the test continues to pass with bit-identical results to the scalar reference.
  • cargo clippy --lib -- -D warnings clean.
  • cargo fmt --all --check clean.

Per-CPU dispatch state after this PR

After this PR + #182 + #183, the AMX tier is wired for BOTH operand families on Sapphire Rapids+:

Op SPR/GNR Zen4/CPL Others
matmul_bf16_to_f32 ✅ TDPBF16PS (#182) ✅ VDPBF16PS (#182) scalar
matmul_f32 (BF16 compute) ✅ TDPBF16PS (#182) ✅ VDPBF16PS (#182) scalar
matmul_i8_to_i32 TDPBUSD (this PR) ⏳ VPDPBUSD zmm scalar

Out of scope (separate PRs)

  • VPDPBUSD-zmm arm of matmul_i8_to_i32 for Cooper Lake / Cascade Lake / Zen 4+ (avx512vnni without AMX). The kernel functions vnni_dot_u8_i8 and vnni_matvec exist in simd_amx.rs already — just need to assemble them into an m×n×k GEMM and wire as the middle dispatch tier (analogous to the VDPBF16PS arm in bf16_gemm_dispatch).
  • AMX tile path for simd_int_ops::gemm_u8_i8 (the slice-level u8×i8 surface from PR simd: agnostic gemm_u8_i8 surface, integer-slice-op lift, per-CPU matrix, BF16 AMX wiring #182) — it's u8 × i8 natively so no sign-shift needed, simpler to wire than matmul_i8_to_i32.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u


Generated by Claude Code

Mirror of the BF16 AMX work (TD-T1 / TD-T1b in PR #182) for the
integer operand family. Builds the missing int8 tile kernel from
scratch (the BF16 equivalent shipped in PR #104; the int8 one had
never been built despite the primitives existing in simd_amx since
day one) and wires matmul_i8_to_i32's AMX arm through it.

New module `hpc::int8_tile_gemm`:

  * `int8_tile_gemm_16x16(a_u8, b_i8, c, k)` — public tile kernel,
    K must be multiple of 64. Mirror shape of
    `bf16_tile_gemm_16x16` but for the `u8 × i8 → i32` operand
    family that TDPBUSD natively supports. **One TDPBUSD = 16 384
    multiply-accumulates per instruction** (16×16 output tile × 64
    K-elements per A row × 4 K-elements per inner-product). That's
    256× the VPDPBUSD-zmm throughput per instruction.
  * Internal `amx_path()` uses the existing primitives in
    `amx_matmul`: TileConfig::for_dpbusd(64) → tile_loadconfig →
    tile_zero → K/64 iterations of (tile_load A, tile_load B,
    tile_dpbusd) → tile_store → tile_release.
  * `fallback_path()` for non-AMX hosts: scalar u8 × i8 → i32
    triple-loop reference.

New primitive `amx_matmul::vnni_pack_i8(src, dst, k, n)`:

  * Packs K × N row-major i8 into K/4 outer rows × (N*4) VNNI quad
    layout required by TDPBUSD tile 2.
  * `dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]`
  * Sibling of `vnni_pack_bf16` (which uses K/2 × (N*2) pair layout
    for TDPBF16PS — both kernels reach the same 64-byte tile row
    width via element-width × pack-factor symmetry: BF16 is 2B × 2,
    INT8 is 1B × 4).

Wiring `matmul_i8_to_i32`'s AMX arm (was placebo):

Pre-commit the AMX branch shifted i8 → u8 then called the SCALAR
`int8_gemm_i32` reference and subtracted the bias — TDPBUSD itself
was never reached even on real AMX silicon. Now:

  1. Shift A: i8 → u8 via (+128).
  2. Tile-loop over M/16 i_tile × N/16 j_tile blocks, calling
     int8_tile_gemm_16x16 per (i_tile, j_tile). B sub-block
     extracted into K × 16 scratch once per j_tile, reused across
     i_tile iterations.
  3. Subtract bias: c[i, j] -= 128 × colsum(B[:, j]).

The shape requirement is m%16 == 0 && n%16 == 0 && k%64 == 0;
misaligned shapes fall back to the scalar reference. Phase-4 work
will land mixed AMX-tile + per-axis scalar tail handling for
arbitrary shapes (same shape of Phase-4 work TD-T1 deferred).

Verification:
  * Default v3 build: 2092 lib tests pass (was 2087 — adds 5 new
    tests: 4 in int8_tile_gemm + the existing matmul_i8_to_i32 test
    now exercises the actual TDPBUSD path because this host has
    amx_int8 + amx_tile in /proc/cpuinfo; the test continues to
    pass with bit-identical results to the scalar reference).
  * `vnni_pack_i8_roundtrip` test verifies the pack layout matches
    the spec exactly for an 8 × 4 sample.
  * `fallback_matches_scalar_reference_k64` test verifies the
    non-AMX path produces the same i32 output as a hand-written
    reference for a 64-K, pseudo-random u8/i8 matrix pair.
  * `public_api_diagonal_k128` test asserts a structured pattern
    (A = identity-like, B = constant 2) gives the expected
    accumulation through the full dispatch chain.
  * `cargo clippy --lib -D warnings` clean.
  * `cargo fmt --all --check` clean.

Dropped: `int8_gemm_i32` import in `amx_matmul.rs` since the AMX
arm no longer falls back to it (the scalar else-branch uses an
inline triple-loop directly).

After this commit, the per-CPU dispatch table from PR #180 has the
AMX tier wired for BOTH operand families on Sapphire Rapids+:

  BF16 GEMM:  SPR+ → TDPBF16PS  (TD-T1 / TD-T1b in PR #182)
  INT8 GEMM:  SPR+ → TDPBUSD    (this commit)

Out of scope (separate PRs):
  * VPDPBUSD-zmm arm of matmul_i8_to_i32 for Cooper Lake / Cascade
    Lake / Zen 4+ (avx512vnni without AMX). The kernel function
    `vnni_dot_u8_i8` and `vnni_matvec` exist in simd_amx.rs; just
    need to assemble them into a m×n×k GEMM and wire as the
    middle dispatch tier (analogous to the VDPBF16PS arm in PR
    #182's bf16_gemm_dispatch).
  * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level
    surface from PR #182) — it's u8 × i8 natively so no sign-shift
    needed, simpler to wire than matmul_i8_to_i32.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b1979d799e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread src/hpc/int8_tile_gemm.rs Outdated
// byte — same 64-byte row width either way).
let cfg = TileConfig::for_dpbusd(64);
tile_loadconfig(&cfg);
tile_zero(0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve existing C accumulator in AMX int8 tile path

int8_tile_gemm_16x16 is documented (and its fallback implements) C += A·B, but the AMX branch unconditionally executes tile_zero(0) before compute and then stores tmm0 back to c, which discards any pre-existing values in c on AMX-enabled hosts. This creates hardware-dependent behavior: callers that rely on accumulation (e.g., blocked GEMM or repeated partial-K updates) will get incorrect results only when AMX is available.

Useful? React with 👍 / 👎.

claude added 3 commits May 21, 2026 05:52
Two clippy-as-error issues blocking PR #184 CI:

1. `src/hpc/int8_tile_gemm.rs:147` (mine, from b1979d7) —
   `clippy::unused_parens` flagged the closure body `(((i*11+5) % 256)
   as u8 as i8)` in the `fallback_matches_scalar_reference_k64` test.
   The outer parens around the cast chain are redundant; rustfmt
   re-broke the line to multi-line after removal so it stays readable.

2. `tests/par_rayon.rs:9` (pre-existing) — `clippy::manual_div_ceil`
   flagged `(M + CHUNK_SIZE - 1) / CHUNK_SIZE`. Replaced with
   `M.div_ceil(CHUNK_SIZE)` per the clippy hint. This file was
   already in tree; the lint became active in clippy 1.95 (Rust
   stable) which CI now uses, so prior PRs weren't blocked by it
   but the rayon-features test build is now.

Both fixes are mechanical / no behaviour change:
  * `cargo clippy --tests --features rayon,native -- -D warnings`
    clean.
  * `cargo fmt --all --check` clean.

Stashed work-in-progress on the VPDPBUSD-zmm middle tier for
`matmul_i8_to_i32` (the natural symmetric next step after TD-T2,
analogous to the VDPBF16PS arm shipped in PR #182's
`bf16_gemm_dispatch`); will follow up in a separate commit once
CI is unblocked.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Completes the per-CPU dispatch chain for `matmul_i8_to_i32`. Per
PR #180's table the middle tier between AMX TDPBUSD (Sapphire
Rapids+) and the scalar reference is `_mm512_dpbusd_epi32` (zmm
form, avx512vnni feature) — covers Cooper Lake, Cascade Lake, Ice
Lake-SP, Zen 4+ silicon that has AVX-512 VNNI but not AMX. Mirrors
the VDPBF16PS arm structure that landed for BF16 in PR #182's
`bf16_gemm_dispatch`.

New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm`:
  * One VPDPBUSD instruction: 16 i32 accumulator lanes, each
    receiving 4 u8×i8 products = 64 MACs per instruction.
  * Maps the 16 output lanes to a row of 16 j-columns of `c[i, ·]`,
    one i row processed at a time, K-quad inner loop accumulating
    into the same 16 i32 lanes across iterations.
  * B-column packing: pre-packs B for the current j-block into
    `b_col_quads[k_quad * 16 + j] = i32 (4 bytes of B[4k_quad..,
    j_base+j] packed bottom-to-top)` once per j-block; reused
    across all M i-iterations so the gather cost amortizes.
  * A row quad broadcast: `_mm512_set1_epi32` of (4 u8 bytes
    packed) every K-iter — same quad seen by every output column.
  * K-tail (k % 4 != 0) handled with scalar u8×i8 multiplies per
    output cell; N-tail (j_count < 16) handled by trimming the
    store width — padding lanes still receive VPDPBUSD updates
    but aren't written back.
  * Stable intrinsic `_mm512_dpbusd_epi32` under
    `target_feature = "avx512vnni,avx512f"` — no asm-byte needed.

Wiring `matmul_i8_to_i32` to three-tier dispatch:
  1. amx_available() + 16/16/64-aligned shapes
     → int8_tile_gemm_16x16 → TDPBUSD asm-byte (16 384 MACs/instr,
       this commit reuses the kernel from PR #184 fe334de... wait,
       same PR — from b1979d7 in THIS PR)
  2. is_x86_feature_detected!("avx512vnni")
     → int8_gemm_vpdpbusd_zmm → _mm512_dpbusd_epi32 stable
       intrinsic (64 MACs/instr, arbitrary shapes, K-tail handled
       scalar, N-tail handled by per-iteration j_count trim)
  3. scalar i8×i8 → i32 reference for non-x86, pre-AVX-512 hosts,
     or shapes that don't satisfy either SIMD tier's requirements

Factored the shared sign-shift bias subtraction into a private
`subtract_i8_to_u8_bias(c, b_i8, m, n, k)` helper: both Tier 1
(AMX) and Tier 2 (VNNI) shift LHS i8 → u8 via (+128) then need to
subtract 128·colsum(B) from the accumulator. Pure integer
arithmetic, bit-identical to the scalar i8×i8 → i32 reference.

Verification:
  * Default v3 build: 2093 lib tests pass (was 2092 — +1 new test
    `vpdpbusd_zmm_matches_scalar` that exercises the new arm
    directly with shapes spanning aligned cases, K-tail (k % 4),
    N-tail (n % 16), and small shapes; asserts byte-equal output
    vs scalar reference).
  * Existing `matmul_i8_to_i32_16x16_exact` continues to pass
    through the AMX tier on this host (which has amx_int8).
  * cargo clippy --lib --tests --features rayon,native -- -D warnings
    clean.
  * cargo fmt --all --check clean.

Per-CPU dispatch state after this commit:

  matmul_bf16_to_f32:  SPR+ AMX  | Zen4/CPL VDPBF16PS | scalar
                       (PR #182) | (PR #182)          | (always)
  matmul_f32:          SPR+ AMX  | Zen4/CPL VDPBF16PS | scalar
                       (PR #182) | (PR #182)          | (always)
  matmul_i8_to_i32:    SPR+ AMX  | CPL/Zen4 VPDPBUSD  | scalar
                       (b1979d7) | (THIS COMMIT)      | (always)

So all three of the public matmul entry points now have full
three-tier dispatch on x86_64.

Out of scope (separate PRs):
  * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level
    u8×i8 surface from PR #182) — it's u8×i8 natively, no sign-
    shift bias needed, simpler than matmul_i8_to_i32.
  * AVX-VNNI ymm arm (Arrow Lake / Meteor Lake U: avxvnni without
    avx512vnni) — the `vnni2_*` functions exist in simd_amx.rs but
    need to be assembled into a m×n×k VNNI-ymm GEMM. Same shape as
    the avx512vnni arm just with ymm width.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Per codex review on PR #184: `int8_tile_gemm_16x16` is documented
as `C += A·B` and the scalar `fallback_path` correctly accumulates,
but the AMX `amx_path` did `tile_zero(0)` + `tile_store(0, c, 64)`
which **overwrote** any pre-existing values in `c` on AMX-enabled
hosts. Hardware-dependent behavior: callers relying on accumulation
(blocked GEMM, repeated partial-K updates) would get incorrect
results only when AMX was active.

Same bug in the BF16 sibling `bf16_tile_gemm::amx_path` (shipped
in PR #104) — fixing both.

The fix:
  1. Add a `tile 0` case to `tile_load` (encoding `C4 E2 7B 4B 04
     08` — same SIB byte as the existing tmm1/tmm2 cases, ModR/M
     `04` = `mod=00, reg=000 (tmm0), r/m=100 (SIB follows)`).
  2. Both AMX paths replace `tile_zero(0)` with
     `tile_load(0, c.as_ptr() as *const u8, 64)` — preloads tmm0
     from caller's C buffer. TDPBUSD / TDPBF16PS then accumulate
     into the pre-loaded values; `tile_store(0, c, 64)` writes back
     the true `+=` result.

Consumer impact: zero. Both `matmul_bf16_to_f32` and
`matmul_i8_to_i32` (the only callers of these kernels in this
crate) `tile_c.fill(0)` before each call — so the now-accumulating
behavior + zero-initialized C = same result as the prior overwrite
semantics. The fix removes a latent trap for future blocked-GEMM /
partial-K consumers without changing any shipped behavior.

New regression test `amx_path_preserves_c_accumulator`: pre-loads
C with a known non-zero marker pattern, runs A·B where B=0 (so the
contribution is 0), asserts the marker is preserved. Would fail on
the pre-fix code because the tile_store would zero everything.
Passes on this host (which has amx_int8).

Verification:
  * 2094 lib tests pass (was 2093 — +1 regression test).
  * 11 amx_matmul tests pass (consumers' fill(0)-then-call pattern
    continues to produce correct results).
  * 2 bf16_tile_gemm tests pass.
  * 6 int8_tile_gemm tests pass.
  * cargo clippy --lib --tests --features rayon,native -- -D warnings
    clean.
  * cargo fmt --all --check clean.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
AdaWorldAPI pushed a commit that referenced this pull request May 21, 2026
…able

Rebased onto master post-#181, #182, #183. Replaces the polyfill-based
add_mul_f32/f64 with LazyLock-cached function pointers picking real
hardware FMA per silicon, and adds two more LazyLock-cached
primitives the consumer needs: is_amx_available() and vnni_dot_u8_i8.

WHY: F32x16::mul_add on AVX2 builds drops to per-lane scalar
f32::mul_add (simd_avx2.rs:586). The polyfill abstracts lane width
but cannot pick between _mm256_fmadd_ps and _mm512_fmadd_ps — that
is an instruction-family choice, not a lane-width one. LazyLock
amortises a one-time simd_caps() read into a frozen fn pointer;
every subsequent call is a single indirect jump with zero
is_x86_feature_detected! overhead. No SimdProfile exposed at the
consumer surface — agnostic contract preserved.

add_mul_f32(acc, a, b) — acc[i] += a[i]*b[i]
  AVX-512F+FMA  → _mm512_fmadd_ps 16-wide + 8-wide tail + scalar tail
  AVX2+FMA      → _mm256_fmadd_ps 8-wide + scalar tail
  NEON          → vfmaq_f32 4-wide + scalar tail
  scalar        → f32::mul_add per lane
  no_std build  → preserves the polyfill F32x16::mul_add path
                  (LazyLock requires std)

add_mul_f64(acc, a, b) — f64 sibling, same shape with 8/4/2 lanes.

is_amx_available() — wraps simd_amx::amx_available() (CPUID +
OSXSAVE + XCR0[17,18] + Linux arch_prctl(XCOMP_PERM)) in
LazyLock<bool>. The 4-step gate, including the syscall, fires
exactly once per process. Always false on non-x86_64.

vnni_dot_u8_i8(a, b) — i32 dot of u8 × i8 slices:
  AVX-512 VNNI  → delegates to simd_amx::vnni_dot_u8_i8 wrapped with
                  scalar tail handling (the existing kernel processes
                  only n - (n%64) since its cognitive-shader caller
                  pre-aligns rows; general-purpose callers need the
                  tail)
  AVX-VNNI 256  → delegates to simd_amx::vnni2_dot_u8_i8 directly
                  (that one already handles its scalar tail)
  scalar        → simd_amx::vnni_dot_u8_i8_scalar

No intrinsic code is duplicated. The dispatcher composes existing
simd_amx::* kernels (which #182/#184 also build on) into a safe
LazyLock-cached consumer-facing wrapper. simd_amx::matvec_dispatch
runs the same selection logic but uses is_x86_feature_detected! per
call; this wrapper amortises that to once at startup.

PARITY CONTRACT:
  - add_mul_f32 / add_mul_f64: bit-identical to f32::mul_add /
    f64::mul_add per lane via to_bits() assertion. All vector
    backends emit single-rounded IEEE-754 FMA.
  - vnni_dot_u8_i8: bit-identical i32 to scalar widen-and-multiply.
    VPDPBUSD does not saturate the accumulator (intermediate u8*i8
    products bounded by 32385, four-element sums by 129540).

Tests: 2101/2101 lib pass (7 new lazylock_dispatch_tests over 12
problem sizes / tail lengths). cargo clippy --lib clean under
default and --features cpu-spr. On Sapphire Rapids host the
LazyLock resolved to AVX-512+FMA for add_mul, AVX-512 VNNI for
vnni_dot; AMX is_amx_available returns false (hypervisor masks
XCR0[17,18]) — matches the Risk #3 demotion from 61b4563.

This commit was rebased atop master after the parallel session
shipped PR #182 (BF16 AMX tile kernels), #183 (F16C cast batch), and
prepared #184 (TDPBUSD int8 tile + matmul_i8_to_i32 wiring). The
earlier 469ecc7 (coarse + SimdTier) and 77e3971 (mul_add_f32_into +
walkback) and be65595 (is_amx_available + vnni_dot duplicating
intrinsics) are subsumed by this single clean commit: no public
SimdProfile / SimdTier re-export, no duplicated intrinsic code, no
mul_add_f32_into (master's add_mul_f32 shape is the right primitive).
@AdaWorldAPI AdaWorldAPI merged commit ddf0905 into master May 21, 2026
17 checks passed
AdaWorldAPI pushed a commit that referenced this pull request May 21, 2026
Extends the u8×i8 → i32 dispatch chain from PR #182's compile-time
cascade (avx512vnni → avxvnni → scalar) by adding a top-tier AMX
runtime check. Brings the SPR/GNR TDPBUSD path (16 384 MACs per
instruction) to the slice-level surface that downstream consumers
(lance-graph, etc.) use, completing the symmetry with PR #184's
matmul_i8_to_i32 wiring.

`gemm_u8_i8` is u8×i8 natively — no sign-shift bias trick needed
(unlike `matmul_i8_to_i32` which is i8×i8 and has to convert via
+128 then subtract `128·colsum(B)`). That makes the AMX path here
a direct call with no bias correction.

New helper `hpc::int8_tile_gemm::int8_gemm_amx_tiled(a_u8, b_i8,
c, m, n, k)` factors out the tile-decomposition logic that was
previously inlined in `matmul_i8_to_i32`. Both consumers now share
the same helper:

  matmul_i8_to_i32:
    1. shift A: i8 → u8 (+128)
    2. int8_gemm_amx_tiled(a_u8, b, c, m, n, k)
    3. subtract_i8_to_u8_bias(c, b, m, n, k)

  gemm_u8_i8 (AMX tier added in this commit):
    1. int8_gemm_amx_tiled(a, b, c, m, n, k) — no shift, no bias

The helper handles arbitrary 16/16/64-aligned shapes via a
j_tile × i_tile loop calling int8_tile_gemm_16x16 per (16, 16)
block. B sub-block extracted into K × 16 scratch once per j-tile,
reused across all M i-tiles. **Overwrite semantics**: c is written
not accumulated (the underlying int8_tile_gemm_16x16 accumulates
into its tile buffer, but we zero the tile buffer before each call
so the per-tile write to c is pure overwrite).

Dispatch placement in gemm_u8_i8:
  * Tier 0 (this commit): runtime amx_available() check at the
    top of the function. AMX requires CPUID + XCR0 + Linux prctl
    which can't fit a target_feature compile-time gate.
  * Tiers 1-3: existing compile-time cfg-cascade (avx512vnni zmm
    → avxvnni ymm → scalar i8_gemm_i32). Unchanged.

Misaligned shapes (m/n not multiples of 16, k not multiple of 64)
or non-AMX hosts fall through to the compile-time cascade as
before.

Also fixed pre-existing clippy::manual_is_multiple_of warnings
that surfaced in the new alignment check — switched from `% 16
== 0` to `.is_multiple_of(16)` etc. per the clippy hint (Rust
1.95 promoted this from `pedantic` to active warn).

Verification:
  * 2095 lib tests pass (was 2094 — +1 new
    `gemm_u8_i8_amx_aligned_32x32x128` test exercising the AMX
    arm with a 32×32×128 shape that hits the AMX tier on this
    host's amx_int8 silicon).
  * 11 amx_matmul tests pass (matmul_i8_to_i32 refactored to call
    the shared helper — same behavior as before).
  * 4 gemm_u8_i8 tests pass (the existing ones still hit the
    compile-time cascade since their shapes aren't AMX-aligned).
  * cargo clippy --lib --tests --features rayon,native -- -D warnings
    clean.
  * cargo fmt --all --check clean.

Per-CPU dispatch state after this commit:

  matmul_bf16_to_f32:  SPR+ AMX  | Zen4/CPL VDPBF16PS | scalar
                       (PR #182) | (PR #182)          | (always)
  matmul_f32:          SPR+ AMX  | Zen4/CPL VDPBF16PS | scalar
                       (PR #182) | (PR #182)          | (always)
  matmul_i8_to_i32:    SPR+ AMX  | CPL/Zen4 VPDPBUSD  | scalar
                       (PR #184) | (PR #184)          | (always)
  gemm_u8_i8 (slice):  SPR+ AMX  | CPL/Zen4 VPDPBUSD  | ARL ymm | scalar
                       (THIS)    | (PR #182)          | (PR #182) | (PR #182)

Out of scope (separate PRs):
  * AVX-VNNI ymm arm for matmul_i8_to_i32 — `vnni2_*` helpers
    exist in simd_amx.rs but need assembling into a m×n×k GEMM.
    Same shape as the avx512vnni arm just with ymm width.
  * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
AdaWorldAPI pushed a commit that referenced this pull request May 21, 2026
Completes the per-CPU dispatch chain for `matmul_i8_to_i32` by
adding the AVX-VNNI ymm tier — Arrow Lake, Meteor Lake U, Alder
Lake silicon that has AVX-VNNI but dropped AVX-512. Mirrors the
shape of the avx512vnni-zmm arm shipped in PR #184 with the
narrower 8-wide kernel.

New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm`:
  * One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator
    lanes, each receiving 4 u8×i8 products = 32 MACs per
    instruction. Half the throughput-per-instruction of the
    `_mm512_dpbusd_epi32` zmm version.
  * Same B-pre-pack scheme (quad-interleaved per 8-wide j-block),
    same K-tail / N-tail handling. Just narrower.
  * Stable intrinsic under `target_feature = "avxvnni,avx2"` — no
    asm-byte needed.

Wiring `matmul_i8_to_i32`'s dispatch as Tier 3:
  1. amx_available() + 16/16/64-aligned → AMX TDPBUSD
     (PR #184: int8_gemm_amx_tiled, 16 384 MACs/instr)
  2. is_x86_feature_detected!("avx512vnni") → VPDPBUSD-zmm
     (PR #184: int8_gemm_vpdpbusd_zmm, 64 MACs/instr)
  3. is_x86_feature_detected!("avxvnni") → VPDPBUSD-ymm
     (THIS COMMIT: int8_gemm_vpdpbusd_ymm, 32 MACs/instr)
  4. scalar i8×i8 → i32 reference (was Tier 3)

All three SIMD tiers share the sign-shift bias trick: shift LHS
i8 → u8 (+128), run the kernel, subtract 128·colsum(B). Same
`subtract_i8_to_u8_bias` helper (factored in PR #184).

New direct test `vpdpbusd_ymm_matches_scalar` mirrors the zmm
version's test: sweeps shapes spanning 8-aligned, K-tail (k % 4),
N-tail (n % 8), and small shapes, asserts byte-equal output vs
scalar reference.

Verification:
  * Default v3 (this host has avx512vnni so the new arm doesn't
    fire from matmul_i8_to_i32 — Tier 2 catches first): 2096 lib
    tests pass (was 2095 — +1 new direct test).
  * Direct test exercises int8_gemm_vpdpbusd_ymm on this host
    since avxvnni is present alongside avx512vnni.
  * cargo clippy --lib --tests --features rayon,native -- -D warnings
    clean.
  * cargo fmt --all --check clean.

Per-CPU dispatch state after this commit (final on the int8 side):

  matmul_i8_to_i32:  SPR+ AMX  | CPL/Zen4 zmm | ARL ymm | scalar
                     (PR #184) | (PR #184)    | (THIS)  | (always)

The matmul_i8_to_i32 column of PR #180's dispatch table is now
fully filled. The gemm_u8_i8 slice surface (in PR #185) already
has AVX-VNNI ymm via its existing compile-time cascade — both
i8-related public surfaces now cover every x86_64 tier with a
hardware-accelerated arm.

Out of scope (separate PRs):
  * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b, needs
    aarch64 CI runner verification.
  * TD-T6: real _mm256_* for AVX2 BLAS-1 (scal/nrm2/asum).

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
AdaWorldAPI pushed a commit that referenced this pull request May 21, 2026
**Alternative** to the compile-time cascade in `crate::simd::*` /
`crate::simd_ops::*`. **Additive**: gated under
`--features runtime-dispatch`, does not touch any existing path.
Mutually exclusive with `nightly-simd` (the portable-SIMD polyfill
replaces the architecture-specific intrinsics that the runtime
trampolines select between).

Use case: ship ONE binary that adapts across heterogeneous
deployment silicon (AVX-512 server + AVX2-only laptop + Arrow Lake
desktop + Sapphire Rapids workstation) from the same artifact. The
existing compile-time `v3` / `v4` / `native` / `nightly-simd`
configs target a single class of CPU per build; the runtime layer
targets the union via per-op LazyLock<fn ptr> trampolines.

Design from `.claude/knowledge/simd-dispatch-architecture.md` § 7.1
/ Phase 5, building on the precedent set by
`hpc::bgz17_bridge::{L1_KERNEL, L1_WEIGHTED_KERNEL, ...}`
(`LazyLock<L1Fn>` pattern, lines 75-86) already proven in tree.

# Dispatch model

One `LazyLock<fn ptr>` per public surface. First call fires the
closure which reads `simd_caps()` and selects a backend; every
subsequent call is one pointer-deref + indirect call. Per-call
overhead: ~2-3 ns (LazyLock atomic-acquire load that's cache-
resident after first hit + indirect-call branch-target predict).
Invisible against any SIMD op's actual work (~100+ cycles).

# Module layout

  src/simd_runtime/
    mod.rs       — module entry, mutual-exclusion check vs
                   nightly-simd, public re-exports
    vnni_dot.rs  — u8×i8 → i32 dot (the proposal's canonical
                   example): 3 backends, the AVX-512 arm
                   wraps `simd_amx::vnni_dot_u8_i8` with a
                   scalar tail because the existing kernel
                   silently drops n%64 lanes (its matvec
                   caller pre-aligns rows; a general-purpose
                   dispatch surface cannot assume that)
    add_mul.rs   — slice-level FMA (acc += a × b) for f32/f64;
                   the ONLY new kernel code in this module —
                   4 backends per type (avx512 / avx2+fma /
                   neon / scalar), each ~15 LoC of direct
                   intrinsics
    matmul.rs    — thin trampolines for matmul_bf16_to_f32 /
                   matmul_f32 / matmul_i8_to_i32 / gemm_u8_i8
                   delegating to existing functions that
                   already runtime-dispatch internally
                   (PR #182 / #184 / #185)
    casts.rs     — trampolines for the four half-precision
                   batch casts delegating to PR #183's already-
                   runtime-dispatched implementations

# Backend reuse — no kernel duplication

Every dispatch arm delegates to a kernel that already exists in
tree. The runtime layer is just the trampoline. The only NEW
kernel code is `add_mul_f32` / `add_mul_f64` (no pre-existing
slice-level FMA primitive in tree to delegate to — the compile-
time `crate::simd_ops::add_mul_f32` from PR #182 polyfills through
the F32x16 lane wrapper; the runtime version skips that
indirection for one more inlined intrinsic per chunk).

# Invariants preserved from this PR series

  * No-FP32-roundtrip on BF16/F16 arithmetic — backends respect
    the bit-exact mantissa rule
  * Asm-byte encoding for nightly-gated AMX / FP16 — selected
    backends keep their existing asm-byte fast paths
  * Little-endian byte contracts for half-precision carriers
  * Accumulator-preservation in tile paths (codex P1 from #184)
  * Boundary assertions on safe public fns (codex P1 from #185) —
    the public `vnni_dot_u8_i8(a, b)` etc. inherit the asserts
    transparently via the call chain

# Verification

  * Default build (no feature): 2087 lib tests pass — the
    `simd_runtime` module is gated out, zero impact on existing
    paths.
  * `cargo test --lib --features runtime-dispatch`: **2105 lib
    tests pass** (+8 new in `simd_runtime::*::tests`).
  * `cargo clippy --lib --tests --features rayon,native -- -D warnings`
    clean (default).
  * `cargo clippy --lib --tests --features rayon,native,runtime-dispatch
    -- -D warnings` clean.
  * `cargo fmt --all --check` clean.
  * Mutual-exclusion enforced via `compile_error!` in
    `simd_runtime/mod.rs` — `--features runtime-dispatch,nightly-simd`
    fails to compile with a clear error.

# What's NOT in this PR (deferred)

  * Sweep the remaining ~15-20 SIMD/HPC public surfaces (min_i8,
    max_i8, add_i8, dot_i8, etc.). Each is ~30-50 LoC of trampoline;
    pattern is established here. Estimated ~700-900 more LoC across
    the full surface map.
  * CI matrix entry for `runtime-dispatch-portable` (per
    simd-dispatch-architecture.md § 7 / TD-SIMD-9). Job builds
    with `--features runtime-dispatch` on a v3 baseline runner and
    asserts every trampoline lands on its expected backend.
  * `simd_caps()` snapshot logging at process start (debug-only)
    to aid release-binary deployment debugging — "which arm did
    you actually pick?"

# Cost summary

  src/simd_runtime/                 +537 LoC (4 modules)
  src/lib.rs                        +9 LoC (cfg-gated mod decl)
  Cargo.toml                        +21 LoC (feature decl + doc)
  Total                             ~570 LoC

  Trampoline LoC per surface (this PR's sample):
    vnni_dot         170 LoC (LazyLock + 3 arms + wrapper + tests)
    add_mul (f32+f64)218 LoC (LazyLock×2 + 4 arms×2 + tests — the ONLY new kernels)
    matmul (4 ops)   100 LoC (thin delegations + tests)
    casts (4 ops)     75 LoC (thin delegations + tests)

Out-of-tree estimate for the full sweep (per § 7 of the design
doc): ~1400 LoC total once all ~25 public SIMD/HPC surfaces are
wired. This PR establishes ~40% of that budget with the canonical
patterns.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants