Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/simd_avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,22 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) {
}
return;
}
// Middle tier: pure AVX-512F bit-shift (Skylake-X, Cascade Lake,
// Ice Lake-SP — all AVX-512F CPUs without the bf16 extension).
// BF16 → f32 is lossless: BF16 IS the upper 16 bits of f32, so
// `(bf16_u16 as u32) << 16` reinterpreted as f32 IS the exact
// value. Vectorized: one _mm512_cvtepu16_epi32 zero-extends 16
// u16 → 16 u32, one _mm512_slli_epi32::<16> shifts each lane left
// by 16, _mm512_castsi512_ps reinterprets the i32 bit pattern as
// f32. Three AVX-512F instructions per 16-lane chunk vs 16
// scalar shifts in the fallback below.
if is_x86_feature_detected!("avx512f") {
// SAFETY: feature detection confirmed avx512f.
unsafe {
convert_bf16_to_f32_avx512f(input, output);
}
return;
}
}

// Scalar fallback (all platforms, all CPUs)
Expand All @@ -2413,6 +2429,36 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) {
}
}

/// Pure-AVX-512F BF16 → f32 conversion. Bit-exact against
/// `bf16_to_f32_scalar` on every input — BF16 is `f32_bits >> 16`, so
/// the inverse `(bf16 as u32) << 16` reconstructed as f32 is exact.
///
/// 16-lane main loop via `_mm512_cvtepu16_epi32` (zero-extend) +
/// `_mm512_slli_epi32::<16>` (shift up) + `_mm512_castsi512_ps`
/// (bit-cast). Scalar tail for the last `n % 16` lanes.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn convert_bf16_to_f32_avx512f(input: &[u16], output: &mut [f32]) {
let n = input.len();
let mut i = 0usize;

// Main 16-wide loop.
while i + 16 <= n {
let raw256 = _mm256_loadu_si256(input.as_ptr().add(i) as *const __m256i);
let extended = _mm512_cvtepu16_epi32(raw256);
let shifted = _mm512_slli_epi32::<16>(extended);
let as_f32 = _mm512_castsi512_ps(shifted);
_mm512_storeu_ps(output.as_mut_ptr().add(i), as_f32);
i += 16;
}

// Scalar tail (0..15 remaining lanes).
while i < n {
*output.get_unchecked_mut(i) = bf16_to_f32_scalar(*input.get_unchecked(i));
i += 1;
}
}

/// Batch f32 → BF16 conversion: same pattern.
pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) {
assert!(output.len() >= input.len(), "output must be >= input length");
Expand Down Expand Up @@ -2707,6 +2753,55 @@ mod bf16_tests {
}
}

/// Direct test for the AVX-512F bit-shift BF16 → f32 arm, exercising
/// the path the dispatcher would skip when avx512bf16 is available.
/// Verifies bit-exact parity against the scalar reference across a
/// pathological corpus (subnormal, NaN, Inf, sign ±0, every exponent
/// boundary) and a 16-aligned-plus-tail length.
#[cfg(target_arch = "x86_64")]
#[test]
fn batch_bf16_to_f32_avx512f_matches_scalar() {
if !is_x86_feature_detected!("avx512f") {
eprintln!("avx512f not detected on this host; skipping");
return;
}
// Build a corpus: every bf16 value of interest. The dispatcher's
// 16-wide loop is what matters most; pick a non-aligned total so
// we also exercise the scalar tail.
let mut input: Vec<u16> = Vec::new();
// Sign × exponent × representative mantissa sweep
for sign in [0u16, 0x8000] {
for exp in 0..256u16 {
for &mant in &[0u16, 1, 0x40, 0x7F] {
input.push(sign | (exp << 7) | mant);
}
}
}
// Add 5 bytes of tail to land on a non-16-aligned length.
input.extend_from_slice(&[0x3F80, 0xBF80, 0x4000, 0xC000, 0x7F80]);

let mut output = vec![0.0f32; input.len()];
// SAFETY: avx512f confirmed above.
unsafe { convert_bf16_to_f32_avx512f(&input, &mut output) };

for (i, &bf16) in input.iter().enumerate() {
let expected = bf16_to_f32_scalar(bf16);
// BF16 → f32 is lossless: bits must be byte-equal (incl. NaN
// payloads).
assert_eq!(
output[i].to_bits(),
expected.to_bits(),
"mismatch at index {} (bf16=0x{:04x}): got {} (0x{:08x}) vs {} (0x{:08x})",
i,
bf16,
output[i],
output[i].to_bits(),
expected,
expected.to_bits()
);
}
}

