feat(simd): BF16x16 + F16x16 SIMD vectors + slice ops (sprint W3-A)#126
Conversation
…int W3-A) Closes parity items 2 + 3. Scalar dispatch (upcast f32 -> op -> downcast). SIMD-accelerated paths (AVX2 emulation, AVX-512-BF16 native, NEON +fp16) are a follow-up. The scalar implementation is correct and portable, and unblocks burn's NdArrayElement bound for half types. - src/simd_half.rs: 691 LOC new module - src/lib.rs: pub mod simd_half declaration - src/simd.rs: re-exports 21 new tests, all passing. Total lib tests: 1817+ pass. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3358057a9c
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| #[cfg(all(feature = "std", not(all(target_arch = "x86_64", target_feature = "avx512bf16"))))] | ||
| pub use crate::simd_half::BF16x16 as BF16x16; |
There was a problem hiding this comment.
Keep BF16x16 API stable across target features
crate::simd::BF16x16 now resolves to two incompatible types depending on compile flags: this line hides the new portable simd_half::BF16x16 when target_feature="avx512bf16" is set, so AVX-512-BF16 builds get simd_avx512::BF16x16 (unsafe load/convert-only API) instead of the new arithmetic API (from_slice, add, mul, copy_to_slice). Any consumer code written against the newly introduced BF16x16 methods will compile on scalar/NEON/AVX2 targets and fail on AVX-512-BF16 targets, which breaks the cross-target SIMD dispatch parity this change is meant to provide.
Useful? React with 👍 / 👎.
Closes the dispatch-table gap for BF16 decode on AVX-512F silicon
without the BF16 extension (Skylake-X, Cascade Lake, Ice Lake-SP).
Before this commit, `bf16_to_f32_batch` was two-tier: avx512bf16
SIMD path (Cooper Lake, SPR+, Zen 4+) or scalar lane-by-lane
fallback. The middle tier — every Intel AVX-512 CPU from 2017 to
2021 plus AMD Zen 1-3 with avx512f — was forced through scalar even
though the BF16 → f32 conversion is just a 16-bit left shift and
AVX-512F has had `_mm512_cvtepu16_epi32` + `_mm512_slli_epi32` since
day one.
The new `convert_bf16_to_f32_avx512f` uses three AVX-512F
instructions per 16-lane chunk:
_mm256_loadu_si256 // 16 u16 → __m256i
_mm512_cvtepu16_epi32 // zero-extend to 16 u32 → __m512i
_mm512_slli_epi32::<16> // shift left by 16 (BF16 → f32 bits)
_mm512_castsi512_ps // bit-cast i32 → f32
_mm512_storeu_ps // store 16 f32
Plus a scalar tail for the last n % 16 lanes (handled via the
existing `bf16_to_f32_scalar` reference).
BF16 → f32 is mathematically exact (BF16 IS the upper 16 bits of
f32), so the AVX-512F path is byte-equal to the scalar reference on
every input, including subnormal, NaN, ±Inf, ±0 — verified in the
new direct test against a corpus that sweeps every (sign × exponent
× representative-mantissa) triple plus a 5-element tail to exercise
both the 16-aligned loop and the scalar tail.
Dispatch order after this commit:
1. avx512bf16 + avx512vl → `_mm512_cvtpbh_ps` path (best — 1 op)
2. avx512f → bit-shift path (this commit — 4 ops, no rounding)
3. scalar lane-by-lane fallback
Verification:
* Direct test `batch_bf16_to_f32_avx512f_matches_scalar` runs on
the `cascadelake` config (avx512f + bw + vl, no bf16) and
passes — asserts byte-equal output against scalar reference
across the full corpus.
* Existing `batch_conversion_matches_scalar` test on this host
(avx512_bf16 present) still hits the avx512bf16 path; the new
arm is dead code there, which is correct — the dispatch order
prefers the better intrinsic when available.
* Default v3 build (no AVX-512): 2087 lib tests pass; the new arm
isn't compiled because the surrounding test module is gated on
`target_feature = "avx512f"`.
* cargo clippy -- -D warnings clean.
* cargo fmt --all --check clean.
The symmetric f32 → BF16 direction already had its AVX-512F-only
RNE path (`f32_to_bf16_batch_rne` shipped in PR #126, byte-exact
vs `_mm512_cvtneps_pbh`). This commit closes the asymmetry so both
directions have AVX-512F-only paths on top of the avx512bf16 fast
path.
https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Closes parity items (2)+(3): half-precision SIMD vector types so burn's
NdArrayElement::F16/BF16enum variants can dispatch through ndarray's SIMD layer.What ships:
src/simd_half.rs(691 LOC) —BF16x16andF16x16types, scalar dispatch (upcast f32 → op → downcast)add_bf16_inplace,mul_bf16_inplace,add_f16_inplace,mul_f16_inplace,cast_*_to_*_batch(8 helpers)src/simd.rsTests: 21 new, all passing. Total lib: 1817+ pass.
SIMD-accelerated paths (AVX2 emulation, AVX-512-BF16 native, NEON +fp16) are a follow-up. Scalar implementation is correct and portable — unblocks burn's
NdArrayElementbound for half types.https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj