@@ -93,21 +93,21 @@ static void wsum_hess_init(expr *node)
9393 many numerical zeros in child->wsum_hess that are actually
9494 structural zeros, but we do not try to exploit that sparsity
9595 right now. */
96- CSR_Matrix * H_child = x -> wsum_hess ;
97- node -> wsum_hess = new_csr_matrix (H_child -> m , H_child -> n , H_child -> nnz );
98- memcpy (node -> wsum_hess -> p , H_child -> p , (H_child -> m + 1 ) * sizeof (int ));
99- memcpy (node -> wsum_hess -> i , H_child -> i , H_child -> nnz * sizeof (int ));
96+ CSR_Matrix * Hx = x -> wsum_hess ;
97+ node -> wsum_hess = new_csr_matrix (Hx -> m , Hx -> n , Hx -> nnz );
98+ memcpy (node -> wsum_hess -> p , Hx -> p , (Hx -> m + 1 ) * sizeof (int ));
99+ memcpy (node -> wsum_hess -> i , Hx -> i , Hx -> nnz * sizeof (int ));
100100}
101101
102102static void eval_wsum_hess (expr * node , const double * w )
103103{
104- expr * child = node -> left ;
104+ expr * x = node -> left ;
105105 index_expr * idx = (index_expr * ) node ;
106106
107107 if (idx -> has_duplicates )
108108 {
109109 /* zero and accumulate for repeated indices */
110- memset (node -> dwork , 0 , child -> size * sizeof (double ));
110+ memset (node -> dwork , 0 , x -> size * sizeof (double ));
111111 for (int i = 0 ; i < idx -> n_idxs ; i ++ )
112112 {
113113 node -> dwork [idx -> indices [i ]] += w [i ];
@@ -122,12 +122,9 @@ static void eval_wsum_hess(expr *node, const double *w)
122122 }
123123 }
124124
125- /* delegate to child */
126- child -> eval_wsum_hess (child , node -> dwork );
127-
128- /* copy values from child */
129- memcpy (node -> wsum_hess -> x , child -> wsum_hess -> x ,
130- child -> wsum_hess -> nnz * sizeof (double ));
125+ /* evalute hessian of child */
126+ x -> eval_wsum_hess (x , node -> dwork );
127+ memcpy (node -> wsum_hess -> x , x -> wsum_hess -> x , x -> wsum_hess -> nnz * sizeof (double ));
131128}
132129
133130static bool is_affine (const expr * node )
0 commit comments