// ─────────────────────────────────────────────────────────────────────
// RNE certification tests — byte-equality with `_mm512_cvtneps_pbh`.
// ─────────────────────────────────────────────────────────────────────
Expand Down
229 changes: 220 additions & 9 deletions src/simd_half.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,31 @@ pub fn cast_bf16_to_f32_batch(src: &[BF16], dst: &mut [f32]) {

/// Batch convert F16 → f32.
///
/// Uses F16x16 for chunks of 16, scalar tail for remainder.
/// On x86_64 with F16C (every CPU from Ivy Bridge 2013 / Piledriver 2012
/// onward), dispatches to `_mm256_cvtph_ps` — one hardware instruction
/// converts 8 F16 lanes to 8 F32 lanes, IEEE-754 exact. The scalar
/// fallback uses the bit-fiddle [`F16::to_f32`] which is also IEEE-754
/// exact, just slower.
pub fn cast_f16_to_f32_batch(src: &[F16], dst: &mut [f32]) {
let n = src.len().min(dst.len());
let chunks = n / 16;
for c in 0..chunks {
let off = c * 16;
let v = F16x16::from_slice(&src[off..]);
let f = v.to_f32x16();
dst[off..off + 16].copy_from_slice(&f);

#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("f16c") && std::is_x86_feature_detected!("avx") {
// SAFETY: `F16` is `#[repr(transparent)] struct F16(pub u16)`
// (per `hpc::quantized::F16`). Slice reinterpretation is
// bit-pattern preserving. Runtime feature detection above
// confirms F16C + AVX before calling the target-feature fn.
let src_u16: &[u16] = unsafe { core::slice::from_raw_parts(src.as_ptr() as *const u16, src.len()) };
unsafe {
cast_f16_to_f32_batch_f16c(&src_u16[..n], &mut dst[..n]);
}
return;
}
}
// Scalar tail
for i in (chunks * 16)..n {

// Scalar fallback (non-x86_64 or pre-F16C silicon).
for i in 0..n {
dst[i] = src[i].to_f32();
}
}
Expand All @@ -376,13 +389,134 @@ pub fn cast_f32_to_bf16_batch(src: &[f32], dst: &mut [BF16]) {
}

/// Batch convert f32 → F16 (round-to-nearest-even).
///
/// On x86_64 with F16C, dispatches to `_mm256_cvtps_ph::<8>` (RNE,
/// no exceptions) — one hardware instruction converts 8 F32 lanes to
/// 8 F16 lanes with IEEE 754 round-to-nearest-even. Scalar fallback
/// uses [`F16::from_f32_rounded`] which matches the IEEE 754 RNE rule
/// bit-for-bit on every input (including subnormal / NaN / Inf).
pub fn cast_f32_to_f16_batch(src: &[f32], dst: &mut [F16]) {
let n = src.len().min(dst.len());

#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("f16c") && std::is_x86_feature_detected!("avx") {
// SAFETY: same as cast_f16_to_f32_batch — `F16` is
// repr(transparent) over u16; runtime feature gate ensures
// F16C is present.
let dst_u16: &mut [u16] =
unsafe { core::slice::from_raw_parts_mut(dst.as_mut_ptr() as *mut u16, dst.len()) };
unsafe {
cast_f32_to_f16_batch_f16c(&src[..n], &mut dst_u16[..n]);
}
return;
}
}

for i in 0..n {
dst[i] = F16::from_f32_rounded(src[i]);
}
}

/// F16C-vectorized F16 → f32 batch.
///
/// 8 F16 lanes per `_mm256_cvtph_ps` instruction (one xmm load + one
/// ymm store). Scalar tail handles the remaining `n % 8` lanes via the
/// bit-fiddle reference. **F16C result is bit-identical to the scalar
/// reference per IEEE 754 binary16 → binary32 spec** (lossless widening,
/// no rounding possible).
///
/// # MXCSR preservation
/// `_mm256_cvtph_ps` may raise `#I` (Invalid: SNaN input) or `#D`
/// (Denormal) — setting bits in MXCSR that the scalar bit-fiddle
/// reference [`F16::to_f32`] does not touch. To preserve the scalar
/// path's contract of "no observable FP control/status side effects,"
/// the MXCSR is saved before the SIMD region and restored after. Net
/// effect: callers see no MXCSR change vs. the scalar path. (See
/// codex review on PR #183.)
///
/// # Safety
/// Caller must have feature-detected `f16c` + `avx` at runtime.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "f16c,avx")]
unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) {
use core::arch::asm;
use core::arch::x86_64::{__m128i, _mm256_cvtph_ps, _mm256_storeu_ps, _mm_loadu_si128};
let mut saved_mxcsr: u32 = 0;
// SAFETY: STMXCSR writes the 32-bit MXCSR control/status register
// to the provided memory location; available on any SSE host
// (baseline x86_64).
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack));
let n = src.len().min(dst.len());
let chunks = n / 8;
for c in 0..chunks {
let off = c * 8;
let h = _mm_loadu_si128(src.as_ptr().add(off) as *const __m128i);
let f = _mm256_cvtph_ps(h);
_mm256_storeu_ps(dst.as_mut_ptr().add(off), f);
}
// Scalar tail (0..7 remaining lanes).
for i in (chunks * 8)..n {
dst[i] = F16(src[i]).to_f32();
}
// SAFETY: LDMXCSR reads the value we saved at the top — preserves
// every bit of the original MXCSR (rounding mode, exception masks,
// flush-to-zero etc.), clearing any exception flags the SIMD path
// may have set.
asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly));
}

