Skip to content

Commit ed77d23

Browse files
Transurgeonclaude
andcommitted
Add kron_left affine atom
Implements Z = kron(C, X) where C is a constant sparse matrix and X is an expression. Lives in src/atoms/affine/ alongside left_matmul, takes the same (param_node, u, ...) signature with param_node required to be NULL for now — the slot is wired into the struct so adding updatable parameter support later is a local change. All inner loops iterate only over the nonzeros of C (cached as active (i, j) tuples at construction), so C = I_m automatically collapses to O(m * p * q) work with no identity-detection code. Jacobian sparsity is built in two passes over the same active-tuple skeleton. wsum_hess inherits the child's sparsity and runs the same index pattern as an adjoint. Covered by six tests: forward (generic + identity), Jacobian (vector and matrix child), and weighted-sum Hessian (variable child + numerical diff against a composite exp child). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 140267c commit ed77d23

7 files changed

Lines changed: 673 additions & 0 deletions

File tree

include/atoms/affine.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
5858
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
5959
const double *data);
6060

61+
/* Kronecker product with constant on the left: Z = kron(C, u) where C is a
62+
* constant sparse matrix and u is a (p x q) expression. Output shape
63+
* (C->m * p, C->n * q). param_node must be NULL; the parameter path is
64+
* reserved for a future change. */
65+
expr *new_kron_left(expr *param_node, expr *u, const CSR_Matrix *C, int p,
66+
int q);
67+
6168
/* Scalar multiplication: a * f(x) where a comes from param_node */
6269
expr *new_scalar_mult(expr *param_node, expr *child);
6370

include/subexpr.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,30 @@ typedef struct matmul_expr
170170
int *idx_map_Hg;
171171
} matmul_expr;
172172

