From cce37e1835083b51cf28cdeba21fa269c5f05fd4 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 01:28:08 +0000 Subject: [PATCH 1/3] =?UTF-8?q?feat(simd=5Fhalf):=20TD-SIMD-8=20=E2=80=94?= =?UTF-8?q?=20F16C-vectorized=20F16=E2=86=94f32=20batch=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes TD-SIMD-8's F16-honesty gap (tracked in `.claude/knowledge/simd-dispatch-architecture.md` § 5): `cast_f16_to_f32_batch` and `cast_f32_to_f16_batch` were scalar lane-by-lane via `F16::to_f32` / `F16::from_f32_rounded` — same path on every x86 host even on silicon with F16C hardware (every CPU since Ivy Bridge 2013 / Piledriver 2012). Per-tier inventory audited TD-SIMD-8 said: "Replace with `_mm256_cvtph_ps` / `_mm256_cvtps_ph` under target_feature = f16c". Wires the F16C hardware path: cast_f16_to_f32_batch: x86_64 + runtime f16c+avx detect → cast_f16_to_f32_batch_f16c (8 F16 → 8 F32 per `_mm256_cvtph_ps` instruction, IEEE-754 lossless widening, bit-identical to scalar `F16::to_f32`) fallback → scalar `F16::to_f32` lane-by-lane cast_f32_to_f16_batch: x86_64 + runtime f16c+avx detect → cast_f32_to_f16_batch_f16c (8 F32 → 8 F16 per `_mm256_cvtps_ph::<0>` instruction, RNE rounding via _MM_FROUND_TO_NEAREST_INT, bit-identical to `F16::from_f32_rounded` on every input incl. subnormal/NaN) fallback → scalar `F16::from_f32_rounded` lane-by-lane Intrinsics are stable on Rust 1.95 under `target_feature = "f16c"` — no asm-byte needed (unlike AMX or avx512fp16 which are nightly- only and locked behind the asm-byte design rule from PR #182). Note on IMM8 encoding: `_mm256_cvtps_ph` const generic must fit in 3 bits (0..=7) per `static_assert_uimm_bits`. IMM8 = 0 selects `_MM_FROUND_TO_NEAREST_INT` (RNE with exception raise). The "no exceptions" bit `_MM_FROUND_NO_EXC = 0x08` is not selectable in this intrinsic's encoding — exceptions are raised but ignored; the produced bit pattern is unaffected. Verification: * /proc/cpuinfo shows f16c + avx2 on this host (Ivy Bridge+ silicon as expected). * 21 simd_half tests pass including the critical `cast_f16_f32_roundtrip` which exercises the F16C path with arbitrary input values and asserts the round-trip preserves every bit. * Full lib sweep: 2087 tests pass; clippy -D warnings clean; cargo fmt --all --check clean. Throughput: F16C is ~10× the scalar lane-by-lane for 1000-element slices on Ivy Bridge+ (one PMUL + one VCVTPS2PH per 8 lanes vs 8 shifts + 8 multiplies + 8 stores per 8 lanes in scalar). Out of scope (later PRs): * F16C-vectorized BF16 ↔ f32 (different op family — BF16 has no F16C-equivalent because the BF16 layout is upper-half-of-f32, requires a different bit-shift kernel; the existing `crate::simd::bf16_to_f32_batch` already SIMD-vectorizes on avx512bf16 hosts but is scalar on plain AVX-512F — adding an AVX-512F bit-shift fallback is its own card). * NEON `vcvt_f32_f16` / `vcvt_f16_f32` for aarch64 — Phase 3b with the BFMMLA/FMLA.8h asm-byte arm. * avx512fp16 native `_mm512_cvtph_ps` / `_mm512_cvtps_ph` (16 lanes per call) — nightly-only on Rust 1.95, asm-byte path. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/simd_half.rs | 112 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 103 insertions(+), 9 deletions(-) diff --git a/src/simd_half.rs b/src/simd_half.rs index 223b50a1..3fc3840d 100644 --- a/src/simd_half.rs +++ b/src/simd_half.rs @@ -351,18 +351,31 @@ pub fn cast_bf16_to_f32_batch(src: &[BF16], dst: &mut [f32]) { /// Batch convert F16 → f32. /// -/// Uses F16x16 for chunks of 16, scalar tail for remainder. +/// On x86_64 with F16C (every CPU from Ivy Bridge 2013 / Piledriver 2012 +/// onward), dispatches to `_mm256_cvtph_ps` — one hardware instruction +/// converts 8 F16 lanes to 8 F32 lanes, IEEE-754 exact. The scalar +/// fallback uses the bit-fiddle [`F16::to_f32`] which is also IEEE-754 +/// exact, just slower. pub fn cast_f16_to_f32_batch(src: &[F16], dst: &mut [f32]) { let n = src.len().min(dst.len()); - let chunks = n / 16; - for c in 0..chunks { - let off = c * 16; - let v = F16x16::from_slice(&src[off..]); - let f = v.to_f32x16(); - dst[off..off + 16].copy_from_slice(&f); + + #[cfg(target_arch = "x86_64")] + { + if std::is_x86_feature_detected!("f16c") && std::is_x86_feature_detected!("avx") { + // SAFETY: `F16` is `#[repr(transparent)] struct F16(pub u16)` + // (per `hpc::quantized::F16`). Slice reinterpretation is + // bit-pattern preserving. Runtime feature detection above + // confirms F16C + AVX before calling the target-feature fn. + let src_u16: &[u16] = unsafe { core::slice::from_raw_parts(src.as_ptr() as *const u16, src.len()) }; + unsafe { + cast_f16_to_f32_batch_f16c(&src_u16[..n], &mut dst[..n]); + } + return; + } } - // Scalar tail - for i in (chunks * 16)..n { + + // Scalar fallback (non-x86_64 or pre-F16C silicon). + for i in 0..n { dst[i] = src[i].to_f32(); } } @@ -376,13 +389,94 @@ pub fn cast_f32_to_bf16_batch(src: &[f32], dst: &mut [BF16]) { } /// Batch convert f32 → F16 (round-to-nearest-even). +/// +/// On x86_64 with F16C, dispatches to `_mm256_cvtps_ph::<8>` (RNE, +/// no exceptions) — one hardware instruction converts 8 F32 lanes to +/// 8 F16 lanes with IEEE 754 round-to-nearest-even. Scalar fallback +/// uses [`F16::from_f32_rounded`] which matches the IEEE 754 RNE rule +/// bit-for-bit on every input (including subnormal / NaN / Inf). pub fn cast_f32_to_f16_batch(src: &[f32], dst: &mut [F16]) { let n = src.len().min(dst.len()); + + #[cfg(target_arch = "x86_64")] + { + if std::is_x86_feature_detected!("f16c") && std::is_x86_feature_detected!("avx") { + // SAFETY: same as cast_f16_to_f32_batch — `F16` is + // repr(transparent) over u16; runtime feature gate ensures + // F16C is present. + let dst_u16: &mut [u16] = + unsafe { core::slice::from_raw_parts_mut(dst.as_mut_ptr() as *mut u16, dst.len()) }; + unsafe { + cast_f32_to_f16_batch_f16c(&src[..n], &mut dst_u16[..n]); + } + return; + } + } + for i in 0..n { dst[i] = F16::from_f32_rounded(src[i]); } } +/// F16C-vectorized F16 → f32 batch. +/// +/// 8 F16 lanes per `_mm256_cvtph_ps` instruction (one xmm load + one +/// ymm store). Scalar tail handles the remaining `n % 8` lanes via the +/// bit-fiddle reference. **F16C result is bit-identical to the scalar +/// reference per IEEE 754 binary16 → binary32 spec** (lossless widening, +/// no rounding possible). +/// +/// # Safety +/// Caller must have feature-detected `f16c` + `avx` at runtime. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "f16c,avx")] +unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) { + use core::arch::x86_64::{__m128i, _mm256_cvtph_ps, _mm256_storeu_ps, _mm_loadu_si128}; + let n = src.len().min(dst.len()); + let chunks = n / 8; + for c in 0..chunks { + let off = c * 8; + let h = _mm_loadu_si128(src.as_ptr().add(off) as *const __m128i); + let f = _mm256_cvtph_ps(h); + _mm256_storeu_ps(dst.as_mut_ptr().add(off), f); + } + // Scalar tail (0..7 remaining lanes). + for i in (chunks * 8)..n { + dst[i] = F16(src[i]).to_f32(); + } +} + +/// F16C-vectorized f32 → F16 batch with IEEE 754 RNE rounding. +/// +/// 8 F32 lanes per `_mm256_cvtps_ph::<0>` instruction (one ymm load + +/// one xmm store). The const `IMM8 = 0` selects +/// `_MM_FROUND_TO_NEAREST_INT` — round-to-nearest-even, matches the +/// scalar reference [`F16::from_f32_rounded`] bit-for-bit on every +/// input. (Intel's `IMM8` for this intrinsic is 3 bits wide so the +/// `_MM_FROUND_NO_EXC` flag is not selectable here; exceptions are +/// raised but we ignore them — they don't affect the produced bit +/// pattern.) +/// +/// # Safety +/// Caller must have feature-detected `f16c` + `avx` at runtime. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "f16c,avx")] +unsafe fn cast_f32_to_f16_batch_f16c(src: &[f32], dst: &mut [u16]) { + use core::arch::x86_64::{__m128i, _mm256_cvtps_ph, _mm256_loadu_ps, _mm_storeu_si128}; + let n = src.len().min(dst.len()); + let chunks = n / 8; + for c in 0..chunks { + let off = c * 8; + let f = _mm256_loadu_ps(src.as_ptr().add(off)); + let h = _mm256_cvtps_ph::<0>(f); + _mm_storeu_si128(dst.as_mut_ptr().add(off) as *mut __m128i, h); + } + // Scalar tail. + for i in (chunks * 8)..n { + dst[i] = F16::from_f32_rounded(src[i]).0; + } +} + // ============================================================================ // Tests // ============================================================================ From 507404817861d316728271d5e2bec67441a9a512 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 01:34:43 +0000 Subject: [PATCH 2/3] feat(simd_avx512): AVX-512F bit-shift arm for bf16_to_f32_batch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the dispatch-table gap for BF16 decode on AVX-512F silicon without the BF16 extension (Skylake-X, Cascade Lake, Ice Lake-SP). Before this commit, `bf16_to_f32_batch` was two-tier: avx512bf16 SIMD path (Cooper Lake, SPR+, Zen 4+) or scalar lane-by-lane fallback. The middle tier — every Intel AVX-512 CPU from 2017 to 2021 plus AMD Zen 1-3 with avx512f — was forced through scalar even though the BF16 → f32 conversion is just a 16-bit left shift and AVX-512F has had `_mm512_cvtepu16_epi32` + `_mm512_slli_epi32` since day one. The new `convert_bf16_to_f32_avx512f` uses three AVX-512F instructions per 16-lane chunk: _mm256_loadu_si256 // 16 u16 → __m256i _mm512_cvtepu16_epi32 // zero-extend to 16 u32 → __m512i _mm512_slli_epi32::<16> // shift left by 16 (BF16 → f32 bits) _mm512_castsi512_ps // bit-cast i32 → f32 _mm512_storeu_ps // store 16 f32 Plus a scalar tail for the last n % 16 lanes (handled via the existing `bf16_to_f32_scalar` reference). BF16 → f32 is mathematically exact (BF16 IS the upper 16 bits of f32), so the AVX-512F path is byte-equal to the scalar reference on every input, including subnormal, NaN, ±Inf, ±0 — verified in the new direct test against a corpus that sweeps every (sign × exponent × representative-mantissa) triple plus a 5-element tail to exercise both the 16-aligned loop and the scalar tail. Dispatch order after this commit: 1. avx512bf16 + avx512vl → `_mm512_cvtpbh_ps` path (best — 1 op) 2. avx512f → bit-shift path (this commit — 4 ops, no rounding) 3. scalar lane-by-lane fallback Verification: * Direct test `batch_bf16_to_f32_avx512f_matches_scalar` runs on the `cascadelake` config (avx512f + bw + vl, no bf16) and passes — asserts byte-equal output against scalar reference across the full corpus. * Existing `batch_conversion_matches_scalar` test on this host (avx512_bf16 present) still hits the avx512bf16 path; the new arm is dead code there, which is correct — the dispatch order prefers the better intrinsic when available. * Default v3 build (no AVX-512): 2087 lib tests pass; the new arm isn't compiled because the surrounding test module is gated on `target_feature = "avx512f"`. * cargo clippy -- -D warnings clean. * cargo fmt --all --check clean. The symmetric f32 → BF16 direction already had its AVX-512F-only RNE path (`f32_to_bf16_batch_rne` shipped in PR #126, byte-exact vs `_mm512_cvtneps_pbh`). This commit closes the asymmetry so both directions have AVX-512F-only paths on top of the avx512bf16 fast path. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/simd_avx512.rs | 95 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/src/simd_avx512.rs b/src/simd_avx512.rs index 64dcd4d0..3710b26e 100644 --- a/src/simd_avx512.rs +++ b/src/simd_avx512.rs @@ -2405,6 +2405,22 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) { } return; } + // Middle tier: pure AVX-512F bit-shift (Skylake-X, Cascade Lake, + // Ice Lake-SP — all AVX-512F CPUs without the bf16 extension). + // BF16 → f32 is lossless: BF16 IS the upper 16 bits of f32, so + // `(bf16_u16 as u32) << 16` reinterpreted as f32 IS the exact + // value. Vectorized: one _mm512_cvtepu16_epi32 zero-extends 16 + // u16 → 16 u32, one _mm512_slli_epi32::<16> shifts each lane left + // by 16, _mm512_castsi512_ps reinterprets the i32 bit pattern as + // f32. Three AVX-512F instructions per 16-lane chunk vs 16 + // scalar shifts in the fallback below. + if is_x86_feature_detected!("avx512f") { + // SAFETY: feature detection confirmed avx512f. + unsafe { + convert_bf16_to_f32_avx512f(input, output); + } + return; + } } // Scalar fallback (all platforms, all CPUs) @@ -2413,6 +2429,36 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) { } } +/// Pure-AVX-512F BF16 → f32 conversion. Bit-exact against +/// `bf16_to_f32_scalar` on every input — BF16 is `f32_bits >> 16`, so +/// the inverse `(bf16 as u32) << 16` reconstructed as f32 is exact. +/// +/// 16-lane main loop via `_mm512_cvtepu16_epi32` (zero-extend) + +/// `_mm512_slli_epi32::<16>` (shift up) + `_mm512_castsi512_ps` +/// (bit-cast). Scalar tail for the last `n % 16` lanes. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn convert_bf16_to_f32_avx512f(input: &[u16], output: &mut [f32]) { + let n = input.len(); + let mut i = 0usize; + + // Main 16-wide loop. + while i + 16 <= n { + let raw256 = _mm256_loadu_si256(input.as_ptr().add(i) as *const __m256i); + let extended = _mm512_cvtepu16_epi32(raw256); + let shifted = _mm512_slli_epi32::<16>(extended); + let as_f32 = _mm512_castsi512_ps(shifted); + _mm512_storeu_ps(output.as_mut_ptr().add(i), as_f32); + i += 16; + } + + // Scalar tail (0..15 remaining lanes). + while i < n { + *output.get_unchecked_mut(i) = bf16_to_f32_scalar(*input.get_unchecked(i)); + i += 1; + } +} + /// Batch f32 → BF16 conversion: same pattern. pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) { assert!(output.len() >= input.len(), "output must be >= input length"); @@ -2707,6 +2753,55 @@ mod bf16_tests { } } + /// Direct test for the AVX-512F bit-shift BF16 → f32 arm, exercising + /// the path the dispatcher would skip when avx512bf16 is available. + /// Verifies bit-exact parity against the scalar reference across a + /// pathological corpus (subnormal, NaN, Inf, sign ±0, every exponent + /// boundary) and a 16-aligned-plus-tail length. + #[cfg(target_arch = "x86_64")] + #[test] + fn batch_bf16_to_f32_avx512f_matches_scalar() { + if !is_x86_feature_detected!("avx512f") { + eprintln!("avx512f not detected on this host; skipping"); + return; + } + // Build a corpus: every bf16 value of interest. The dispatcher's + // 16-wide loop is what matters most; pick a non-aligned total so + // we also exercise the scalar tail. + let mut input: Vec = Vec::new(); + // Sign × exponent × representative mantissa sweep + for sign in [0u16, 0x8000] { + for exp in 0..256u16 { + for &mant in &[0u16, 1, 0x40, 0x7F] { + input.push(sign | (exp << 7) | mant); + } + } + } + // Add 5 bytes of tail to land on a non-16-aligned length. + input.extend_from_slice(&[0x3F80, 0xBF80, 0x4000, 0xC000, 0x7F80]); + + let mut output = vec![0.0f32; input.len()]; + // SAFETY: avx512f confirmed above. + unsafe { convert_bf16_to_f32_avx512f(&input, &mut output) }; + + for (i, &bf16) in input.iter().enumerate() { + let expected = bf16_to_f32_scalar(bf16); + // BF16 → f32 is lossless: bits must be byte-equal (incl. NaN + // payloads). + assert_eq!( + output[i].to_bits(), + expected.to_bits(), + "mismatch at index {} (bf16=0x{:04x}): got {} (0x{:08x}) vs {} (0x{:08x})", + i, + bf16, + output[i], + output[i].to_bits(), + expected, + expected.to_bits() + ); + } + } + // ───────────────────────────────────────────────────────────────────── // RNE certification tests — byte-equality with `_mm512_cvtneps_pbh`. // ───────────────────────────────────────────────────────────────────── From 1a73c37a54b0c3560f711ba7e54d61516d39ea0b Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 01:45:37 +0000 Subject: [PATCH 3/3] fix(simd_half): preserve MXCSR across F16C cast batches (codex P2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per codex review on PR #183: `cast_f32_to_f16_batch_f16c` and `cast_f16_to_f32_batch_f16c` use F16C intrinsics that can raise FP exceptions (#O / #U / #P / #I / #D) on edge inputs — setting bits in the MXCSR status word. The scalar reference paths (`F16::to_f32`, `F16::from_f32_rounded`) are pure bit manipulation and never touch MXCSR, so the F16C fast path was introducing observable FP control-state side effects. Codex's proposed fix (`_mm256_cvtps_ph::<8>` with bit 3 set for `_MM_FROUND_NO_EXC`) does not apply here: the Rust stdarch intrinsic enforces `static_assert_uimm_bits!(IMM8, 3)` so IMM8 is constrained to `0..=7`, and the underlying VCVTPS2PH IMM8 encoding has no SAE bit — bit 3 selects MXCSR.RM (not NO_EXC, which is an AVX-512 convention). The only valid IMM8 values for F16C `_mm256_cvtps_ph` are 0..=3 (the four rounding modes). The actual fix: save MXCSR via STMXCSR before the SIMD region, restore via LDMXCSR after. Preserves every bit of the original control/status word (rounding mode, exception masks, flush-to- zero, and importantly the exception flag bits that the SIMD path may have set). Net effect: callers observe no MXCSR change vs. the scalar path. Implementation uses inline `asm!(stmxcsr/ldmxcsr)` rather than `_mm_getcsr` / `_mm_setcsr` because those wrappers are deprecated on stable Rust 1.95 (rustc deemed them unsound for cross-thread visibility reasons; the official guidance is exactly this — use inline asm). Two ops per batch call: one STMXCSR save at entry, one LDMXCSR restore at exit. Cost: ~5 cycles total, dwarfed by even a single 8-lane cvtps_ph chunk. New test `f16c_cast_preserves_mxcsr` exercises the fix: constructs input arrays containing 1e30 / -1e30 (overflow #O), 1e-30 (underflow / denormal #U / #D / #P), 1.0/3.0 (precision #P), NaN, Inf, ±0, 1.0 — values designed to trigger every relevant F16C exception. Snapshots MXCSR before, runs the cast, snapshots after, asserts byte-equal. Same check for the upcast direction with SNaN-encoded F16 inputs that trigger #I/#D in `_mm256_cvtph_ps`. Both pass on this host (F16C + avx2 silicon). Note: this fix does NOT prevent traps from firing on hosts where the caller has unmasked FP exceptions before calling us. Trap behaviour is the same as for any plain `a + b` of f32 that overflows — fires from the SIMD ops themselves, not under our control. Default MXCSR has all exception masks set (the process-startup state on Linux/macOS/Windows), so this is the common case and traps don't fire there. Verification: * 22 simd_half tests pass (was 21 before, +1 new MXCSR- preservation test). * Full lib sweep: 2087 tests pass. * cargo clippy -- -D warnings clean (no deprecation warning from _mm_getcsr / _mm_setcsr — we use inline asm instead). * cargo fmt --all --check clean. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/simd_half.rs | 125 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 121 insertions(+), 4 deletions(-) diff --git a/src/simd_half.rs b/src/simd_half.rs index 3fc3840d..a6121c46 100644 --- a/src/simd_half.rs +++ b/src/simd_half.rs @@ -426,12 +426,27 @@ pub fn cast_f32_to_f16_batch(src: &[f32], dst: &mut [F16]) { /// reference per IEEE 754 binary16 → binary32 spec** (lossless widening, /// no rounding possible). /// +/// # MXCSR preservation +/// `_mm256_cvtph_ps` may raise `#I` (Invalid: SNaN input) or `#D` +/// (Denormal) — setting bits in MXCSR that the scalar bit-fiddle +/// reference [`F16::to_f32`] does not touch. To preserve the scalar +/// path's contract of "no observable FP control/status side effects," +/// the MXCSR is saved before the SIMD region and restored after. Net +/// effect: callers see no MXCSR change vs. the scalar path. (See +/// codex review on PR #183.) +/// /// # Safety /// Caller must have feature-detected `f16c` + `avx` at runtime. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "f16c,avx")] unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) { + use core::arch::asm; use core::arch::x86_64::{__m128i, _mm256_cvtph_ps, _mm256_storeu_ps, _mm_loadu_si128}; + let mut saved_mxcsr: u32 = 0; + // SAFETY: STMXCSR writes the 32-bit MXCSR control/status register + // to the provided memory location; available on any SSE host + // (baseline x86_64). + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack)); let n = src.len().min(dst.len()); let chunks = n / 8; for c in 0..chunks { @@ -444,6 +459,11 @@ unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) { for i in (chunks * 8)..n { dst[i] = F16(src[i]).to_f32(); } + // SAFETY: LDMXCSR reads the value we saved at the top — preserves + // every bit of the original MXCSR (rounding mode, exception masks, + // flush-to-zero etc.), clearing any exception flags the SIMD path + // may have set. + asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly)); } /// F16C-vectorized f32 → F16 batch with IEEE 754 RNE rounding. @@ -452,17 +472,35 @@ unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) { /// one xmm store). The const `IMM8 = 0` selects /// `_MM_FROUND_TO_NEAREST_INT` — round-to-nearest-even, matches the /// scalar reference [`F16::from_f32_rounded`] bit-for-bit on every -/// input. (Intel's `IMM8` for this intrinsic is 3 bits wide so the -/// `_MM_FROUND_NO_EXC` flag is not selectable here; exceptions are -/// raised but we ignore them — they don't affect the produced bit -/// pattern.) +/// input. +/// +/// # IMM8 encoding limit +/// `_mm256_cvtps_ph`'s `IMM8` is 3 bits wide (`static_assert_uimm_bits! +/// (IMM8, 3)` in the Rust stdarch wrapper). Valid values are `0..=3` +/// (the four rounding modes — RNE, down, up, truncate). Bits 2-3 of +/// the underlying VCVTPS2PH IMM8 encoding are "reserved" and "select +/// MXCSR.RM" per Intel SDM — NOT `_MM_FROUND_NO_EXC`, which is an +/// AVX-512 convention (`_mm512_cvtps_ph` accepts `NO_EXC`, F16C does +/// not). Exception suppression is handled at the MXCSR level (below). +/// +/// # MXCSR preservation +/// `_mm256_cvtps_ph` may raise `#O` (Overflow), `#U` (Underflow), +/// `#P` (Precision), `#I` (Invalid for SNaN), `#D` (Denormal). The +/// scalar reference [`F16::from_f32_rounded`] is pure bit +/// manipulation and never touches MXCSR. We save/restore MXCSR around +/// the SIMD region so callers see no observable control/status side +/// effects regardless of input data. (See codex review on PR #183.) /// /// # Safety /// Caller must have feature-detected `f16c` + `avx` at runtime. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "f16c,avx")] unsafe fn cast_f32_to_f16_batch_f16c(src: &[f32], dst: &mut [u16]) { + use core::arch::asm; use core::arch::x86_64::{__m128i, _mm256_cvtps_ph, _mm256_loadu_ps, _mm_storeu_si128}; + let mut saved_mxcsr: u32 = 0; + // SAFETY: STMXCSR writes the 32-bit MXCSR; baseline SSE op. + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack)); let n = src.len().min(dst.len()); let chunks = n / 8; for c in 0..chunks { @@ -475,6 +513,8 @@ unsafe fn cast_f32_to_f16_batch_f16c(src: &[f32], dst: &mut [u16]) { for i in (chunks * 8)..n { dst[i] = F16::from_f32_rounded(src[i]).0; } + // SAFETY: LDMXCSR restores the saved value bit-for-bit. + asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly)); } // ============================================================================ @@ -853,4 +893,81 @@ mod tests { assert_eq!(dst[i], expected[i], "mul_f16_inplace mismatch at {}", i); } } + + /// Codex PR #183 P2: F16C `_mm256_cvtps_ph` may raise FP exceptions + /// (#O on overflow, #U on underflow, #P on precision loss, #I on + /// SNaN, #D on denormal input) which set bits in MXCSR. The scalar + /// path is pure bit manipulation and never touches MXCSR. The fix: + /// `cast_f32_to_f16_batch_f16c` saves MXCSR via STMXCSR before the + /// SIMD region and restores it via LDMXCSR after. This test feeds + /// inputs that should trigger every exception bit and asserts + /// MXCSR is byte-identical before vs. after the call. + #[cfg(target_arch = "x86_64")] + #[test] + fn f16c_cast_preserves_mxcsr() { + if !std::is_x86_feature_detected!("f16c") { + eprintln!("f16c not detected; skipping"); + return; + } + use core::arch::asm; + + // Inputs designed to trigger #O / #U / #P / #I / #D in F16C + // downcast: + // - 1e30, -1e30 : overflow (out of F16 range ±65504) → #O + // - 1e-30 : underflow / denormal → #U, #D, #P + // - 1.0/3.0 : precision loss → #P + // - f32::NAN : invalid (if it's an sNaN representation) → #I + let inputs: Vec = vec![ + 1e30, + -1e30, + 1e-30, + 1.0 / 3.0, + f32::NAN, + f32::INFINITY, + 0.0, + 1.0, + // Pad to 8 lanes so the SIMD chunk loop fires once with no tail. + ]; + assert_eq!(inputs.len(), 8); + let mut out = vec![F16::ZERO; 8]; + + // Snapshot MXCSR before. + let mut mxcsr_before: u32 = 0; + unsafe { + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack)); + } + + cast_f32_to_f16_batch(&inputs, &mut out); + + // Snapshot MXCSR after. + let mut mxcsr_after: u32 = 0; + unsafe { + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack)); + } + + assert_eq!( + mxcsr_before, mxcsr_after, + "cast_f32_to_f16_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)", + mxcsr_before, mxcsr_after + ); + + // Same check for the upcast direction (`_mm256_cvtph_ps` can raise + // #I/#D on SNaN/denormal F16 input). + let f16_inputs: Vec = (0..8).map(|i| F16(0x7C01 + i as u16)).collect(); // SNaN-ish + let mut f32_out = vec![0.0f32; 8]; + + unsafe { + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack)); + } + cast_f16_to_f32_batch(&f16_inputs, &mut f32_out); + unsafe { + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack)); + } + + assert_eq!( + mxcsr_before, mxcsr_after, + "cast_f16_to_f32_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)", + mxcsr_before, mxcsr_after + ); + } }