Commit bb7b9b7
committed
feat(hpc): VPDPBUSD-zmm middle tier for matmul_i8_to_i32
Completes the per-CPU dispatch chain for `matmul_i8_to_i32`. Per
PR #180's table the middle tier between AMX TDPBUSD (Sapphire
Rapids+) and the scalar reference is `_mm512_dpbusd_epi32` (zmm
form, avx512vnni feature) — covers Cooper Lake, Cascade Lake, Ice
Lake-SP, Zen 4+ silicon that has AVX-512 VNNI but not AMX. Mirrors
the VDPBF16PS arm structure that landed for BF16 in PR #182's
`bf16_gemm_dispatch`.
New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm`:
* One VPDPBUSD instruction: 16 i32 accumulator lanes, each
receiving 4 u8×i8 products = 64 MACs per instruction.
* Maps the 16 output lanes to a row of 16 j-columns of `c[i, ·]`,
one i row processed at a time, K-quad inner loop accumulating
into the same 16 i32 lanes across iterations.
* B-column packing: pre-packs B for the current j-block into
`b_col_quads[k_quad * 16 + j] = i32 (4 bytes of B[4k_quad..,
j_base+j] packed bottom-to-top)` once per j-block; reused
across all M i-iterations so the gather cost amortizes.
* A row quad broadcast: `_mm512_set1_epi32` of (4 u8 bytes
packed) every K-iter — same quad seen by every output column.
* K-tail (k % 4 != 0) handled with scalar u8×i8 multiplies per
output cell; N-tail (j_count < 16) handled by trimming the
store width — padding lanes still receive VPDPBUSD updates
but aren't written back.
* Stable intrinsic `_mm512_dpbusd_epi32` under
`target_feature = "avx512vnni,avx512f"` — no asm-byte needed.
Wiring `matmul_i8_to_i32` to three-tier dispatch:
1. amx_available() + 16/16/64-aligned shapes
→ int8_tile_gemm_16x16 → TDPBUSD asm-byte (16 384 MACs/instr,
this commit reuses the kernel from PR #184 fe334de... wait,
same PR — from b1979d7 in THIS PR)
2. is_x86_feature_detected!("avx512vnni")
→ int8_gemm_vpdpbusd_zmm → _mm512_dpbusd_epi32 stable
intrinsic (64 MACs/instr, arbitrary shapes, K-tail handled
scalar, N-tail handled by per-iteration j_count trim)
3. scalar i8×i8 → i32 reference for non-x86, pre-AVX-512 hosts,
or shapes that don't satisfy either SIMD tier's requirements
Factored the shared sign-shift bias subtraction into a private
`subtract_i8_to_u8_bias(c, b_i8, m, n, k)` helper: both Tier 1
(AMX) and Tier 2 (VNNI) shift LHS i8 → u8 via (+128) then need to
subtract 128·colsum(B) from the accumulator. Pure integer
arithmetic, bit-identical to the scalar i8×i8 → i32 reference.
Verification:
* Default v3 build: 2093 lib tests pass (was 2092 — +1 new test
`vpdpbusd_zmm_matches_scalar` that exercises the new arm
directly with shapes spanning aligned cases, K-tail (k % 4),
N-tail (n % 16), and small shapes; asserts byte-equal output
vs scalar reference).
* Existing `matmul_i8_to_i32_16x16_exact` continues to pass
through the AMX tier on this host (which has amx_int8).
* cargo clippy --lib --tests --features rayon,native -- -D warnings
clean.
* cargo fmt --all --check clean.
Per-CPU dispatch state after this commit:
matmul_bf16_to_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar
(PR #182) | (PR #182) | (always)
matmul_f32: SPR+ AMX | Zen4/CPL VDPBF16PS | scalar
(PR #182) | (PR #182) | (always)
matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 VPDPBUSD | scalar
(b1979d7) | (THIS COMMIT) | (always)
So all three of the public matmul entry points now have full
three-tier dispatch on x86_64.
Out of scope (separate PRs):
* AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level
u8×i8 surface from PR #182) — it's u8×i8 natively, no sign-
shift bias needed, simpler than matmul_i8_to_i32.
* AVX-VNNI ymm arm (Arrow Lake / Meteor Lake U: avxvnni without
avx512vnni) — the `vnni2_*` functions exist in simd_amx.rs but
need to be assembled into a m×n×k VNNI-ymm GEMM. Same shape as
the avx512vnni arm just with ymm width.
https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u1 parent 33a2bbb commit bb7b9b7
2 files changed
Lines changed: 183 additions & 23 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
586 | 586 | | |
587 | 587 | | |
588 | 588 | | |
589 | | - | |
590 | | - | |
591 | | - | |
592 | | - | |
593 | | - | |
| 589 | + | |
| 590 | + | |
594 | 591 | | |
595 | 592 | | |
596 | | - | |
597 | | - | |
598 | 593 | | |
599 | 594 | | |
600 | 595 | | |
601 | 596 | | |
602 | | - | |
603 | 597 | | |
604 | 598 | | |
605 | 599 | | |
| |||
609 | 603 | | |
610 | 604 | | |
611 | 605 | | |
612 | | - | |
613 | 606 | | |
614 | 607 | | |
615 | 608 | | |
616 | 609 | | |
617 | 610 | | |
618 | 611 | | |
619 | | - | |
620 | | - | |
621 | | - | |
622 | | - | |
623 | | - | |
624 | | - | |
625 | | - | |
626 | | - | |
627 | | - | |
628 | | - | |
629 | | - | |
630 | | - | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
631 | 621 | | |
| 622 | + | |
632 | 623 | | |
633 | | - | |
634 | | - | |
| 624 | + | |
| 625 | + | |
| 626 | + | |
635 | 627 | | |
636 | 628 | | |
637 | 629 | | |
| |||
653 | 645 | | |
654 | 646 | | |
655 | 647 | | |
| 648 | + | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
| 655 | + | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
656 | 669 | | |
657 | 670 | | |
658 | 671 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
101 | 101 | | |
102 | 102 | | |
103 | 103 | | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
104 | 209 | | |
105 | 210 | | |
106 | 211 | | |
| |||
192 | 297 | | |
193 | 298 | | |
194 | 299 | | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
195 | 342 | | |
196 | 343 | | |
197 | 344 | | |
| |||
0 commit comments