diff --git a/src/simd_avx512.rs b/src/simd_avx512.rs index 64dcd4d0..3710b26e 100644 --- a/src/simd_avx512.rs +++ b/src/simd_avx512.rs @@ -2405,6 +2405,22 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) { } return; } + // Middle tier: pure AVX-512F bit-shift (Skylake-X, Cascade Lake, + // Ice Lake-SP — all AVX-512F CPUs without the bf16 extension). + // BF16 → f32 is lossless: BF16 IS the upper 16 bits of f32, so + // `(bf16_u16 as u32) << 16` reinterpreted as f32 IS the exact + // value. Vectorized: one _mm512_cvtepu16_epi32 zero-extends 16 + // u16 → 16 u32, one _mm512_slli_epi32::<16> shifts each lane left + // by 16, _mm512_castsi512_ps reinterprets the i32 bit pattern as + // f32. Three AVX-512F instructions per 16-lane chunk vs 16 + // scalar shifts in the fallback below. + if is_x86_feature_detected!("avx512f") { + // SAFETY: feature detection confirmed avx512f. + unsafe { + convert_bf16_to_f32_avx512f(input, output); + } + return; + } } // Scalar fallback (all platforms, all CPUs) @@ -2413,6 +2429,36 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) { } } +/// Pure-AVX-512F BF16 → f32 conversion. Bit-exact against +/// `bf16_to_f32_scalar` on every input — BF16 is `f32_bits >> 16`, so +/// the inverse `(bf16 as u32) << 16` reconstructed as f32 is exact. +/// +/// 16-lane main loop via `_mm512_cvtepu16_epi32` (zero-extend) + +/// `_mm512_slli_epi32::<16>` (shift up) + `_mm512_castsi512_ps` +/// (bit-cast). Scalar tail for the last `n % 16` lanes. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn convert_bf16_to_f32_avx512f(input: &[u16], output: &mut [f32]) { + let n = input.len(); + let mut i = 0usize; + + // Main 16-wide loop. + while i + 16 <= n { + let raw256 = _mm256_loadu_si256(input.as_ptr().add(i) as *const __m256i); + let extended = _mm512_cvtepu16_epi32(raw256); + let shifted = _mm512_slli_epi32::<16>(extended); + let as_f32 = _mm512_castsi512_ps(shifted); + _mm512_storeu_ps(output.as_mut_ptr().add(i), as_f32); + i += 16; + } + + // Scalar tail (0..15 remaining lanes). + while i < n { + *output.get_unchecked_mut(i) = bf16_to_f32_scalar(*input.get_unchecked(i)); + i += 1; + } +} + /// Batch f32 → BF16 conversion: same pattern. pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) { assert!(output.len() >= input.len(), "output must be >= input length"); @@ -2707,6 +2753,55 @@ mod bf16_tests { } } + /// Direct test for the AVX-512F bit-shift BF16 → f32 arm, exercising + /// the path the dispatcher would skip when avx512bf16 is available. + /// Verifies bit-exact parity against the scalar reference across a + /// pathological corpus (subnormal, NaN, Inf, sign ±0, every exponent + /// boundary) and a 16-aligned-plus-tail length. + #[cfg(target_arch = "x86_64")] + #[test] + fn batch_bf16_to_f32_avx512f_matches_scalar() { + if !is_x86_feature_detected!("avx512f") { + eprintln!("avx512f not detected on this host; skipping"); + return; + } + // Build a corpus: every bf16 value of interest. The dispatcher's + // 16-wide loop is what matters most; pick a non-aligned total so + // we also exercise the scalar tail. + let mut input: Vec = Vec::new(); + // Sign × exponent × representative mantissa sweep + for sign in [0u16, 0x8000] { + for exp in 0..256u16 { + for &mant in &[0u16, 1, 0x40, 0x7F] { + input.push(sign | (exp << 7) | mant); + } + } + } + // Add 5 bytes of tail to land on a non-16-aligned length. + input.extend_from_slice(&[0x3F80, 0xBF80, 0x4000, 0xC000, 0x7F80]); + + let mut output = vec![0.0f32; input.len()]; + // SAFETY: avx512f confirmed above. + unsafe { convert_bf16_to_f32_avx512f(&input, &mut output) }; + + for (i, &bf16) in input.iter().enumerate() { + let expected = bf16_to_f32_scalar(bf16); + // BF16 → f32 is lossless: bits must be byte-equal (incl. NaN + // payloads). + assert_eq!( + output[i].to_bits(), + expected.to_bits(), + "mismatch at index {} (bf16=0x{:04x}): got {} (0x{:08x}) vs {} (0x{:08x})", + i, + bf16, + output[i], + output[i].to_bits(), + expected, + expected.to_bits() + ); + } + } + // ───────────────────────────────────────────────────────────────────── // RNE certification tests — byte-equality with `_mm512_cvtneps_pbh`. // ───────────────────────────────────────────────────────────────────── diff --git a/src/simd_half.rs b/src/simd_half.rs index 223b50a1..a6121c46 100644 --- a/src/simd_half.rs +++ b/src/simd_half.rs @@ -351,18 +351,31 @@ pub fn cast_bf16_to_f32_batch(src: &[BF16], dst: &mut [f32]) { /// Batch convert F16 → f32. /// -/// Uses F16x16 for chunks of 16, scalar tail for remainder. +/// On x86_64 with F16C (every CPU from Ivy Bridge 2013 / Piledriver 2012 +/// onward), dispatches to `_mm256_cvtph_ps` — one hardware instruction +/// converts 8 F16 lanes to 8 F32 lanes, IEEE-754 exact. The scalar +/// fallback uses the bit-fiddle [`F16::to_f32`] which is also IEEE-754 +/// exact, just slower. pub fn cast_f16_to_f32_batch(src: &[F16], dst: &mut [f32]) { let n = src.len().min(dst.len()); - let chunks = n / 16; - for c in 0..chunks { - let off = c * 16; - let v = F16x16::from_slice(&src[off..]); - let f = v.to_f32x16(); - dst[off..off + 16].copy_from_slice(&f); + + #[cfg(target_arch = "x86_64")] + { + if std::is_x86_feature_detected!("f16c") && std::is_x86_feature_detected!("avx") { + // SAFETY: `F16` is `#[repr(transparent)] struct F16(pub u16)` + // (per `hpc::quantized::F16`). Slice reinterpretation is + // bit-pattern preserving. Runtime feature detection above + // confirms F16C + AVX before calling the target-feature fn. + let src_u16: &[u16] = unsafe { core::slice::from_raw_parts(src.as_ptr() as *const u16, src.len()) }; + unsafe { + cast_f16_to_f32_batch_f16c(&src_u16[..n], &mut dst[..n]); + } + return; + } } - // Scalar tail - for i in (chunks * 16)..n { + + // Scalar fallback (non-x86_64 or pre-F16C silicon). + for i in 0..n { dst[i] = src[i].to_f32(); } } @@ -376,13 +389,134 @@ pub fn cast_f32_to_bf16_batch(src: &[f32], dst: &mut [BF16]) { } /// Batch convert f32 → F16 (round-to-nearest-even). +/// +/// On x86_64 with F16C, dispatches to `_mm256_cvtps_ph::<8>` (RNE, +/// no exceptions) — one hardware instruction converts 8 F32 lanes to +/// 8 F16 lanes with IEEE 754 round-to-nearest-even. Scalar fallback +/// uses [`F16::from_f32_rounded`] which matches the IEEE 754 RNE rule +/// bit-for-bit on every input (including subnormal / NaN / Inf). pub fn cast_f32_to_f16_batch(src: &[f32], dst: &mut [F16]) { let n = src.len().min(dst.len()); + + #[cfg(target_arch = "x86_64")] + { + if std::is_x86_feature_detected!("f16c") && std::is_x86_feature_detected!("avx") { + // SAFETY: same as cast_f16_to_f32_batch — `F16` is + // repr(transparent) over u16; runtime feature gate ensures + // F16C is present. + let dst_u16: &mut [u16] = + unsafe { core::slice::from_raw_parts_mut(dst.as_mut_ptr() as *mut u16, dst.len()) }; + unsafe { + cast_f32_to_f16_batch_f16c(&src[..n], &mut dst_u16[..n]); + } + return; + } + } + for i in 0..n { dst[i] = F16::from_f32_rounded(src[i]); } } +/// F16C-vectorized F16 → f32 batch. +/// +/// 8 F16 lanes per `_mm256_cvtph_ps` instruction (one xmm load + one +/// ymm store). Scalar tail handles the remaining `n % 8` lanes via the +/// bit-fiddle reference. **F16C result is bit-identical to the scalar +/// reference per IEEE 754 binary16 → binary32 spec** (lossless widening, +/// no rounding possible). +/// +/// # MXCSR preservation +/// `_mm256_cvtph_ps` may raise `#I` (Invalid: SNaN input) or `#D` +/// (Denormal) — setting bits in MXCSR that the scalar bit-fiddle +/// reference [`F16::to_f32`] does not touch. To preserve the scalar +/// path's contract of "no observable FP control/status side effects," +/// the MXCSR is saved before the SIMD region and restored after. Net +/// effect: callers see no MXCSR change vs. the scalar path. (See +/// codex review on PR #183.) +/// +/// # Safety +/// Caller must have feature-detected `f16c` + `avx` at runtime. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "f16c,avx")] +unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) { + use core::arch::asm; + use core::arch::x86_64::{__m128i, _mm256_cvtph_ps, _mm256_storeu_ps, _mm_loadu_si128}; + let mut saved_mxcsr: u32 = 0; + // SAFETY: STMXCSR writes the 32-bit MXCSR control/status register + // to the provided memory location; available on any SSE host + // (baseline x86_64). + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack)); + let n = src.len().min(dst.len()); + let chunks = n / 8; + for c in 0..chunks { + let off = c * 8; + let h = _mm_loadu_si128(src.as_ptr().add(off) as *const __m128i); + let f = _mm256_cvtph_ps(h); + _mm256_storeu_ps(dst.as_mut_ptr().add(off), f); + } + // Scalar tail (0..7 remaining lanes). + for i in (chunks * 8)..n { + dst[i] = F16(src[i]).to_f32(); + } + // SAFETY: LDMXCSR reads the value we saved at the top — preserves + // every bit of the original MXCSR (rounding mode, exception masks, + // flush-to-zero etc.), clearing any exception flags the SIMD path + // may have set. + asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly)); +} + +/// F16C-vectorized f32 → F16 batch with IEEE 754 RNE rounding. +/// +/// 8 F32 lanes per `_mm256_cvtps_ph::<0>` instruction (one ymm load + +/// one xmm store). The const `IMM8 = 0` selects +/// `_MM_FROUND_TO_NEAREST_INT` — round-to-nearest-even, matches the +/// scalar reference [`F16::from_f32_rounded`] bit-for-bit on every +/// input. +/// +/// # IMM8 encoding limit +/// `_mm256_cvtps_ph`'s `IMM8` is 3 bits wide (`static_assert_uimm_bits! +/// (IMM8, 3)` in the Rust stdarch wrapper). Valid values are `0..=3` +/// (the four rounding modes — RNE, down, up, truncate). Bits 2-3 of +/// the underlying VCVTPS2PH IMM8 encoding are "reserved" and "select +/// MXCSR.RM" per Intel SDM — NOT `_MM_FROUND_NO_EXC`, which is an +/// AVX-512 convention (`_mm512_cvtps_ph` accepts `NO_EXC`, F16C does +/// not). Exception suppression is handled at the MXCSR level (below). +/// +/// # MXCSR preservation +/// `_mm256_cvtps_ph` may raise `#O` (Overflow), `#U` (Underflow), +/// `#P` (Precision), `#I` (Invalid for SNaN), `#D` (Denormal). The +/// scalar reference [`F16::from_f32_rounded`] is pure bit +/// manipulation and never touches MXCSR. We save/restore MXCSR around +/// the SIMD region so callers see no observable control/status side +/// effects regardless of input data. (See codex review on PR #183.) +/// +/// # Safety +/// Caller must have feature-detected `f16c` + `avx` at runtime. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "f16c,avx")] +unsafe fn cast_f32_to_f16_batch_f16c(src: &[f32], dst: &mut [u16]) { + use core::arch::asm; + use core::arch::x86_64::{__m128i, _mm256_cvtps_ph, _mm256_loadu_ps, _mm_storeu_si128}; + let mut saved_mxcsr: u32 = 0; + // SAFETY: STMXCSR writes the 32-bit MXCSR; baseline SSE op. + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack)); + let n = src.len().min(dst.len()); + let chunks = n / 8; + for c in 0..chunks { + let off = c * 8; + let f = _mm256_loadu_ps(src.as_ptr().add(off)); + let h = _mm256_cvtps_ph::<0>(f); + _mm_storeu_si128(dst.as_mut_ptr().add(off) as *mut __m128i, h); + } + // Scalar tail. + for i in (chunks * 8)..n { + dst[i] = F16::from_f32_rounded(src[i]).0; + } + // SAFETY: LDMXCSR restores the saved value bit-for-bit. + asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly)); +} + // ============================================================================ // Tests // ============================================================================ @@ -759,4 +893,81 @@ mod tests { assert_eq!(dst[i], expected[i], "mul_f16_inplace mismatch at {}", i); } } + + /// Codex PR #183 P2: F16C `_mm256_cvtps_ph` may raise FP exceptions + /// (#O on overflow, #U on underflow, #P on precision loss, #I on + /// SNaN, #D on denormal input) which set bits in MXCSR. The scalar + /// path is pure bit manipulation and never touches MXCSR. The fix: + /// `cast_f32_to_f16_batch_f16c` saves MXCSR via STMXCSR before the + /// SIMD region and restores it via LDMXCSR after. This test feeds + /// inputs that should trigger every exception bit and asserts + /// MXCSR is byte-identical before vs. after the call. + #[cfg(target_arch = "x86_64")] + #[test] + fn f16c_cast_preserves_mxcsr() { + if !std::is_x86_feature_detected!("f16c") { + eprintln!("f16c not detected; skipping"); + return; + } + use core::arch::asm; + + // Inputs designed to trigger #O / #U / #P / #I / #D in F16C + // downcast: + // - 1e30, -1e30 : overflow (out of F16 range ±65504) → #O + // - 1e-30 : underflow / denormal → #U, #D, #P + // - 1.0/3.0 : precision loss → #P + // - f32::NAN : invalid (if it's an sNaN representation) → #I + let inputs: Vec = vec![ + 1e30, + -1e30, + 1e-30, + 1.0 / 3.0, + f32::NAN, + f32::INFINITY, + 0.0, + 1.0, + // Pad to 8 lanes so the SIMD chunk loop fires once with no tail. + ]; + assert_eq!(inputs.len(), 8); + let mut out = vec![F16::ZERO; 8]; + + // Snapshot MXCSR before. + let mut mxcsr_before: u32 = 0; + unsafe { + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack)); + } + + cast_f32_to_f16_batch(&inputs, &mut out); + + // Snapshot MXCSR after. + let mut mxcsr_after: u32 = 0; + unsafe { + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack)); + } + + assert_eq!( + mxcsr_before, mxcsr_after, + "cast_f32_to_f16_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)", + mxcsr_before, mxcsr_after + ); + + // Same check for the upcast direction (`_mm256_cvtph_ps` can raise + // #I/#D on SNaN/denormal F16 input). + let f16_inputs: Vec = (0..8).map(|i| F16(0x7C01 + i as u16)).collect(); // SNaN-ish + let mut f32_out = vec![0.0f32; 8]; + + unsafe { + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack)); + } + cast_f16_to_f32_batch(&f16_inputs, &mut f32_out); + unsafe { + asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack)); + } + + assert_eq!( + mxcsr_before, mxcsr_after, + "cast_f16_to_f32_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)", + mxcsr_before, mxcsr_after + ); + } }