/// F16C-vectorized f32 → F16 batch with IEEE 754 RNE rounding.
///
/// 8 F32 lanes per `_mm256_cvtps_ph::<0>` instruction (one ymm load +
/// one xmm store). The const `IMM8 = 0` selects
/// `_MM_FROUND_TO_NEAREST_INT` — round-to-nearest-even, matches the
/// scalar reference [`F16::from_f32_rounded`] bit-for-bit on every
/// input.
///
/// # IMM8 encoding limit
/// `_mm256_cvtps_ph`'s `IMM8` is 3 bits wide (`static_assert_uimm_bits!
/// (IMM8, 3)` in the Rust stdarch wrapper). Valid values are `0..=3`
/// (the four rounding modes — RNE, down, up, truncate). Bits 2-3 of
/// the underlying VCVTPS2PH IMM8 encoding are "reserved" and "select
/// MXCSR.RM" per Intel SDM — NOT `_MM_FROUND_NO_EXC`, which is an
/// AVX-512 convention (`_mm512_cvtps_ph` accepts `NO_EXC`, F16C does
/// not). Exception suppression is handled at the MXCSR level (below).
///
/// # MXCSR preservation
/// `_mm256_cvtps_ph` may raise `#O` (Overflow), `#U` (Underflow),
/// `#P` (Precision), `#I` (Invalid for SNaN), `#D` (Denormal). The
/// scalar reference [`F16::from_f32_rounded`] is pure bit
/// manipulation and never touches MXCSR. We save/restore MXCSR around
/// the SIMD region so callers see no observable control/status side
/// effects regardless of input data. (See codex review on PR #183.)
///
/// # Safety
/// Caller must have feature-detected `f16c` + `avx` at runtime.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "f16c,avx")]
unsafe fn cast_f32_to_f16_batch_f16c(src: &[f32], dst: &mut [u16]) {
use core::arch::asm;
use core::arch::x86_64::{__m128i, _mm256_cvtps_ph, _mm256_loadu_ps, _mm_storeu_si128};
let mut saved_mxcsr: u32 = 0;
// SAFETY: STMXCSR writes the 32-bit MXCSR; baseline SSE op.
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack));
let n = src.len().min(dst.len());
let chunks = n / 8;
for c in 0..chunks {
let off = c * 8;
let f = _mm256_loadu_ps(src.as_ptr().add(off));
let h = _mm256_cvtps_ph::<0>(f);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use no-exception rounding mode for F16C downcast

cast_f32_to_f16_batch_f16c currently uses _mm256_cvtps_ph::<0>, which performs round-to-nearest-even but does not request exception suppression. This means conversions of NaN/Inf/overflow/underflow inputs can set MXCSR exception flags (and can trap if FP exceptions are unmasked), which is a behavior change from the previous pure bit-manipulation scalar path and contradicts the function-level contract that says “no exceptions.” Use the _MM_FROUND_NO_EXC variant (imm8 with bit 3 set) to preserve non-trapping behavior.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks — the underlying concern (the F16C path can set MXCSR flags that the scalar bit-fiddle path doesn't) is valid. The proposed fix isn't quite right though: _mm256_cvtps_ph's IMM8 is constrained to 3 bits in Rust stdarch (static_assert_uimm_bits!(IMM8, 3) — fails to compile for IMM8 = 8), and in the underlying VCVTPS2PH spec bit 3 selects MXCSR.RM not _MM_FROUND_NO_EXC (NO_EXC is an AVX-512 convention; F16C predates the SAE family). Only valid IMM8 values here are 0..=3 (the four rounding modes).

The right fix is MXCSR save/restore via inline asm!(stmxcsr/ldmxcsr) — landed in 1a73c37. STMXCSR before the SIMD region, LDMXCSR after, preserves every bit of the saved control/status word including the exception flags the SIMD path may have set. Net effect: callers observe zero MXCSR change vs. the scalar path. Inline asm rather than _mm_getcsr/_mm_setcsr because those wrappers are deprecated on Rust 1.95 stable (unsoundness across thread MXCSR visibility; the deprecation notice explicitly recommends inline asm).

Same fix applied to cast_f16_to_f32_batch_f16c since _mm256_cvtph_ps can also raise #I/#D on SNaN/denormal F16 inputs. New test f16c_cast_preserves_mxcsr exercises both directions with inputs that trigger every relevant exception (overflow/underflow/precision/invalid/denormal); snapshots MXCSR before and after via stmxcsr, asserts byte-equal. Test passes.

This fix preserves the MXCSR FLAG state. It does not prevent traps when the caller has unmasked FP exceptions before invoking us — those would fire from the SIMD ops themselves and bypass our restore. That's the same trap behaviour as any plain a + b on overflow-prone f32, and the default OS-set MXCSR has all exception masks set so it's a non-issue for the common case.


Generated by Claude Code

_mm_storeu_si128(dst.as_mut_ptr().add(off) as *mut __m128i, h);
}
// Scalar tail.
for i in (chunks * 8)..n {
dst[i] = F16::from_f32_rounded(src[i]).0;
}
// SAFETY: LDMXCSR restores the saved value bit-for-bit.
asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly));
}

