simd_int_ops, hpc: AMX TDPBUSD arm for gemm_u8_i8 slice surface#185
Conversation
Extends the u8×i8 → i32 dispatch chain from PR #182's compile-time cascade (avx512vnni → avxvnni → scalar) by adding a top-tier AMX runtime check. Brings the SPR/GNR TDPBUSD path (16 384 MACs per instruction) to the slice-level surface that downstream consumers (lance-graph, etc.) use, completing the symmetry with PR #184's matmul_i8_to_i32 wiring. `gemm_u8_i8` is u8×i8 natively — no sign-shift bias trick needed (unlike `matmul_i8_to_i32` which is i8×i8 and has to convert via +128 then subtract `128·colsum(B)`). That makes the AMX path here a direct call with no bias correction. New helper `hpc::int8_tile_gemm::int8_gemm_amx_tiled(a_u8, b_i8, c, m, n, k)` factors out the tile-decomposition logic that was previously inlined in `matmul_i8_to_i32`. Both consumers now share the same helper: matmul_i8_to_i32: 1. shift A: i8 → u8 (+128) 2. int8_gemm_amx_tiled(a_u8, b, c, m, n, k) 3. subtract_i8_to_u8_bias(c, b, m, n, k) gemm_u8_i8 (AMX tier added in this commit): 1. int8_gemm_amx_tiled(a, b, c, m, n, k) — no shift, no bias The helper handles arbitrary 16/16/64-aligned shapes via a j_tile × i_tile loop calling int8_tile_gemm_16x16 per (16, 16) block. B sub-block extracted into K × 16 scratch once per j-tile, reused across all M i-tiles. **Overwrite semantics**: c is written not accumulated (the underlying int8_tile_gemm_16x16 accumulates into its tile buffer, but we zero the tile buffer before each call so the per-tile write to c is pure overwrite). Dispatch placement in gemm_u8_i8: * Tier 0 (this commit): runtime amx_available() check at the top of the function. AMX requires CPUID + XCR0 + Linux prctl which can't fit a target_feature compile-time gate. * Tiers 1-3: existing compile-time cfg-cascade (avx512vnni zmm → avxvnni ymm → scalar i8_gemm_i32). Unchanged. Misaligned shapes (m/n not multiples of 16, k not multiple of 64) or non-AMX hosts fall through to the compile-time cascade as before. Also fixed pre-existing clippy::manual_is_multiple_of warnings that surfaced in the new alignment check — switched from `% 16 == 0` to `.is_multiple_of(16)` etc. per the clippy hint (Rust 1.95 promoted this from `pedantic` to active warn). Verification: * 2095 lib tests pass (was 2094 — +1 new `gemm_u8_i8_amx_aligned_32x32x128` test exercising the AMX arm with a 32×32×128 shape that hits the AMX tier on this host's amx_int8 silicon). * 11 amx_matmul tests pass (matmul_i8_to_i32 refactored to call the shared helper — same behavior as before). * 4 gemm_u8_i8 tests pass (the existing ones still hit the compile-time cascade since their shapes aren't AMX-aligned). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit: matmul_bf16_to_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar (PR #182) | (PR #182) | (always) matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 VPDPBUSD | scalar (PR #184) | (PR #184) | (always) gemm_u8_i8 (slice): SPR+ AMX | CPL/Zen4 VPDPBUSD | ARL ymm | scalar (THIS) | (PR #182) | (PR #182) | (PR #182) Out of scope (separate PRs): * AVX-VNNI ymm arm for matmul_i8_to_i32 — `vnni2_*` helpers exist in simd_amx.rs but need assembling into a m×n×k GEMM. Same shape as the avx512vnni arm just with ymm width. * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f8e94531ca
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| for kk in 0..k { | ||
| let row = kk * n + j_tile; | ||
| b_tile[kk * 16..(kk + 1) * 16] | ||
| .copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) }); |
There was a problem hiding this comment.
Validate slice lengths before unchecked B-tile reads
int8_gemm_amx_tiled is a safe public function but it reads b_i8 via from_raw_parts(b_i8.as_ptr().add(row), 16) without any length checks, so callers that pass mismatched (m,n,k) and slice lengths can trigger out-of-bounds reads/UB instead of a panic. This was previously only in a private call path with shape checks, but the new helper is now reusable from other call sites; adding explicit a_u8/b_i8/c length assertions (like gemm_u8_i8 does) or avoiding unchecked raw slices would keep the API memory-safe.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in e8f9ce0. The concern is valid — int8_gemm_amx_tiled is a safe pub fn but its inner loop reached for b_i8 via raw core::slice::from_raw_parts(...) arithmetic without any length check, so mismatched (m, n, k) vs slice lengths could trigger out-of-bounds reads / UB instead of a clean panic. The function lived as a private call path inside matmul_i8_to_i32 before being factored out for gemm_u8_i8 in this PR; the public refactor inherited the missing boundary check.
The fix mirrors gemm_u8_i8's contract — three boundary assertions at function entry:
assert!(a_u8.len() >= m * k, "int8_gemm_amx_tiled: a_u8.len()={} < m*k={}", a_u8.len(), m * k);
assert!(b_i8.len() >= k * n, "int8_gemm_amx_tiled: b_i8.len()={} < k*n={}", b_i8.len(), k * n);
assert!(c.len() >= m * n, "int8_gemm_amx_tiled: c.len()={} < m*n={}", c.len(), m * n);— plus replacing the unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) } line with safe &b_i8[row..row + 16] slicing. The inner-loop bounds check should fold away once the entry-level asserts prove the invariant, but either way the public surface is now panicking-safe rather than UB-on-misuse.
New regression test amx_tiled_panics_on_undersized_b constructs a b: Vec<i8> half a j-tile shorter than what (m, n, k) claims, calls the helper, and asserts the expected panic message via #[should_panic(expected = "b_i8.len()")]. Works on any host since the boundary check fires before the debug_assert!(amx_available()), so non-AMX CI runners catch the same regression. The matmul_i8_to_i32 call site inherits the assertions transparently — its existing tests continue to pass on valid input with no behavior change.
The # Panics docstring was updated to list the boundary panics alongside the existing debug-only AMX / alignment assertions.
Generated by Claude Code
Completes the per-CPU dispatch chain for `matmul_i8_to_i32` by adding the AVX-VNNI ymm tier — Arrow Lake, Meteor Lake U, Alder Lake silicon that has AVX-VNNI but dropped AVX-512. Mirrors the shape of the avx512vnni-zmm arm shipped in PR #184 with the narrower 8-wide kernel. New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm`: * One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes, each receiving 4 u8×i8 products = 32 MACs per instruction. Half the throughput-per-instruction of the `_mm512_dpbusd_epi32` zmm version. * Same B-pre-pack scheme (quad-interleaved per 8-wide j-block), same K-tail / N-tail handling. Just narrower. * Stable intrinsic under `target_feature = "avxvnni,avx2"` — no asm-byte needed. Wiring `matmul_i8_to_i32`'s dispatch as Tier 3: 1. amx_available() + 16/16/64-aligned → AMX TDPBUSD (PR #184: int8_gemm_amx_tiled, 16 384 MACs/instr) 2. is_x86_feature_detected!("avx512vnni") → VPDPBUSD-zmm (PR #184: int8_gemm_vpdpbusd_zmm, 64 MACs/instr) 3. is_x86_feature_detected!("avxvnni") → VPDPBUSD-ymm (THIS COMMIT: int8_gemm_vpdpbusd_ymm, 32 MACs/instr) 4. scalar i8×i8 → i32 reference (was Tier 3) All three SIMD tiers share the sign-shift bias trick: shift LHS i8 → u8 (+128), run the kernel, subtract 128·colsum(B). Same `subtract_i8_to_u8_bias` helper (factored in PR #184). New direct test `vpdpbusd_ymm_matches_scalar` mirrors the zmm version's test: sweeps shapes spanning 8-aligned, K-tail (k % 4), N-tail (n % 8), and small shapes, asserts byte-equal output vs scalar reference. Verification: * Default v3 (this host has avx512vnni so the new arm doesn't fire from matmul_i8_to_i32 — Tier 2 catches first): 2096 lib tests pass (was 2095 — +1 new direct test). * Direct test exercises int8_gemm_vpdpbusd_ymm on this host since avxvnni is present alongside avx512vnni. * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit (final on the int8 side): matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 zmm | ARL ymm | scalar (PR #184) | (PR #184) | (THIS) | (always) The matmul_i8_to_i32 column of PR #180's dispatch table is now fully filled. The gemm_u8_i8 slice surface (in PR #185) already has AVX-VNNI ymm via its existing compile-time cascade — both i8-related public surfaces now cover every x86_64 tier with a hardware-accelerated arm. Out of scope (separate PRs): * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b, needs aarch64 CI runner verification. * TD-T6: real _mm256_* for AVX2 BLAS-1 (scal/nrm2/asum). https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Per codex review on PR #185: `int8_gemm_amx_tiled` is a safe public function (no `unsafe` in the signature) but its inner loop read `b_i8` via `core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16)` without any length check. Callers passing mismatched (m, n, k) vs slice lengths could trigger out-of-bounds reads / UB instead of a panic. Before PR #185 this logic lived only in `matmul_i8_to_i32`'s private AMX arm (where the public `pack_contig` preceded it and bounded everything), but the factored helper is now reachable from `gemm_u8_i8` and any future caller. Fix: 1. Add three boundary assertions at function entry matching `gemm_u8_i8`'s contract: a_u8.len() >= m * k b_i8.len() >= k * n c.len() >= m * n These panic with descriptive messages on undersized input — the safety contract is now enforced at the public function boundary, not at the unsafe pointer-arithmetic site inside the hot loop. 2. Replace the `unsafe { core::slice::from_raw_parts(...) }` B-pack line with safe `b_tile[..].copy_from_slice(&b_i8[row..row + 16])`. The bounds-check inside the loop is now redundant given the function-entry assertions, but the compiler should elide it once the invariant is proven; either way the code becomes panicking- safe instead of UB-on-misuse. 3. Update the doc-comment `# Panics` section to list the boundary panics alongside the existing debug-only AMX / alignment assertions. New regression test `amx_tiled_panics_on_undersized_b`: * Constructs `b: Vec<i8>` half-a-j_tile shorter than the claimed `k * n`. * Calls `int8_gemm_amx_tiled` and asserts the expected panic fires before any unsafe slice arithmetic. * `#[should_panic(expected = "b_i8.len()")]` catches the exact assertion message; works on any host (the boundary check fires before the `debug_assert!(amx_available())` so the test passes on AMX-less CI runners too). Verification: * 2097 lib tests pass (was 2096 — +1 new regression test). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. The matmul_i8_to_i32 path that delegates to int8_gemm_amx_tiled inherits the assertions transparently via the call chain. No behavior change for valid input — only mismatched-shape callers that would have hit UB now get a clean panic instead. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
**Alternative** to the compile-time cascade in `crate::simd::*` /
`crate::simd_ops::*`. **Additive**: gated under
`--features runtime-dispatch`, does not touch any existing path.
Mutually exclusive with `nightly-simd` (the portable-SIMD polyfill
replaces the architecture-specific intrinsics that the runtime
trampolines select between).
Use case: ship ONE binary that adapts across heterogeneous
deployment silicon (AVX-512 server + AVX2-only laptop + Arrow Lake
desktop + Sapphire Rapids workstation) from the same artifact. The
existing compile-time `v3` / `v4` / `native` / `nightly-simd`
configs target a single class of CPU per build; the runtime layer
targets the union via per-op LazyLock<fn ptr> trampolines.
Design from `.claude/knowledge/simd-dispatch-architecture.md` § 7.1
/ Phase 5, building on the precedent set by
`hpc::bgz17_bridge::{L1_KERNEL, L1_WEIGHTED_KERNEL, ...}`
(`LazyLock<L1Fn>` pattern, lines 75-86) already proven in tree.
# Dispatch model
One `LazyLock<fn ptr>` per public surface. First call fires the
closure which reads `simd_caps()` and selects a backend; every
subsequent call is one pointer-deref + indirect call. Per-call
overhead: ~2-3 ns (LazyLock atomic-acquire load that's cache-
resident after first hit + indirect-call branch-target predict).
Invisible against any SIMD op's actual work (~100+ cycles).
# Module layout
src/simd_runtime/
mod.rs — module entry, mutual-exclusion check vs
nightly-simd, public re-exports
vnni_dot.rs — u8×i8 → i32 dot (the proposal's canonical
example): 3 backends, the AVX-512 arm
wraps `simd_amx::vnni_dot_u8_i8` with a
scalar tail because the existing kernel
silently drops n%64 lanes (its matvec
caller pre-aligns rows; a general-purpose
dispatch surface cannot assume that)
add_mul.rs — slice-level FMA (acc += a × b) for f32/f64;
the ONLY new kernel code in this module —
4 backends per type (avx512 / avx2+fma /
neon / scalar), each ~15 LoC of direct
intrinsics
matmul.rs — thin trampolines for matmul_bf16_to_f32 /
matmul_f32 / matmul_i8_to_i32 / gemm_u8_i8
delegating to existing functions that
already runtime-dispatch internally
(PR #182 / #184 / #185)
casts.rs — trampolines for the four half-precision
batch casts delegating to PR #183's already-
runtime-dispatched implementations
# Backend reuse — no kernel duplication
Every dispatch arm delegates to a kernel that already exists in
tree. The runtime layer is just the trampoline. The only NEW
kernel code is `add_mul_f32` / `add_mul_f64` (no pre-existing
slice-level FMA primitive in tree to delegate to — the compile-
time `crate::simd_ops::add_mul_f32` from PR #182 polyfills through
the F32x16 lane wrapper; the runtime version skips that
indirection for one more inlined intrinsic per chunk).
# Invariants preserved from this PR series
* No-FP32-roundtrip on BF16/F16 arithmetic — backends respect
the bit-exact mantissa rule
* Asm-byte encoding for nightly-gated AMX / FP16 — selected
backends keep their existing asm-byte fast paths
* Little-endian byte contracts for half-precision carriers
* Accumulator-preservation in tile paths (codex P1 from #184)
* Boundary assertions on safe public fns (codex P1 from #185) —
the public `vnni_dot_u8_i8(a, b)` etc. inherit the asserts
transparently via the call chain
# Verification
* Default build (no feature): 2087 lib tests pass — the
`simd_runtime` module is gated out, zero impact on existing
paths.
* `cargo test --lib --features runtime-dispatch`: **2105 lib
tests pass** (+8 new in `simd_runtime::*::tests`).
* `cargo clippy --lib --tests --features rayon,native -- -D warnings`
clean (default).
* `cargo clippy --lib --tests --features rayon,native,runtime-dispatch
-- -D warnings` clean.
* `cargo fmt --all --check` clean.
* Mutual-exclusion enforced via `compile_error!` in
`simd_runtime/mod.rs` — `--features runtime-dispatch,nightly-simd`
fails to compile with a clear error.
# What's NOT in this PR (deferred)
* Sweep the remaining ~15-20 SIMD/HPC public surfaces (min_i8,
max_i8, add_i8, dot_i8, etc.). Each is ~30-50 LoC of trampoline;
pattern is established here. Estimated ~700-900 more LoC across
the full surface map.
* CI matrix entry for `runtime-dispatch-portable` (per
simd-dispatch-architecture.md § 7 / TD-SIMD-9). Job builds
with `--features runtime-dispatch` on a v3 baseline runner and
asserts every trampoline lands on its expected backend.
* `simd_caps()` snapshot logging at process start (debug-only)
to aid release-binary deployment debugging — "which arm did
you actually pick?"
# Cost summary
src/simd_runtime/ +537 LoC (4 modules)
src/lib.rs +9 LoC (cfg-gated mod decl)
Cargo.toml +21 LoC (feature decl + doc)
Total ~570 LoC
Trampoline LoC per surface (this PR's sample):
vnni_dot 170 LoC (LazyLock + 3 arms + wrapper + tests)
add_mul (f32+f64)218 LoC (LazyLock×2 + 4 arms×2 + tests — the ONLY new kernels)
matmul (4 ops) 100 LoC (thin delegations + tests)
casts (4 ops) 75 LoC (thin delegations + tests)
Out-of-tree estimate for the full sweep (per § 7 of the design
doc): ~1400 LoC total once all ~25 public SIMD/HPC surfaces are
wired. This PR establishes ~40% of that budget with the canonical
patterns.
https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
… CPU table
Two additions, scoped together because they're the same idea — using
scraped CPU metadata to drive runtime dispatch:
# Piece A: matrix doc § M (GCC-grounded aarch64 enumeration)
The matrix had three aarch64 columns (A53 / A72 / A76) covering
*dispatch tiers* (multiple physical cores share each tier's SIMD
primitive set). The authoritative per-core feature membership lives
in GCC's `gcc/config/aarch64/aarch64-cores.def` — scraped 2026-05-21
and recorded as a new § M table covering 28 cores:
* V8.0-A baseline (A53, A72)
* V8.2-A dotprod+fp16 (A76, A78, X1, Neoverse-N1, Apple M1)
* V8.5-A baseline (Apple M1 specifically — V8.5 includes V8.2's
fp16+dotprod but NOT bf16+i8mm; corrects a wrong "+bf16" claim
on the existing A76 row of the column legend)
* V8.6-A baseline incl. bf16+i8mm (Apple M2/M3, Oryon-1 / Snapdragon
X Elite, Ampere1+, Cortex-A510/A710/A715, X2/X3, Neoverse-N2/V2)
* V8.7-A (Apple M4, Ampere1B)
* V9.0-A SVE2 baseline + explicit bf16+i8mm flags (Cortex-A510-A715,
X2/X3, Neoverse-N2/V2, Grace)
* V8.4-A SVE tier (Neoverse-V1 / Graviton 3 — only V8.4 core with
explicit SVE+bf16+i8mm)
* V9.2-A (Cortex-A520/A720/A725, X4, X925, Neoverse-N3/V3)
Each entry verbatim from the GCC FEATURE_STRING column. Cross-
referencing with the V8.X-A baseline rules (V8.6+ includes bf16+i8mm
implicitly; V9.0 includes SVE2 implicitly) gives the canonical
"which silicon has what" table. The note flags that a new dispatch
column for the V8.6+/V9-bf16-i8mm tier needs to land alongside the
NEON BFMMLA / BFDOT asm-byte arm in Phase 3b.
The A76 column legend (line 26 of the matrix) was corrected: removed
the wrong "+bf16" (A76 itself is V8.2-A, NO bf16 — bf16 came in
V8.6-A).
# Piece B: CpuOps DTO — third dispatch pattern
Adds `src/simd_runtime/cpu_ops.rs` exposing a per-CPU operations DTO
distinct from the existing patterns:
Pattern 1 (`crate::simd::*`): compile-time `#[cfg(target_feature)]`
cascade. Direct monomorphized calls.
Pattern 2 (`crate::simd_runtime::vnni_dot_u8_i8` etc., from #185):
per-op LazyLock<fn ptr>. One CPUID +
atomic-load per op the first time
called.
Pattern 3 (THIS COMMIT): per-CPU `&'static CpuOps` selected
once at first access. Every op is a
fn-ptr field on the struct.
Why the third pattern?
* Per-op LazyLock: N ops touched = N atomic-load setup costs over
the process lifetime.
* CpuOps DTO: ONE atomic-load total at first `cpu_ops()` call;
every subsequent op is a direct fn-ptr deref through the cached
`&'static CpuOps`. The OpenBLAS / MKL dispatch model — wins for
dense-op consumers (linear-algebra pipelines touching every
BLAS-1/2/3 kernel).
* All three coexist. Consumers pick by import path.
Six tiers baked as static const `CpuOps` instances:
x86_64: amx_int8, avx512vnni, avx512f, avxvnni, avx2_fma
aarch64: neon
universal: scalar
Each instance points at the existing trampolines in
`crate::simd_runtime::{vnni_dot, add_mul}` — no kernel duplication;
this module is pure dispatch glue. Backend ops referenced:
vnni_dot_u8_i8 (3 backends: avx512+tail / avxvnni / scalar)
add_mul_f32 (4 backends: avx512 / avx2+fma / neon / scalar)
add_mul_f64 (4 backends: avx512 / avx2+fma / neon / scalar)
# The naughty data-driven part
`cpu_ops_for_cpu(name: &str) -> Option<&'static CpuOps>` maps GCC
CPU codenames to the dispatch tier they land in, sourced from § M's
scrape. Spot-checks (each verified by the test suite):
sapphirerapids / graniterapids / emeraldrapids → amx_int8
cascadelake / cooperlake / icelake-* / tigerlake / rocketlake
/ znver4 / znver5 → avx512vnni
alderlake / raptorlake / meteorlake / arrowlake / arrowlake-s
/ lunarlake / pantherlake / sierraforest → avxvnni
haswell / broadwell / skylake / znver1-3 → avx2_fma
apple-m1..m4 / oryon-1 / cortex-a76..a725
/ cortex-x1..x925 / neoverse-n1..v3 / grace
/ ampere1..1b → neon
Returns `None` for unknown CPUs — caller can fall back to
`cpu_ops_for_tier("scalar")` if a "best-effort" answer is needed.
Use cases for `cpu_ops_for_cpu`:
* "What would $CPU pick?" introspection without running on $CPU.
* Cross-compilation reports + deployment-planning tools.
* Integration tests asserting tier selection for named targets.
* Explicit-tier-pinning ("force AVX2 even though AMX is available,
to measure overhead").
Future: code-gen the table from a `build.rs` that fetches GCC's
latest core list. Today the table is hand-rolled from the scrape
recorded in matrix doc § M.
# Verification
* `cargo test --lib --features runtime-dispatch`: 2147 tests pass
(was 2105 — +5 new cpu_ops tests + 37 carried over from prior
feature-gated tests now compiled-in too).
* 5 new cpu_ops tests:
cpu_ops_resolves_on_this_host
cpu_ops_stable_across_calls (LazyLock fires once)
cpu_ops_for_tier_known_names
cpu_ops_for_cpu_data_driven_lookup (spot-checks the GCC scrape)
cpu_ops_call_through_dto (full indirect-call exercise)
* cargo clippy --lib --tests --features rayon,native,runtime-dispatch
-- -D warnings clean.
* cargo fmt --all --check clean.
* Default build (no feature) unchanged: zero impact on existing
paths — the entire `simd_runtime` module is gated out.
# Backward-compat for the existing per-op LazyLock surface
The pub(super) wrappers in `vnni_dot.rs` and `add_mul.rs`
(`*_safe` / `*_safe_wrapper` / `*_scalar_wrapper`) are new but
purely additive — every existing public function in `simd_runtime`
keeps its prior signature and dispatch behavior.
https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Summary
Extends the
u8 × i8 → i32dispatch chain from PR #182's compile-time cascade by adding a top-tier AMX runtime check. Brings the SPR/GNR TDPBUSD path (16 384 MACs per instruction) to the slice-level surface that downstream consumers use, completing the symmetry with PR #184'smatmul_i8_to_i32wiring.gemm_u8_i8is u8×i8 natively — no sign-shift bias trick needed (unlikematmul_i8_to_i32which is i8×i8 and has to convert via +128 then subtract128·colsum(B)). The AMX path here is a direct call with no bias correction.New helper
hpc::int8_tile_gemm::int8_gemm_amx_tiledFactors out the tile-decomposition logic that was previously inlined in
matmul_i8_to_i32. Both consumers now share the same helper:Handles arbitrary 16/16/64-aligned shapes via a
j_tile × i_tileloop callingint8_tile_gemm_16x16per (16, 16) block. B sub-block extracted into K × 16 scratch once per j-tile, reused across all M i-tiles. Overwrite semantics: c is written not accumulated.Dispatch in
gemm_u8_i8amx_available()check at the top (AMX requires CPUID + XCR0 + Linuxprctl— can't fit atarget_featurecompile-time gate).int8_gemm_i32).Misaligned shapes or non-AMX hosts fall through to the compile-time cascade.
Drive-by clippy fix
Pre-existing
clippy::manual_is_multiple_ofwarnings surfaced in the new alignment check — switched from% 16 == 0to.is_multiple_of(16)per the clippy hint (Rust 1.95 promoted this frompedanticto active warn).Test plan
gemm_u8_i8_amx_aligned_32x32x128— new test, exercises the AMX arm with a 32×32×128 shape that hits the AMX tier on this host'samx_int8silicon.cargo clippy --lib --tests --features rayon,native -- -D warningsclean.cargo fmt --all --checkclean.Per-CPU dispatch state after this PR
matmul_bf16_to_f32matmul_f32(BF16 compute)matmul_i8_to_i32gemm_u8_i8(slice)Out of scope (separate PRs)
matmul_i8_to_i32— thevnni2_*helpers exist insimd_amx.rsbut need assembling into an m×n×k GEMM. Same shape as the avx512vnni arm just with ymm width.https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Generated by Claude Code