Skip to content

Commit aed2fb8

Browse files
committed
fixes beam test
1 parent d4ea7d6 commit aed2fb8

3 files changed

Lines changed: 15 additions & 18 deletions

File tree

src/affine/index.c

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,21 +93,21 @@ static void wsum_hess_init(expr *node)
9393
many numerical zeros in child->wsum_hess that are actually
9494
structural zeros, but we do not try to exploit that sparsity
9595
right now. */
96-
CSR_Matrix *H_child = x->wsum_hess;
97-
node->wsum_hess = new_csr_matrix(H_child->m, H_child->n, H_child->nnz);
98-
memcpy(node->wsum_hess->p, H_child->p, (H_child->m + 1) * sizeof(int));
99-
memcpy(node->wsum_hess->i, H_child->i, H_child->nnz * sizeof(int));
96+
CSR_Matrix *Hx = x->wsum_hess;
97+
node->wsum_hess = new_csr_matrix(Hx->m, Hx->n, Hx->nnz);
98+
memcpy(node->wsum_hess->p, Hx->p, (Hx->m + 1) * sizeof(int));
99+
memcpy(node->wsum_hess->i, Hx->i, Hx->nnz * sizeof(int));
100100
}
101101

102102
static void eval_wsum_hess(expr *node, const double *w)
103103
{
104-
expr *child = node->left;
104+
expr *x = node->left;
105105
index_expr *idx = (index_expr *) node;
106106

107107
if (idx->has_duplicates)
108108
{
109109
/* zero and accumulate for repeated indices */
110-
memset(node->dwork, 0, child->size * sizeof(double));
110+
memset(node->dwork, 0, x->size * sizeof(double));
111111
for (int i = 0; i < idx->n_idxs; i++)
112112
{
113113
node->dwork[idx->indices[i]] += w[i];
@@ -122,12 +122,9 @@ static void eval_wsum_hess(expr *node, const double *w)
122122
}
123123
}
124124

125-
/* delegate to child */
126-
child->eval_wsum_hess(child, node->dwork);
127-
128-
/* copy values from child */
129-
memcpy(node->wsum_hess->x, child->wsum_hess->x,
130-
child->wsum_hess->nnz * sizeof(double));
125+
/* evalute hessian of child */
126+
x->eval_wsum_hess(x, node->dwork);
127+
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
131128
}
132129

133130
static bool is_affine(const expr *node)

src/affine/promote.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ static void jacobian_init(expr *node)
2222
expr *x = node->left;
2323
x->jacobian_init(x);
2424

25-
// each output row copies the single row from child's jacobian
25+
/* each output row copies the single row from child's jacobian */
2626
int nnz = node->size * x->jacobian->nnz;
2727
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz);
2828

29-
// fill sparsity pattern
29+
/* fill sparsity pattern */
3030
CSR_Matrix *J = node->jacobian;
3131
J->nnz = 0;
3232
for (int row = 0; row < node->size; row++)

src/dnlp_diff_engine/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _convert_matmul(expr, children):
2929
# One of them should be a Constant, the other a variable expression
3030
left_arg, right_arg = expr.args
3131

32-
if isinstance(left_arg, cp.Constant):
32+
if left_arg.is_constant():
3333
# A @ f(x) -> left_matmul
3434
# TODO: why is this always dense? What's going on here?
3535
# we later convert it to csr....
@@ -46,7 +46,7 @@ def _convert_matmul(expr, children):
4646
m,
4747
n,
4848
)
49-
elif isinstance(right_arg, cp.Constant):
49+
elif right_arg.is_constant():
5050
# f(x) @ A -> right_matmul
5151

5252
# TODO: why is this always dense? What's going on here?
@@ -74,7 +74,7 @@ def _convert_multiply(expr, children):
7474
left_arg, right_arg = expr.args
7575

7676
# Check if left is a constant
77-
if isinstance(left_arg, cp.Constant):
77+
if left_arg.is_constant():
7878
value = np.asarray(left_arg.value, dtype=np.float64)
7979

8080
# Scalar constant
@@ -88,7 +88,7 @@ def _convert_multiply(expr, children):
8888
return _diffengine.make_const_vector_mult(children[1], vector)
8989

9090
# Check if right is a constant
91-
elif isinstance(right_arg, cp.Constant):
91+
elif right_arg.is_constant():
9292
value = np.asarray(right_arg.value, dtype=np.float64)
9393

9494
# Scalar constant

0 commit comments

Comments
 (0)