Skip to content

Commit

Permalink
feat: add solve_cholesky and solve_triangular functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Alina TUHOLUKOVA committed Jun 10, 2020
1 parent a757f06 commit 33ba8da
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
48 changes: 48 additions & 0 deletions include/xtensor-blas/xlapack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,54 @@ namespace lapack
return info;
}

template <class E1, class E2>
int potrs(E1& A, E2& b, char uplo = 'L')
{
XTENSOR_ASSERT(A.dimension() == 2);
XTENSOR_ASSERT(A.layout() == layout_type::column_major);

XTENSOR_ASSERT(b.dimension() == 1);

XTENSOR_ASSERT(A.shape()[0] == A.shape()[1]);

int info = cxxlapack::potrs<blas_index_t>(
uplo,
static_cast<blas_index_t>(A.shape()[0]),
1,
A.data(),
static_cast<blas_index_t>(A.shape()[0]),
b.data(),
static_cast<blas_index_t>(b.shape()[0])
);

return info;
}

template <class E1, class E2>
int trtrs(E1& A, E2& b, char uplo = 'L', char trans = 'N', char diag = 'N')
{
XTENSOR_ASSERT(A.dimension() == 2);
XTENSOR_ASSERT(A.layout() == layout_type::column_major);

XTENSOR_ASSERT(b.dimension() == 1);

XTENSOR_ASSERT(A.shape()[0] == A.shape()[1]);

int info = cxxlapack::trtrs<blas_index_t>(
uplo,
trans,
diag,
static_cast<blas_index_t>(A.shape()[0]),
1,
A.data(),
static_cast<blas_index_t>(A.shape()[0]),
b.data(),
static_cast<blas_index_t>(b.shape()[0])
);

return info;
}

/**
* Interface to LAPACK getri.
*
Expand Down
44 changes: 44 additions & 0 deletions include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,50 @@ namespace linalg
return M;
}

/**
* Solves a system of linear equations M*X = B with a symmetric
* where M = A*A**T if uplo is L.
* Factorization of M can be computed with cholesky.
* @return solution X
*/
template <class T, class D>
auto solve_cholesky(const xexpression<T>& A, const xexpression<D>& b)
{
assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
auto p = copy_to_layout<layout_type::column_major>(b.derived_cast());

int info = lapack::potrs(M, p, 'L');

if (info > 0)
{
throw std::runtime_error("Cholesky decomposition failed.");
}

return p;
}

/**
* Solves Ax = b, where A is a lower triangular matrix
* @return solution x
*/
template <class T, class D>
auto solve_triangular(const xexpression<T>& A, const xexpression<D>& b)
{
assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
auto p = copy_to_layout<layout_type::column_major>(b.derived_cast());

int info = lapack::trtrs(M, p, 'L', 'N');

if (info > 0)
{
throw std::runtime_error("Cholesky decomposition failed.");
}

return p;
}

/**
* Compute the SVD decomposition of \em A.
* @return tuple containing S, V, and D
Expand Down
40 changes: 40 additions & 0 deletions test/test_lapack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,44 @@ namespace xt
EXPECT_EQ(expected3, res3);
}

TEST(xlapack, solveCholesky) {

xarray<double> A =
{{ 1. , 0. , 0. , 0. , 0. },
{ 0.44615865, 0.89495389, 0. , 0. , 0. },
{ 0.39541532, 0.24253783, 0.88590187, 0. , 0. },
{-0.36681098, -0.26249522, 0.0338034 , 0.89185386, 0. },
{ 0.0881614 , 0.12356345, 0.19887529, -0.35996807, 0.89879433}};

xarray<double> b = {1, 1, 1, -1, -1};
auto x = linalg::solve_cholesky(A, b);

const xarray<double> x_expected = { 0.13757507429403265, 0.26609253571318064, 1.03715526610177222,
-1.3449222878385465 , -1.81183493755905478};

for (int i = 0; i < x_expected.shape()[0]; ++i) {
EXPECT_DOUBLE_EQ(x_expected[i], x[i]);
}
}

TEST(xlapack, solveTriangular) {

const xt::xtensor<double, 2> A =
{{ 1. , 0. , 0. , 0. , 0. },
{ 0.44615865, 0.89495389, 0. , 0. , 0. },
{ 0.39541532, 0.24253783, 0.88590187, 0. , 0. },
{-0.36681098, -0.26249522, 0.0338034 , 0.89185386, 0. },
{ 0.0881614 , 0.12356345, 0.19887529, -0.35996807, 0.89879433}};

const xt::xtensor<double, 1> b = {0.38867999, 0.46467046, 0.39042938, -0.2736973, 0.20813322};
auto x = linalg::solve_triangular(A, b);

const xarray<double> x_expected = { 0.38867998999999998, 0.32544416381003327, 0.17813128230545805,
-0.05799057434472885, 0.08606304705465571};

for (int i = 0; i < x_expected.shape()[0]; ++i) {
EXPECT_DOUBLE_EQ(x_expected[i], x[i]);
}
}

}

0 comments on commit 33ba8da

Please sign in to comment.