Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ typedef struct parameter_expr
{
expr base;
int param_id;
bool has_been_refreshed;
/* Set to true by problem_update_params(), cleared by
refresh_param_values() after propagating new values. */
bool needs_refresh;
} parameter_expr;

/* Linear operator: y = A * x + b
Expand Down
2 changes: 2 additions & 0 deletions include/utils/dense_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ Matrix *new_dense_matrix(int m, int n, const double *data);
/* Transpose helper */
Matrix *dense_matrix_trans(const Dense_Matrix *self);

void A_transpose(double *AT, const double *A, int m, int n);

#endif /* DENSE_MATRIX_H */
36 changes: 19 additions & 17 deletions src/atoms/affine/left_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ static void refresh_param_values(left_matmul_expr *lnode)
return;
}
parameter_expr *param = (parameter_expr *) lnode->param_source;
if (param->has_been_refreshed)
if (!param->needs_refresh)
{
return;
}
param->has_been_refreshed = true;
param->needs_refresh = false;
lnode->refresh_param_values(lnode);
}

Expand Down Expand Up @@ -168,28 +168,25 @@ static void eval_wsum_hess(expr *node, const double *w)

static void refresh_sparse_left(left_matmul_expr *lnode)
{
Sparse_Matrix *sm_A = (Sparse_Matrix *) lnode->A;
Sparse_Matrix *sm_AT = (Sparse_Matrix *) lnode->AT;
lnode->A->update_values(lnode->A, lnode->param_source->value);
/* Recompute AT values from A */
AT_fill_values(sm_A->csr, sm_AT->csr, lnode->base.work->iwork);
(void) lnode;
fprintf(stderr,
"Error in refresh_sparse_left: parameter for a sparse matrix not "
"supported \n");
exit(1);
}

static void refresh_dense_left(left_matmul_expr *lnode)
{
Dense_Matrix *dm_A = (Dense_Matrix *) lnode->A;
Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT;
int m = dm_A->base.m;
int n = dm_A->base.n;
lnode->A->update_values(lnode->A, lnode->param_source->value);
/* Recompute AT data (transpose of row-major A) */
Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT;
for (int i = 0; i < m; i++)
{
for (int j = 0; j < n; j++)
{
dm_AT->x[j * m + i] = dm_A->x[i * n + j];
}
}

/* The parameter represents the A in left_matmul_dense(A, x) in column-major.
In this diffengine, we store A in row-major order. Hence, param->vals
actually corresponds to the transpose of A, and we transpose AT to get A. */
memcpy(dm_AT->x, lnode->param_source->value, m * n * sizeof(double));
A_transpose(dm_A->x, dm_AT->x, n, m);
}

expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
Expand Down Expand Up @@ -243,6 +240,11 @@ expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
lnode->param_source = param_node;
if (param_node != NULL)
{

fprintf(stderr, "Error in new_left_matmul: parameter for a sparse matrix "
"not supported \n");
exit(1);

expr_retain(param_node);
lnode->refresh_param_values = refresh_sparse_left;
}
Expand Down
2 changes: 1 addition & 1 deletion src/atoms/affine/parameter.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);

pnode->param_id = param_id;
pnode->has_been_refreshed = false;
pnode->needs_refresh = false;