// ============================================================================
// Tests
// ============================================================================
Expand Down Expand Up @@ -759,4 +893,81 @@ mod tests {
assert_eq!(dst[i], expected[i], "mul_f16_inplace mismatch at {}", i);
}
}

/// Codex PR #183 P2: F16C `_mm256_cvtps_ph` may raise FP exceptions
/// (#O on overflow, #U on underflow, #P on precision loss, #I on
/// SNaN, #D on denormal input) which set bits in MXCSR. The scalar
/// path is pure bit manipulation and never touches MXCSR. The fix:
/// `cast_f32_to_f16_batch_f16c` saves MXCSR via STMXCSR before the
/// SIMD region and restores it via LDMXCSR after. This test feeds
/// inputs that should trigger every exception bit and asserts
/// MXCSR is byte-identical before vs. after the call.
#[cfg(target_arch = "x86_64")]
#[test]
fn f16c_cast_preserves_mxcsr() {
if !std::is_x86_feature_detected!("f16c") {
eprintln!("f16c not detected; skipping");
return;
}
use core::arch::asm;

// Inputs designed to trigger #O / #U / #P / #I / #D in F16C
// downcast:
// - 1e30, -1e30 : overflow (out of F16 range ±65504) → #O
// - 1e-30 : underflow / denormal → #U, #D, #P
// - 1.0/3.0 : precision loss → #P
// - f32::NAN : invalid (if it's an sNaN representation) → #I
let inputs: Vec<f32> = vec![
1e30,
-1e30,
1e-30,
1.0 / 3.0,
f32::NAN,
f32::INFINITY,
0.0,
1.0,
// Pad to 8 lanes so the SIMD chunk loop fires once with no tail.
];
assert_eq!(inputs.len(), 8);
let mut out = vec![F16::ZERO; 8];

// Snapshot MXCSR before.
let mut mxcsr_before: u32 = 0;
unsafe {
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack));
}

cast_f32_to_f16_batch(&inputs, &mut out);

// Snapshot MXCSR after.
let mut mxcsr_after: u32 = 0;
unsafe {
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack));
}

assert_eq!(
mxcsr_before, mxcsr_after,
"cast_f32_to_f16_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)",
mxcsr_before, mxcsr_after
);

// Same check for the upcast direction (`_mm256_cvtph_ps` can raise
// #I/#D on SNaN/denormal F16 input).
let f16_inputs: Vec<F16> = (0..8).map(|i| F16(0x7C01 + i as u16)).collect(); // SNaN-ish
let mut f32_out = vec![0.0f32; 8];

unsafe {
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack));
}
cast_f16_to_f32_batch(&f16_inputs, &mut f32_out);
unsafe {
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack));
}

assert_eq!(
mxcsr_before, mxcsr_after,
"cast_f16_to_f32_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)",
mxcsr_before, mxcsr_after
);
}
}
Loading