Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,16 @@ typedef struct expr
// ------------------------------------------------------------------------
double *value;
CSR_Matrix *jacobian;
CSC_Matrix *jacobian_csc;
int *csc_work; /* workspace for CSR-CSC conversion */

/* jacobian_csc_filled is only used for affine functions to avoid redundant
conversions. Could become relevant for non-affine functions if we start
supporting common subexpressions on the Python side. */
bool jacobian_csc_filled;
CSR_Matrix *wsum_hess;
CSR_Matrix *hess_term1; /* Jg^T D Jg workspace */
CSR_Matrix *hess_term2; /* child wsum_hess workspace */
forward_fn forward;
jacobian_init_fn jacobian_init;
wsum_hess_init_fn wsum_hess_init;
Expand All @@ -67,6 +76,7 @@ typedef struct expr
// other things
// ------------------------------------------------------------------------
is_affine_fn is_affine;
double *local_jac_diag; /* cached f'(g(x)) diagonal */
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
Expand All @@ -83,6 +93,10 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,

void free_expr(expr *node);

/* Initialize CSC form of the Jacobian from the CSR Jacobian.
* Must be called after jacobian_init. */
void jacobian_csc_init(expr *node);

/* Reference counting helpers */
void expr_retain(expr *node);

Expand Down
83 changes: 73 additions & 10 deletions src/elementwise_full_dom/common.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "elementwise_full_dom.h"
#include "subexpr.h"
#include "utils/CSC_Matrix.h"
#include "utils/CSR_Matrix.h"
#include "utils/CSR_sum.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
Expand All @@ -20,14 +22,14 @@ void jacobian_init_elementwise(expr *node)
}
node->jacobian->p[node->size] = node->size;
}
/* otherwise it will be a linear operator */
else
{
/* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
child->jacobian_init(child);
CSR_Matrix *Jg = child->jacobian;
node->jacobian = new_csr_matrix(Jg->m, Jg->n, Jg->nnz);
node->dwork = (double *) malloc(node->size * sizeof(double));
node->local_jac_diag = (double *) malloc(node->size * sizeof(double));

/* copy sparsity pattern of child */
memcpy(node->jacobian->p, Jg->p, sizeof(int) * (Jg->m + 1));
Expand All @@ -48,7 +50,8 @@ void eval_jacobian_elementwise(expr *node)
/* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
child->eval_jacobian(child);
CSR_Matrix *Jg = child->jacobian;
node->local_jacobian(node, node->dwork);
node->local_jacobian(node, node->local_jac_diag);
memcpy(node->dwork, node->local_jac_diag, node->size * sizeof(double));
diag_csr_mult_fill_values(node->dwork, Jg, node->jacobian);
}
}
Expand All @@ -59,7 +62,7 @@ void wsum_hess_init_elementwise(expr *node)
int id = child->var_id;
int i;

/* if the variable is a child*/
/* if the variable is a child */
if (id != NOT_A_VARIABLE)
{
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, node->size);
Expand All @@ -75,11 +78,38 @@ void wsum_hess_init_elementwise(expr *node)
node->wsum_hess->p[i] = node->size;
}
}
/* otherwise it will be a linear operator */
else
{
linear_op_expr *lin_child = (linear_op_expr *) child;
node->wsum_hess = ATA_alloc(lin_child->A_csc);
/* Hessian of h(x) = w^T f(g(x) is term1 + term 2 where
term1 = J_g^T @ D @ J_g with D = sum_i w_i Hf_i,
term2 = sum_i (J_f^T w)_i^T Hg_i.

For elementwise functions, D is diagonal. */
jacobian_csc_init(child);
CSC_Matrix *Jg = child->jacobian_csc;

if (child->is_affine(child))
{
node->wsum_hess = ATA_alloc(Jg);
}
else
{
/* term1: Jg^T @ D @ Jg */
node->hess_term1 = ATA_alloc(Jg);

/* term2: child's Hessian */
child->wsum_hess_init(child);
CSR_Matrix *Hg = child->wsum_hess;
node->hess_term2 = new_csr_matrix(Hg->m, Hg->n, Hg->nnz);
memcpy(node->hess_term2->p, Hg->p, (Hg->m + 1) * sizeof(int));
memcpy(node->hess_term2->i, Hg->i, Hg->nnz * sizeof(int));

/* wsum_hess = term1 + term2 */
int max_nnz = node->hess_term1->nnz + node->hess_term2->nnz;
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, max_nnz);
sum_csr_matrices_fill_sparsity(node->hess_term1, node->hess_term2,
node->wsum_hess);
}
}
}

