Skip to content

Commit 3f35170

Browse files
authored
Merge pull request #154 from AdaWorldAPI/claude/w2-hpc-arrayview-conversion
W2: hpc kernel layer ArrayView-first conversion (in-place rename, 32 fns)
2 parents ab20d11 + 66a8a81 commit 3f35170

8 files changed

Lines changed: 1942 additions & 2327 deletions

File tree

.claude/knowledge/w2-arrayview-migration.md

Lines changed: 387 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# W2-3 + W2-4 Audit — BLAS levels + statistics ArrayView compliance
2+
3+
## Verdict
4+
**CLEAN.** No follow-up wave needed for `blas_level{1,2,3}.rs` or `statistics.rs`. All four files are already ArrayView-shaped via trait impls on `ArrayBase`.
5+
6+
## Per-file findings
7+
8+
### `src/hpc/blas_level1.rs`
9+
- Trait impl on ArrayBase: **yes, L47**`impl<A, S> BlasLevel1<A> for ArrayBase<S, Ix1>`
10+
- Bonus trait impls (not in the original migration doc, but clean): `ScalarArith` (L196), `VecArith` (L242) — both on `ArrayBase<S, Ix1>`
11+
- Slice-taking pub fns: **1**`blas_rotg` (L152). **OK-as-is**: signature is `(a: A, b: A)` (scalars), not slices. The regex `^pub fn .*&\[` matched a `&[` in the doc-comment example, not the signature.
12+
- `axis_iter` misuse: **0**
13+
- Bridge pattern: verified present in trait methods — `blas_dot`, `blas_axpy`, `blas_scal`, `blas_nrm2`, `blas_asum` all dispatch through `as_slice()` hot path + stride-aware cold path.
14+
15+
### `src/hpc/blas_level2.rs`
16+
- Trait impl on ArrayBase: **yes, L97**`impl<A, S> BlasLevel2<A> for ArrayBase<S, Ix2>`
17+
- Slice-taking pub fns: **0**
18+
- `axis_iter` misuse: **0**
19+
20+
### `src/hpc/blas_level3.rs`
21+
- Trait impl on ArrayBase: **yes, L59**`impl<A, S> BlasLevel3<A> for ArrayBase<S, Ix2>`
22+
- Slice-taking pub fns: **0**
23+
- `axis_iter` misuse: **0**
24+
25+
### `src/hpc/statistics.rs`
26+
- Trait impl on ArrayBase: **yes, L65**`impl<A, S, D> Statistics<A> for ArrayBase<S, D>` (note: generic-D, unlike BLAS L1/L2/L3 which fix `Ix1`/`Ix2`)
27+
- Slice-taking pub fns: **0**
28+
- `axis_iter` misuse: **0**
29+
30+
## Build verification
31+
`cargo check -p ndarray --no-default-features --features std` → clean (31.82s, no warnings).
32+
33+
## Surprises
34+
- `blas_level1.rs` carries two extra trait impls (`ScalarArith`, `VecArith`) on `ArrayBase<S, Ix1>` beyond `BlasLevel1` itself. Not mentioned in the original migration doc but clean and consistent with the two-layer rule.
35+
- `blas_rotg` regex match was a false positive (doc-comment `&[` in an example, not in the signature).
36+
37+
## Follow-up needed
38+
**None.** W2-3 and W2-4 require no code changes. The W2 sprint scope reduces to the three converter waves: W2-1 (reductions), W2-2a (vml), W2-2b (activations).

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ target/
77

88
# Apple details
99
**/.DS_Store
10+
11+
# Claude Code: agent isolation worktrees (temporary, per-agent)
12+
.claude/worktrees/

crates/burn/src/ops/activation.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ where
2727
if view.is_standard_layout() {
2828
if let Some(input) = view.as_slice() {
2929
let mut output = alloc::vec![0.0f32; input.len()];
30-
ndarray::hpc::activations::sigmoid_f32(input, &mut output);
30+
let in_view = ndarray::ArrayView::from(input);
31+
let out_view = ndarray::ArrayViewMut::from(&mut output[..]);
32+
ndarray::hpc::activations::sigmoid_f32(in_view, out_view);
3133
let shape: alloc::vec::Vec<usize> = view.shape().to_vec();
3234
let array = ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), output)
3335
.expect("sigmoid output shape mismatch");

crates/burn/src/ops/tensor.rs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,30 @@ use libm::erf;
3434

3535
/// Try to accelerate a unary f32 operation via ndarray's hpc::vml (F32x16 SIMD).
3636
///
37-
/// VML signature: `fn(input: &[f32], output: &mut [f32])`.
38-
/// Uses crate::simd::F32x16 internally. Consumer never sees hardware details.
37+
/// VML signature (post W2-2a): generic over dimension, takes
38+
/// `ArrayView<f32, D> / ArrayViewMut<f32, D>`. We pass the dyn-D views from
39+
/// the burn tensor directly; ndarray's vml routes to the F32x16 SIMD
40+
/// primitive on the contiguous hot path and falls back to a stride-aware
41+
/// `Zip` on the cold path. Consumer never sees hardware details.
3942
#[cfg(feature = "simd")]
4043
fn try_vml_unary(
4144
tensor: NdArrayTensor,
42-
vml_fn: fn(&[f32], &mut [f32]),
45+
vml_fn: fn(ndarray::ArrayView<'_, f32, ndarray::IxDyn>, ndarray::ArrayViewMut<'_, f32, ndarray::IxDyn>),
4346
) -> Result<NdArrayTensor, NdArrayTensor> {
4447
if let NdArrayTensor::F32(storage) = tensor {
4548
let shared = storage.into_shared();
4649
if shared.is_standard_layout() {
47-
if let Some(input) = shared.as_slice() {
48-
let mut output = vec![0.0f32; input.len()];
49-
vml_fn(input, &mut output);
50-
let shape = shared.shape().to_vec();
51-
let array = ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), output)
52-
.expect("vml output shape mismatch");
53-
return Ok(NdArrayTensor::F32(
54-
crate::NdArrayStorage::Owned(array.into_shared()),
55-
));
56-
}
50+
let shape = shared.shape().to_vec();
51+
let len = shared.len();
52+
let mut output = ndarray::Array::from_shape_vec(
53+
ndarray::IxDyn(&shape),
54+
vec![0.0f32; len],
55+
)
56+
.expect("vml output shape mismatch");
57+
vml_fn(shared.view(), output.view_mut());
58+
return Ok(NdArrayTensor::F32(
59+
crate::NdArrayStorage::Owned(output.into_shared()),
60+
));
5761
}
5862
return Err(NdArrayTensor::F32(crate::NdArrayStorage::Owned(shared)));
5963
}

0 commit comments

Comments
 (0)