173+
/* Kronecker product with a constant on the left: Z = kron(C, X) where C is
174+
* a constant (m x n) sparse matrix and X is an expression of shape (p x q).
175+
* Output has shape (m*p, n*q). The atom is affine in X; the param_source
176+
* slot is reserved for a future update that makes C an updatable parameter.
177+
*
178+
* We cache the active entries of C (one per nonzero of C) so that all
179+
* inner loops run in O(nnz_C * p * q) rather than touching zero rows of
180+
* the output. This automatically collapses to O(m * p * q) when C = I_m,
181+
* with no special case in the code. */
182+
typedef struct kron_left_expr
183+
{
184+
expr base;
185+
CSR_Matrix *C; /* constant matrix, owned */
186+
int p, q; /* child shape (m, n are C->m, C->n) */
187+
/* active-entry tables (length C->nnz), filled in constructor */
188+
int n_active;
189+
int *active_i; /* row index i of each nonzero */
190+
int *active_j; /* col index j of each nonzero */
191+
int *active_idx; /* index into C->x */
192+
/* parameter slot (not wired up yet — param_source must be NULL) */
193+
expr *param_source;
194+
void (*refresh_param_values)(struct kron_left_expr *);
195+
} kron_left_expr;
196+
173197
/* Index/slicing: y = child[indices] where indices is a list of flat positions */
174198
typedef struct index_expr
175199
{

src/atoms/affine/kron_left.c

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
/*
2+
* Copyright 2026 Daniel Cederberg and William Zhang
3+
*
4+
* This file is part of the SparseDiffEngine project.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
#include "atoms/affine.h"
19+
#include "subexpr.h"
20+
#include "utils/tracked_alloc.h"
21+
#include <assert.h>
22+
#include <stdio.h>
23+
#include <stdlib.h>
24+
#include <string.h>
25+
26+
/* Kronecker product with constant on the left: Z = kron(C, X) where
27+
* C has shape (m, n) and is a constant sparse matrix,
28+
* X has shape (p, q) and is an expression.
29+
* Output Z has shape (m*p, n*q), stored column-major as vec(Z) of length
30+
* m*p*n*q.
31+
*
32+
* Key identity: Z[i*p+k, j*q+l] = C[i,j] * X[k,l].
33+
* In column-major: vec(Z)[r] with r = (j*q+l)*(m*p) + i*p + k
34+
* depends on vec(X)[s] with s = l*p + k and coefficient C[i,j].
35+
*
36+
* The atom is affine in X: each output row r (when C[i,j] != 0) is a
37+
* scaled copy of child row s of the child's Jacobian, and the weighted
38+
* Hessian inherits the child's sparsity with an adjoint accumulation
39+
* over the same index pattern.
40+
*
41+
* All inner loops iterate only over nonzeros of C (cached in the
42+
* active_i / active_j / active_idx tables at construction). No explicit
43+
* identity-detection is needed: for C = I_m, nnz_C == m and the work
44+
* naturally drops to O(m * p * q) without any special-case code. */
45+
46+
/* ------------------------------------------------------------------ */
47+
/* Forward pass */
48+
/* ------------------------------------------------------------------ */
49+
static void forward(expr *node, const double *u)
50+
{
51+
kron_left_expr *lnode = (kron_left_expr *) node;
52+
expr *child = node->left;
53+
CSR_Matrix *C = lnode->C;
54+
int p = lnode->p, q = lnode->q;
55+
int mp = C->m * p;
56+
57+
child->forward(child, u);
58+
59+
memset(node->value, 0, (size_t) node->size * sizeof(double));
60+
61+
/* For each nonzero C[i,j], scatter the (p x q) block cij * X into
62+
* position Z[i*p .. i*p+p-1, j*q .. j*q+q-1]. */
63+
for (int t = 0; t < lnode->n_active; t++)
64+
{
65+
int i = lnode->active_i[t];
66+
int j = lnode->active_j[t];
67+
double cij = C->x[lnode->active_idx[t]];
68+
for (int l = 0; l < q; l++)
69+
{
70+
int z_col_start = (j * q + l) * mp + i * p;
71+
int x_col_start = l * p;
72+
for (int k = 0; k < p; k++)
73+
{
74+
node->value[z_col_start + k] =
75+
cij * child->value[x_col_start + k];
76+
}
77+
}
78+
}
79+
}
80+
81+
/* ------------------------------------------------------------------ */
82+
/* Affine check */
83+
/* ------------------------------------------------------------------ */
84+
static bool is_affine(const expr *node)
85+
{
86+
return node->left->is_affine(node->left);
87+
}
88+
89+
/* ------------------------------------------------------------------ */
90+
/* Jacobian initialization */
91+
/* ------------------------------------------------------------------ */
92+
/* Two-pass construction over active C entries × (l, k):
93+
* pass 1 fills row_nnz[r] for every active output row,
94+
* pass 2 writes column indices into the already-allocated CSR.
95+
* Rows r that don't correspond to an active (i, j) stay at 0 nnz.
96+
*
97+
* Work: O(nnz_C * p * q * avg_nnz_per_Jchild_row). For C = I_m this is
98+
* O(m * p * q * avg_Jchild_row_nnz), i.e. a factor-of-n reduction vs a
99+
* naive iteration over every output row of Z. */
100+
static void jacobian_init_impl(expr *node)
101+
{
102+
kron_left_expr *lnode = (kron_left_expr *) node;
103+
expr *child = node->left;
104+
CSR_Matrix *C = lnode->C;
105+
int p = lnode->p, q = lnode->q;
106+
int mp = C->m * p;
107+
int out_size = node->size;
108+
109+
jacobian_init(child);
110+
CSR_Matrix *Jchild = child->jacobian;
111+
112+
/* Pass 1: row_nnz[r] = Jchild row-nnz for active r, else 0. */
113+
int *row_nnz = (int *) SP_CALLOC((size_t) out_size, sizeof(int));
114+
for (int t = 0; t < lnode->n_active; t++)
115+
{
116+
int i = lnode->active_i[t];
117+
int j = lnode->active_j[t];
118+
for (int l = 0; l < q; l++)
119+
{
120+
int r_col_base = (j * q + l) * mp + i * p;
121+
for (int k = 0; k < p; k++)
122+
{
123+
int s = l * p + k;
124+
row_nnz[r_col_base + k] = Jchild->p[s + 1] - Jchild->p[s];
125+
}
126+
}
127+
}
128+
129+
/* Cumulative sum into a local buffer; we'll memcpy into the
130+
* Jacobian's p[] after allocation. */
131+
int *Jp = (int *) SP_MALLOC((size_t) (out_size + 1) * sizeof(int));
132+
int total_nnz = 0;
133+
for (int r = 0; r < out_size; r++)
134+
{
135+
Jp[r] = total_nnz;
136+
total_nnz += row_nnz[r];
137+
}
138+
Jp[out_size] = total_nnz;
139+
free(row_nnz);
140+
141+
node->jacobian = new_csr_matrix(out_size, node->n_vars, total_nnz);
142+
memcpy(node->jacobian->p, Jp, (size_t) (out_size + 1) * sizeof(int));
143+
free(Jp);
144+
145+
/* Pass 2: column indices are a copy of the corresponding Jchild row. */
146+
for (int t = 0; t < lnode->n_active; t++)
147+
{
148+
int i = lnode->active_i[t];
149+
int j = lnode->active_j[t];
150+
for (int l = 0; l < q; l++)
151+
{
152+
int r_col_base = (j * q + l) * mp + i * p;
153+
for (int k = 0; k < p; k++)
154+
{
155+
int s = l * p + k;
156+
int r = r_col_base + k;
157+
int cs = Jchild->p[s];
158+
int row_nnz_r = Jchild->p[s + 1] - cs;
159+
int row_start = node->jacobian->p[r];
160+
memcpy(node->jacobian->i + row_start, Jchild->i + cs,
161+
(size_t) row_nnz_r * sizeof(int));
162+
}
163+
}
164+
}
165+
}
166+
167+
/* ------------------------------------------------------------------ */
168+
/* Jacobian evaluation */
169+
/* ------------------------------------------------------------------ */
170+
static void eval_jacobian(expr *node)
171+
{
172+
kron_left_expr *lnode = (kron_left_expr *) node;
173+
expr *child = node->left;
174+
CSR_Matrix *C = lnode->C;
175+
CSR_Matrix *Jchild = child->jacobian;
176+
CSR_Matrix *J = node->jacobian;
177+
int p = lnode->p, q = lnode->q;
178+
int mp = C->m * p;
179+
180+
child->eval_jacobian(child);
181+
182+
for (int t = 0; t < lnode->n_active; t++)
183+
{
184+
int i = lnode->active_i[t];
185+
int j = lnode->active_j[t];
186+
double cij = C->x[lnode->active_idx[t]];
187+
for (int l = 0; l < q; l++)
188+
{
189+
int r_col_base = (j * q + l) * mp + i * p;
190+
for (int k = 0; k < p; k++)
191+
{
192+
int s = l * p + k;
193+
int r = r_col_base + k;
194+
int cs = Jchild->p[s];
195+
int row_nnz_r = Jchild->p[s + 1] - cs;
196+
int row_start = J->p[r];
197+
for (int u = 0; u < row_nnz_r; u++)
198+
{
199+
J->x[row_start + u] = cij * Jchild->x[cs + u];
200+
}
201+
}
202+
}
203+
}
204+
}
205+
206+
/* ------------------------------------------------------------------ */
207+
/* Weighted-sum Hessian initialization */
208+
/* ------------------------------------------------------------------ */
209+
static void wsum_hess_init_impl(expr *node)
210+
{
211+
expr *child = node->left;
212+
213+
wsum_hess_init(child);
214+
215+
/* Linear in X: Hessian sparsity equals the child's. */
216+
node->wsum_hess = new_csr_copy_sparsity(child->wsum_hess);
217+
218+
/* Workspace for the reverse-mode weight vector passed down to child. */
219+
node->work->dwork = (double *) SP_MALLOC((size_t) child->size * sizeof(double));
220+
}
221+
222+
/* ------------------------------------------------------------------ */
223+
/* Weighted-sum Hessian evaluation */
224+
/* ------------------------------------------------------------------ */
225+
static void eval_wsum_hess(expr *node, const double *w)
226+
{
227+
kron_left_expr *lnode = (kron_left_expr *) node;
228+
expr *child = node->left;
229+
CSR_Matrix *C = lnode->C;
230+
int p = lnode->p, q = lnode->q;
231+
int mp = C->m * p;
232+
int child_size = child->size;
233+
double *w_child = node->work->dwork;
234+
235+
/* Adjoint of the forward pass: w_child[s] = sum_{(i,j,k,l): s=l*p+k}
236+
* C[i,j] * w[(j*q+l)*mp + i*p + k]. */
237+
memset(w_child, 0, (size_t) child_size * sizeof(double));
238+
for (int t = 0; t < lnode->n_active; t++)
239+
{
240+
int i = lnode->active_i[t];
241+
int j = lnode->active_j[t];
242+
double cij = C->x[lnode->active_idx[t]];
243+
for (int l = 0; l < q; l++)
244+
{
245+
int r_col_base = (j * q + l) * mp + i * p;
246+
for (int k = 0; k < p; k++)
247+
{
248+
int s = l * p + k;
249+
w_child[s] += cij * w[r_col_base + k];
250+
}
251+
}
252+
}
253+
254+
child->eval_wsum_hess(child, w_child);
255+
memcpy(node->wsum_hess->x, child->wsum_hess->x,
256+
(size_t) node->wsum_hess->nnz * sizeof(double));
257+
}
258+
259+
/* ------------------------------------------------------------------ */
260+
/* Cleanup */
261+
/* ------------------------------------------------------------------ */
262+
static void free_type_data(expr *node)
263+
{
264+
kron_left_expr *lnode = (kron_left_expr *) node;
265+
free_csr_matrix(lnode->C);
266+
free(lnode->active_i);
267+
free(lnode->active_j);
268+
free(lnode->active_idx);
269+
if (lnode->param_source != NULL)
270+
{
271+
free_expr(lnode->param_source);
272+
}
273+
lnode->C = NULL;
274+
lnode->active_i = NULL;
275+
lnode->active_j = NULL;
276+
lnode->active_idx = NULL;
277+
lnode->param_source = NULL;
278+
}
279+
280+
/* ------------------------------------------------------------------ */
281+
/* Constructor */
282+
/* ------------------------------------------------------------------ */
283+
expr *new_kron_left(expr *param_node, expr *u, const CSR_Matrix *C, int p, int q)
284+
{
285+
if (u->size != p * q)
286+
{
287+
fprintf(stderr,
288+
"Error in new_kron_left: child size %d != p*q = %d*%d = %d\n",
289+
u->size, p, q, p * q);
290+
exit(1);
291+
}
292+
293+
int m = C->m;
294+
int n = C->n;
295+
296+
kron_left_expr *lnode =
297+
(kron_left_expr *) SP_CALLOC(1, sizeof(kron_left_expr));
298+
expr *node = &lnode->base;
299+
init_expr(node, m * p, n * q, u->n_vars, forward, jacobian_init_impl,
300+
eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess,
301+
free_type_data);
302+
node->left = u;
303+
expr_retain(u);
304+
305+
lnode->p = p;
306+
lnode->q = q;
307+
lnode->C = new_csr(C);
308+
309+
/* Precompute active (i, j) tuples and their offset into C->x. */
310+
lnode->n_active = C->nnz;
311+
lnode->active_i = (int *) SP_MALLOC((size_t) C->nnz * sizeof(int));
312+
lnode->active_j = (int *) SP_MALLOC((size_t) C->nnz * sizeof(int));
313+
lnode->active_idx = (int *) SP_MALLOC((size_t) C->nnz * sizeof(int));
314+
int t = 0;
315+
for (int i = 0; i < m; i++)
316+
{
317+
for (int idx = C->p[i]; idx < C->p[i + 1]; idx++)
318+
{
319+
lnode->active_i[t] = i;
320+
lnode->active_j[t] = C->i[idx];
321+
lnode->active_idx[t] = idx;
322+
t++;
323+
}
324+
}
325+
assert(t == C->nnz);
326+
327+
/* Parameter slot is reserved but not yet wired up. */
328+
lnode->param_source = param_node;
329+
if (param_node != NULL)
330+
{
331+
fprintf(stderr, "Error in new_kron_left: parameter for kron C "
332+
"not supported yet\n");
333+
exit(1);
334+
}
335+
336+
return node;
337+
}

0 commit comments

Comments
 (0)