Expand All @@ -93,10 +123,43 @@ void eval_wsum_hess_elementwise(expr *node, const double *w)
}
else
{
/* Child will be a linear operator */
linear_op_expr *lin_child = (linear_op_expr *) child;
node->local_wsum_hess(node, node->dwork, w);
ATDA_fill_values(lin_child->A_csc, node->dwork, node->wsum_hess);
if (child->is_affine(child))
{
if (!child->jacobian_csc_filled)
{
csr_to_csc_fill_values(child->jacobian, child->jacobian_csc,
child->csc_work);
child->jacobian_csc_filled = true;
}

node->local_wsum_hess(node, node->dwork, w);
ATDA_fill_values(child->jacobian_csc, node->dwork, node->wsum_hess);
}
else
{
/* refresh CSC jacobian values */
csr_to_csc_fill_values(child->jacobian, child->jacobian_csc,
child->csc_work);

/* term1: Jg^T @ D @ Jg */
node->local_wsum_hess(node, node->dwork, w);
ATDA_fill_values(child->jacobian_csc, node->dwork, node->hess_term1);

/* term2: child Hessian with weight Jf^T w */
memcpy(node->dwork, node->local_jac_diag, node->size * sizeof(double));
for (int k = 0; k < node->size; k++)
{
node->dwork[k] *= w[k];
}

child->eval_wsum_hess(child, node->dwork);
memcpy(node->hess_term2->x, child->wsum_hess->x,
child->wsum_hess->nnz * sizeof(double));

/* wsum_hess = term1 + term2 */
sum_csr_matrices_fill_values(node->hess_term1, node->hess_term2,
node->wsum_hess);
}
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/expr.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* limitations under the License.
*/
#include "expr.h"
#include "utils/CSC_Matrix.h"
#include "utils/int_double_pair.h"
#include <stdlib.h>
#include <string.h>
Expand All @@ -41,6 +42,12 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
node->free_type_data = free_type_data;
}

void jacobian_csc_init(expr *node)
{
node->csc_work = (int *) malloc(node->n_vars * sizeof(int));
node->jacobian_csc = csr_to_csc_fill_sparsity(node->jacobian, node->csc_work);
}

