From 8a859a30d2683f21bac96db46aec0e73acd65d23 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 20 May 2026 07:50:43 +0000 Subject: [PATCH] feat(pr-x2): generalize aos_to_soa / soa_to_aos to MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Worker A of PR-X2 (sequential, per .claude/knowledge/pr-x2-design.md § "Worker decomposition" line 458). Lifts the f32-only constraint that W3-W6 shipped, so downstream consumers with u8/u16/u64/i8 SoA fields (palette indices, BF16 carrier, CausalEdge64 mantissa, quantized weights) can use the public surface instead of rolling their own extract loop. Signature change (Option C per design § "Migration path"): // Pre-PR-X2: pub fn aos_to_soa(aos: &[T], extract: F) -> SoaVec where F: Fn(&T) -> [f32; N] // After PR-X2 (this commit): pub fn aos_to_soa(aos: &[T], extract: F) -> SoaVec where F: Fn(&T) -> [U; N] `soa_to_aos` mirrors the same generalisation and adds a `U: Copy` bound (needed to materialise the per-row `[U; N]` via `core::array::from_fn`). `SoaVec` itself was already generic over `T` so no internal changes were needed — the constraint was purely at the closure-helper signature layer. Caller migration (callers using return-type inference are unaffected): // Turbofish form gains one type param at position 2: aos_to_soa::<_, 3, _>(&aos, …) // was aos_to_soa::<_, _, 3, _>(&aos, …) // now (or `<_, f32, 3, _>` explicit) Updated callers: - src/hpc/soa.rs 4 inline test bodies (sed-rewrite) - src/hpc/bulk.rs 1 module doctest + 1 inline test body New tests in `src/hpc/soa.rs`: - aos_to_soa_u64_round_trip — 3-field u64, full range incl. u64::MAX - aos_to_soa_u8_round_trip — palette/alpha + soa_to_aos round-trip - aos_to_soa_u16_round_trip — BF16 carrier + soa_to_aos round-trip - aos_to_soa_inference_only — i8, no turbofish (closure ret-type) Updated doctests on aos_to_soa + soa_to_aos cover f32 (back-compat), u64 (CausalEdge64-style), and u8 (palette indices); module header "Element-type scope" rewritten to record the lift + migration note. Verified: cargo test -p ndarray --lib hpc::soa 33 passed cargo test --doc -p ndarray hpc::soa 13 passed cargo fmt --check clean cargo clippy --features approx,serde,rayon -- -D warnings clean Out of scope (Worker B of PR-X2): #[soa(pad_to_lanes=N)] field attribute on soa_struct! — separate commit per design § "Worker decomposition". --- src/hpc/bulk.rs | 4 +- src/hpc/soa.rs | 221 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 183 insertions(+), 42 deletions(-) diff --git a/src/hpc/bulk.rs b/src/hpc/bulk.rs index 8b97fef6..e0c3cb59 100644 --- a/src/hpc/bulk.rs +++ b/src/hpc/bulk.rs @@ -30,7 +30,7 @@ //! .map(|i| Item { a: i as f32, b: (i * 2) as f32, c: (i * 3) as f32 }) //! .collect(); //! bulk_apply(&mut items, 16, |chunk, _start| { -//! let soa = aos_to_soa::<_, 3, _>(chunk, |it| [it.a, it.b, it.c]); +//! let soa = aos_to_soa::<_, _, 3, _>(chunk, |it| [it.a, it.b, it.c]); //! // ... per-field SIMD-style loops over soa.field(0), soa.field(1), ... //! let _ = soa; //! }); @@ -315,7 +315,7 @@ mod tests { let mut chunk_count = 0; bulk_apply(&mut items, 16, |chunk, start_idx| { - let soa = aos_to_soa::<_, 3, _>(chunk, |it| [it.a, it.b, it.c]); + let soa = aos_to_soa::<_, _, 3, _>(chunk, |it| [it.a, it.b, it.c]); assert_eq!(soa.len(), chunk.len()); // First row of the chunk corresponds to absolute index start_idx. assert_eq!(soa.field(0)[0], start_idx as f32); diff --git a/src/hpc/soa.rs b/src/hpc/soa.rs index 0aebcc4e..2392c10e 100644 --- a/src/hpc/soa.rs +++ b/src/hpc/soa.rs @@ -17,22 +17,30 @@ //! Both shapes are SIMD-friendly storage layouts: each field is a //! contiguous `Vec`, so per-field SIMD loops iterate one `Vec`. //! -//! # Element-type scope (this PR) +//! # Element-type scope (PR-X2) //! -//! The macro and `SoaVec` are generic over `T`. The closure-based -//! conversion helpers ([`aos_to_soa`], [`soa_to_aos`]) are currently -//! **hardwired to `f32` output** (`SoaVec`). Downstream consumers -//! with `i8` / `u8` / `u16` / `bf16` SoA fields (palette indices, -//! quantized embeddings, BF16 mantissa bytes) must write their own -//! extract loop today; the public surface for generic-T conversion is -//! a follow-up. The macro itself supports any field type. +//! `SoaVec`, the `soa_struct!` macro, and the closure-based conversion +//! helpers [`aos_to_soa`] / [`soa_to_aos`] are **fully generic over the +//! element type `U`** (was f32-hardwired through W3-W6; PR-X2 lifted the +//! constraint). Common element types now flow through directly: +//! +//! - `f32` — Gaussian batch means, covariances (original W3-W6 case) +//! - `u64` — `CausalEdge64` mantissa cells, NARS evidence packs +//! - `u16` — BF16 carrier values, packed depth fields +//! - `u8` — palette indices, quantized embeddings +//! - `i8` — quantized weights with signed range +//! +//! Callers passing turbofish should now use four type params: +//! `aos_to_soa::<_, U, N, _>(...)` instead of the pre-PR-X2 form +//! `aos_to_soa::<_, N, _>(...)`. Callers using return-type inference are +//! unaffected by the generalisation. //! //! # Layering — why `hpc::soa` and not `simd_ops` //! //! `crate::simd_ops` is the SIMD-dispatch glue layer (every fn there //! dispatches through `F32x16` / `F64x8`). Per the W1a consumer contract //! at `.claude/knowledge/vertical-simd-consumer-contract.md`, free-function -//! shapes like `fn aos_to_soa(&[T], extract) -> SoaVec` belong +//! shapes like `fn aos_to_soa(&[T], extract) -> SoaVec` belong //! at the `crate::hpc` level, co-located with the data types they //! convert between. Putting pure-scalar helpers in `simd_ops` would //! contradict that module's charter and the W1a litmus that rejects @@ -370,30 +378,34 @@ macro_rules! soa_struct { }; } -/// Deinterleave an AoS slice into a [`SoaVec`] by extracting `N` field -/// values per item via the user-supplied `extract` closure. +/// Deinterleave an AoS slice into a [`SoaVec`] by extracting `N` +/// field values per item via the user-supplied `extract` closure. +/// +/// `U` is the element type of the resulting `SoaVec` — generic over all +/// `Copy` types. Common values: +/// - `f32` — Gaussian batch means, covariances (original W3-W6 use case) +/// - `u64` — `CausalEdge64` mantissa cells, NARS evidence packs +/// - `u16` — BF16 carrier values, packed depth fields +/// - `u8` — palette indices, quantized embeddings /// /// Scalar implementation. A future bench-justified wave may add per-arch -/// SIMD gather (VPGATHERDD on AVX-512, LD3/LD4 on NEON) for stride-known -/// dense layouts; the public API is forward-compatible — the dispatcher -/// will grow internal per-arch arms without changing this signature. +/// SIMD gather (VPGATHERDD on AVX-512, LD3/LD4 on NEON). The public +/// signature is forward-compatible — the dispatcher will grow internal +/// per-arch arms without changing this signature. /// -/// `T` need not be `Copy`; only the extracted `[f32; N]` row is -/// materialized. +/// `T` need not be `Copy`; only the extracted `[U; N]` row is materialised. /// /// # Inference /// -/// If the const-generic `N` fails to infer from the closure return type, -/// annotate either with a turbofish or a closure return-type ascription: +/// If `N` fails to infer from the closure return type, annotate via +/// turbofish (note: 4 type params now, was 3 in the f32-only era): /// /// ```ignore -/// aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]); -/// aos_to_soa(&aos, |it| -> [f32; 3] { [it.a, it.b, it.c] }); +/// aos_to_soa::<_, u64, 3, _>(&aos, |it| [it.a, it.b, it.c]); +/// aos_to_soa(&aos, |it| -> [u64; 3] { [it.a, it.b, it.c] }); /// ``` /// -/// (Verified on Rust 1.94.) -/// -/// # Example +/// # Example — f32 (backwards-compatible) /// /// ``` /// use ndarray::hpc::soa::aos_to_soa; @@ -402,32 +414,58 @@ macro_rules! soa_struct { /// Item { a: 1.0, b: 2.0, c: 3.0 }, /// Item { a: 4.0, b: 5.0, c: 6.0 }, /// ]; -/// let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]); +/// let soa = aos_to_soa::<_, f32, 3, _>(&aos, |it| [it.a, it.b, it.c]); /// assert_eq!(soa.field(0), &[1.0, 4.0]); /// assert_eq!(soa.field(1), &[2.0, 5.0]); /// assert_eq!(soa.field(2), &[3.0, 6.0]); /// ``` -pub fn aos_to_soa(aos: &[T], extract: F) -> SoaVec +/// +/// # Example — u64 (CausalEdge64-style) +/// +/// ``` +/// use ndarray::hpc::soa::aos_to_soa; +/// struct Edge { src: u64, dst: u64, weight: u64 } +/// let aos = vec![ +/// Edge { src: 1, dst: 2, weight: 10 }, +/// Edge { src: 3, dst: 4, weight: 20 }, +/// ]; +/// let soa = aos_to_soa::<_, u64, 3, _>(&aos, |e| [e.src, e.dst, e.weight]); +/// assert_eq!(soa.field(0), &[1u64, 3]); +/// assert_eq!(soa.field(2), &[10u64, 20]); +/// ``` +/// +/// # Example — u8 (palette indices) +/// +/// ``` +/// use ndarray::hpc::soa::aos_to_soa; +/// struct Cell { palette: u8, alpha: u8 } +/// let aos = vec![Cell { palette: 7, alpha: 255 }, Cell { palette: 3, alpha: 128 }]; +/// let soa = aos_to_soa::<_, u8, 2, _>(&aos, |c| [c.palette, c.alpha]); +/// assert_eq!(soa.field(0), &[7u8, 3]); +/// assert_eq!(soa.field(1), &[255u8, 128]); +/// ``` +pub fn aos_to_soa(aos: &[T], extract: F) -> SoaVec where - F: Fn(&T) -> [f32; N], + F: Fn(&T) -> [U; N], { - let mut soa = SoaVec::::with_capacity(aos.len()); + let mut soa = SoaVec::::with_capacity(aos.len()); for item in aos { soa.push(extract(item)); } soa } -/// Interleave a [`SoaVec`] into an AoS `Vec` by building each item +/// Interleave a [`SoaVec`] into an AoS `Vec` by building each item /// from the per-field values via the user-supplied `build` closure. /// -/// Scalar implementation. See [`aos_to_soa`] for the forward-compatible -/// note on future SIMD acceleration. +/// `U` is the element type of the input `SoaVec` (must be `Copy` so a +/// per-row `[U; N]` can be materialised by indexing). Scalar implementation; +/// the public signature is forward-compatible per [`aos_to_soa`]. /// /// Complexity: O(N·len) where N is the field count and len is the row /// count. /// -/// # Example +/// # Example — f32 (backwards-compatible) /// /// ``` /// use ndarray::hpc::soa::{aos_to_soa, soa_to_aos}; @@ -436,20 +474,34 @@ where /// Item { a: 1.0, b: 2.0, c: 3.0 }, /// Item { a: 4.0, b: 5.0, c: 6.0 }, /// ]; -/// let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]); +/// let soa = aos_to_soa::<_, f32, 3, _>(&aos, |it| [it.a, it.b, it.c]); /// let back: Vec = soa_to_aos(&soa, |[a, b, c]| Item { a, b, c }); /// assert_eq!(back[0].a, 1.0); /// assert_eq!(back[1].c, 6.0); /// ``` -pub fn soa_to_aos(soa: &SoaVec, build: F) -> Vec +/// +/// # Example — u16 (BF16 carrier) +/// +/// ``` +/// use ndarray::hpc::soa::{aos_to_soa, soa_to_aos}; +/// #[derive(Debug, PartialEq)] +/// struct Pair { lo: u16, hi: u16 } +/// let aos = vec![Pair { lo: 0x1234, hi: 0xABCD }, Pair { lo: 0x5678, hi: 0xEF01 }]; +/// let soa = aos_to_soa::<_, u16, 2, _>(&aos, |p| [p.lo, p.hi]); +/// let back: Vec = soa_to_aos(&soa, |[lo, hi]| Pair { lo, hi }); +/// assert_eq!(back[0], Pair { lo: 0x1234, hi: 0xABCD }); +/// assert_eq!(back[1], Pair { lo: 0x5678, hi: 0xEF01 }); +/// ``` +pub fn soa_to_aos(soa: &SoaVec, build: F) -> Vec where - F: Fn([f32; N]) -> T, + F: Fn([U; N]) -> T, + U: Copy, { let n = soa.len(); let fields = soa.all_fields(); let mut out = Vec::with_capacity(n); for i in 0..n { - let row: [f32; N] = core::array::from_fn(|k| fields[k][i]); + let row: [U; N] = core::array::from_fn(|k| fields[k][i]); out.push(build(row)); } out @@ -787,7 +839,7 @@ mod tests { #[test] fn aos_to_soa_n2_roundtrip() { let aos = vec![ItemN2 { a: 1.0, b: 2.0 }, ItemN2 { a: 3.0, b: 4.0 }, ItemN2 { a: 5.0, b: 6.0 }]; - let soa = aos_to_soa::<_, 2, _>(&aos, |it| [it.a, it.b]); + let soa = aos_to_soa::<_, _, 2, _>(&aos, |it| [it.a, it.b]); assert_eq!(soa.len(), 3); assert_eq!(soa.field(0), &[1.0, 3.0, 5.0]); assert_eq!(soa.field(1), &[2.0, 4.0, 6.0]); @@ -798,7 +850,7 @@ mod tests { #[test] fn aos_to_soa_n3_roundtrip() { let aos = vec![ItemN3 { a: 1.0, b: 2.0, c: 3.0 }, ItemN3 { a: 4.0, b: 5.0, c: 6.0 }]; - let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]); + let soa = aos_to_soa::<_, _, 3, _>(&aos, |it| [it.a, it.b, it.c]); assert_eq!(soa.field(0), &[1.0, 4.0]); assert_eq!(soa.field(1), &[2.0, 5.0]); assert_eq!(soa.field(2), &[3.0, 6.0]); @@ -828,7 +880,7 @@ mod tests { d: 12.0, }, ]; - let soa = aos_to_soa::<_, 4, _>(&aos, |it| [it.a, it.b, it.c, it.d]); + let soa = aos_to_soa::<_, _, 4, _>(&aos, |it| [it.a, it.b, it.c, it.d]); assert_eq!(soa.field(0), &[1.0, 5.0, 9.0]); assert_eq!(soa.field(1), &[2.0, 6.0, 10.0]); assert_eq!(soa.field(2), &[3.0, 7.0, 11.0]); @@ -840,7 +892,7 @@ mod tests { #[test] fn aos_to_soa_empty_input() { let aos: Vec = Vec::new(); - let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]); + let soa = aos_to_soa::<_, _, 3, _>(&aos, |it| [it.a, it.b, it.c]); assert!(soa.is_empty()); assert_eq!(soa.field(0), &[] as &[f32]); assert_eq!(soa.field(1), &[] as &[f32]); @@ -856,7 +908,7 @@ mod tests { // applied per row. let scale: f32 = 10.0; let aos = vec![ItemN2 { a: 1.0, b: 2.0 }, ItemN2 { a: 3.0, b: 4.0 }]; - let soa = aos_to_soa::<_, 2, _>(&aos, |it| [it.a * scale, it.b * scale]); + let soa = aos_to_soa::<_, _, 2, _>(&aos, |it| [it.a * scale, it.b * scale]); assert_eq!(soa.field(0), &[10.0, 30.0]); assert_eq!(soa.field(1), &[20.0, 40.0]); } @@ -867,4 +919,93 @@ mod tests { let back: Vec = soa_to_aos(&soa, |[a, b]| ItemN2 { a, b }); assert!(back.is_empty()); } + + // ------------------------------------------------------------------ + // PR-X2 — generic-U coverage (was f32-hardwired through W3-W6) + // ------------------------------------------------------------------ + + /// `aos_to_soa` over `u64` (CausalEdge64-style fields). + #[test] + fn aos_to_soa_u64_round_trip() { + struct Edge { + src: u64, + dst: u64, + weight: u64, + } + let aos = [ + Edge { + src: 1, + dst: 2, + weight: 10, + }, + Edge { + src: 3, + dst: 4, + weight: 20, + }, + Edge { + src: 0xDEAD_BEEF_CAFE_BABE, + dst: 0, + weight: u64::MAX, + }, + ]; + let soa = aos_to_soa::<_, u64, 3, _>(&aos, |e| [e.src, e.dst, e.weight]); + assert_eq!(soa.len(), 3); + assert_eq!(soa.field(0), &[1u64, 3, 0xDEAD_BEEF_CAFE_BABE]); + assert_eq!(soa.field(1), &[2u64, 4, 0]); + assert_eq!(soa.field(2), &[10u64, 20, u64::MAX]); + } + + /// `aos_to_soa` over `u8` (palette indices) plus `soa_to_aos` round-trip. + #[test] + fn aos_to_soa_u8_round_trip() { + #[derive(Debug, PartialEq, Eq)] + struct Cell { + palette: u8, + alpha: u8, + } + let aos = vec![Cell { palette: 7, alpha: 255 }, Cell { palette: 3, alpha: 128 }, Cell { palette: 0, alpha: 0 }]; + let soa = aos_to_soa::<_, u8, 2, _>(&aos, |c| [c.palette, c.alpha]); + assert_eq!(soa.field(0), &[7u8, 3, 0]); + assert_eq!(soa.field(1), &[255u8, 128, 0]); + + let back: Vec = soa_to_aos(&soa, |[palette, alpha]| Cell { palette, alpha }); + assert_eq!(back, aos); + } + + /// `aos_to_soa` over `u16` (BF16 carrier bytes). + #[test] + fn aos_to_soa_u16_round_trip() { + #[derive(Debug, PartialEq, Eq)] + struct Bf16Pair { + lo: u16, + hi: u16, + } + let aos = vec![ + Bf16Pair { lo: 0x1234, hi: 0xABCD }, + Bf16Pair { lo: 0x5678, hi: 0xEF01 }, + Bf16Pair { lo: 0xFFFF, hi: 0x0000 }, + ]; + let soa = aos_to_soa::<_, u16, 2, _>(&aos, |p| [p.lo, p.hi]); + assert_eq!(soa.field(0), &[0x1234u16, 0x5678, 0xFFFF]); + assert_eq!(soa.field(1), &[0xABCDu16, 0xEF01, 0x0000]); + + let back: Vec = soa_to_aos(&soa, |[lo, hi]| Bf16Pair { lo, hi }); + assert_eq!(back, aos); + } + + /// Inference-only entry: caller relies on closure return-type ascription, + /// no turbofish at all. + #[test] + fn aos_to_soa_inference_only() { + struct Triple { + a: i8, + b: i8, + c: i8, + } + let aos = [Triple { a: 1, b: 2, c: 3 }, Triple { a: -1, b: -2, c: -3 }]; + let soa = aos_to_soa(&aos, |t| -> [i8; 3] { [t.a, t.b, t.c] }); + assert_eq!(soa.field(0), &[1i8, -1]); + assert_eq!(soa.field(2), &[3i8, -3]); + } }