if (values != NULL)
{
Expand Down
54 changes: 25 additions & 29 deletions src/atoms/affine/right_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "utils/CSR_Matrix.h"
#include "utils/dense_matrix.h"
#include "utils/tracked_alloc.h"
#include <stdio.h>
#include <stdlib.h>

/* This file implements the atom 'right_matmul' corresponding to the operation y =
Expand All @@ -38,30 +39,27 @@
So: update lnode->AT from param values, then recompute lnode->A. */
static void refresh_sparse_right(left_matmul_expr *lnode)
{
Sparse_Matrix *sm_AT_inner = (Sparse_Matrix *) lnode->A;
Sparse_Matrix *sm_A_inner = (Sparse_Matrix *) lnode->AT;
/* lnode->AT holds the original A; update its values from param */
lnode->AT->update_values(lnode->AT, lnode->param_source->value);
/* Recompute A^T (lnode->A) from A (lnode->AT) */
AT_fill_values(sm_A_inner->csr, sm_AT_inner->csr, lnode->base.work->iwork);
(void) lnode;
fprintf(stderr,
"Error in refresh_sparse_right: parameter for a sparse matrix not "
"supported \n");
exit(1);
}

static void refresh_dense_right(left_matmul_expr *lnode)
{
Dense_Matrix *dm_AT_inner = (Dense_Matrix *) lnode->A;
Dense_Matrix *dm_A_inner = (Dense_Matrix *) lnode->AT;
int m_orig = dm_A_inner->base.m; /* original A is m x n */
int n_orig = dm_A_inner->base.n;
/* Update original A (inner's AT) from param values */
lnode->AT->update_values(lnode->AT, lnode->param_source->value);
/* Recompute A^T (inner's A) from A */
for (int i = 0; i < m_orig; i++)
{
for (int j = 0; j < n_orig; j++)
{
dm_AT_inner->x[j * m_orig + i] = dm_A_inner->x[i * n_orig + j];
}
}
/* This left_matmul_expr node corresponds to left multiplication with B = AT,
where A is the original (m x n) matrix given to the right_matmul function.
Furthermore, lnode->param_source->value corresponds to the column-major
version of A, which is BT (an m x n matrix) */

Dense_Matrix *B = (Dense_Matrix *) lnode->AT;
Dense_Matrix *BT = (Dense_Matrix *) lnode->A;
int m = B->base.n;
int n = B->base.m;

memcpy(BT->x, lnode->param_source->value, m * n * sizeof(double));
A_transpose(B->x, BT->x, m, n);
}

expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
Expand All @@ -78,6 +76,11 @@ expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
left_matmul */
if (param_node != NULL)
{

fprintf(stderr, "Error in new_right_matmul: parameter for a sparse matrix "
"not supported \n");
exit(1);

left_matmul_expr *lnode = (left_matmul_expr *) left_matmul;
lnode->param_source = param_node;
expr_retain(param_node);
Expand All @@ -94,16 +97,9 @@ expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
const double *data)
{
/* We express: u @ A = (A^T @ u^T)^T
A is m x n, so A^T is n x m. */
/* We express: u @ A = (A^T @ u^T)^T. A is m x n, so A^T is n x m. */
double *AT = (double *) SP_MALLOC(n * m * sizeof(double));
for (int i = 0; i < m; i++)
{
for (int j = 0; j < n; j++)
{
AT[j * m + i] = data[i * n + j];
}
}
A_transpose(AT, data, m, n);

expr *u_transpose = new_transpose(u);
expr *left_matmul_node = new_left_matmul_dense(NULL, u_transpose, n, m, AT);
Expand Down
2 changes: 1 addition & 1 deletion src/problem.c
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ void problem_update_params(problem *prob, const double *theta)
if (param->param_id == PARAM_FIXED) continue;
int offset = param->param_id;
memcpy(pnode->value, theta + offset, pnode->size * sizeof(double));
param->has_been_refreshed = false;
param->needs_refresh = true;
}

/* Force re-evaluation of affine Jacobians on next call */
Expand Down
15 changes: 10 additions & 5 deletions src/utils/dense_matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,20 @@ Matrix *dense_matrix_trans(const Dense_Matrix *A)
int n = A->base.n;
double *AT_x = (double *) SP_MALLOC(m * n * sizeof(double));

A_transpose(AT_x, A->x, m, n);

Matrix *result = new_dense_matrix(n, m, AT_x);
free(AT_x);
return result;
}

void A_transpose(double *AT, const double *A, int m, int n)
{
for (int i = 0; i < m; i++)
{
for (int j = 0; j < n; j++)
{
AT_x[j * m + i] = A->x[i * n + j];
AT[j * m + i] = A[i * n + j];
}
}

Matrix *result = new_dense_matrix(n, m, AT_x);
free(AT_x);
return result;
}
2 changes: 2 additions & 0 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ int main(void)
mu_run_test(test_param_vector_mult_problem, tests_run);
mu_run_test(test_param_left_matmul_problem, tests_run);
mu_run_test(test_param_right_matmul_problem, tests_run);
mu_run_test(test_param_left_matmul_rectangular, tests_run);
mu_run_test(test_param_right_matmul_rectangular, tests_run);
mu_run_test(test_param_fixed_skip_in_update, tests_run);
#endif /* PROFILE_ONLY */

Expand Down
Loading
Loading