void free_expr(expr *node)
{
if (node == NULL) return;
Expand All @@ -63,8 +70,13 @@ void free_expr(expr *node)
/* free value array and jacobian */
free(node->value);
free_csr_matrix(node->jacobian);
free_csc_matrix(node->jacobian_csc);
free(node->csc_work);
free_csr_matrix(node->wsum_hess);
free_csr_matrix(node->hess_term1);
free_csr_matrix(node->hess_term2);
free(node->dwork);
free(node->local_jac_diag);
free(node->iwork);
node->value = NULL;
node->jacobian = NULL;
Expand Down
6 changes: 6 additions & 0 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include "wsum_hess/elementwise/test_trig.h"
#include "wsum_hess/elementwise/test_xexp.h"
#include "wsum_hess/test_broadcast.h"
#include "wsum_hess/test_chain_rule_wsum_hess.h"
#include "wsum_hess/test_const_scalar_mult.h"
#include "wsum_hess/test_const_vector_mult.h"
#include "wsum_hess/test_hstack.h"
Expand Down Expand Up @@ -259,6 +260,11 @@ int main(void)
mu_run_test(test_wsum_hess_trace_log_variable, tests_run);
mu_run_test(test_wsum_hess_trace_composite, tests_run);
mu_run_test(test_wsum_hess_transpose, tests_run);
mu_run_test(test_wsum_hess_exp_sum, tests_run);
mu_run_test(test_wsum_hess_exp_sum_mult, tests_run);
mu_run_test(test_wsum_hess_exp_sum_matmul, tests_run);
mu_run_test(test_wsum_hess_sin_sum_axis0_matmul, tests_run);
mu_run_test(test_wsum_hess_logistic_sum_axis0_matmul, tests_run);

printf("\n--- Utility Tests ---\n");
mu_run_test(test_cblas_ddot, tests_run);
Expand Down
100 changes: 100 additions & 0 deletions tests/wsum_hess/test_chain_rule_wsum_hess.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include "affine.h"
#include "bivariate.h"
#include "elementwise_full_dom.h"
#include "minunit.h"
#include "numerical_diff.h"

const char *test_wsum_hess_exp_sum(void)
{
double u_vals[3] = {1.0, 2.0, 3.0};
double w = 1.0;

expr *x = new_variable(3, 1, 0, 3);
expr *sum_x = new_sum(x, -1);
expr *exp_sum_x = new_exp(sum_x);

mu_assert("check_wsum_hess failed",
check_wsum_hess(exp_sum_x, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H));

free_expr(exp_sum_x);
return 0;
}

const char *test_wsum_hess_exp_sum_mult(void)
{
double u_vals[4] = {1.0, 2.0, 3.0, 4.0};
double w = 1.0;

expr *x = new_variable(2, 1, 0, 4);
expr *y = new_variable(2, 1, 2, 4);
expr *xy = new_elementwise_mult(x, y);
expr *sum_xy = new_sum(xy, -1);
expr *exp_sum_xy = new_exp(sum_xy);

mu_assert("check_wsum_hess failed",
check_wsum_hess(exp_sum_xy, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H));

free_expr(exp_sum_xy);
return 0;
}

const char *test_wsum_hess_exp_sum_matmul(void)
{
/* exp(sum(X @ Y)) where X is 2x3, Y is 3x2
* n_vars = 6 + 6 = 12 */
double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3};
double w = 1.0;

expr *X = new_variable(2, 3, 0, 12);
expr *Y = new_variable(3, 2, 6, 12);
expr *XY = new_matmul(X, Y);
expr *sum_XY = new_sum(XY, -1);
expr *exp_sum_XY = new_exp(sum_XY);

mu_assert("check_wsum_hess failed",
check_wsum_hess(exp_sum_XY, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H));

free_expr(exp_sum_XY);
return 0;
}

const char *test_wsum_hess_sin_sum_axis0_matmul(void)
{
/* sin(sum(X @ Y, axis=0)) where X is 2x3, Y is 3x2
* X@Y is 2x2, sum(axis=0) gives 1x2, sin gives 1x2
* n_vars = 6 + 6 = 12 */
double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3};
double w[2] = {1.0, 1.0};

expr *X = new_variable(2, 3, 0, 12);
expr *Y = new_variable(3, 2, 6, 12);
expr *XY = new_matmul(X, Y);
expr *sum_XY = new_sum(XY, 0);
expr *sin_sum_XY = new_sin(sum_XY);

mu_assert("check_wsum_hess failed",
check_wsum_hess(sin_sum_XY, u_vals, w, NUMERICAL_DIFF_DEFAULT_H));

free_expr(sin_sum_XY);
return 0;
}

const char *test_wsum_hess_logistic_sum_axis0_matmul(void)
{
/* logistic(sum(X @ Y, axis=0)) where X is 2x3, Y is 3x2
* n_vars = 6 + 6 = 12 */
double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3};
double w[2] = {1.0, 1.0};

expr *X = new_variable(2, 3, 0, 12);
expr *Y = new_variable(3, 2, 6, 12);
expr *XY = new_matmul(X, Y);
expr *sum_XY = new_sum(XY, 0);
expr *logistic_sum_XY = new_logistic(sum_XY);

mu_assert("check_wsum_hess failed",
check_wsum_hess(logistic_sum_XY, u_vals, w, NUMERICAL_DIFF_DEFAULT_H));

free_expr(logistic_sum_XY);
return 0;
}
Loading