diff --git a/include/xtensor-blas/xlapack.hpp b/include/xtensor-blas/xlapack.hpp index b1743ae..e9d1660 100644 --- a/include/xtensor-blas/xlapack.hpp +++ b/include/xtensor-blas/xlapack.hpp @@ -427,6 +427,54 @@ namespace lapack return info; } + template + int potrs(E1& A, E2& b, char uplo = 'L') + { + XTENSOR_ASSERT(A.dimension() == 2); + XTENSOR_ASSERT(A.layout() == layout_type::column_major); + + XTENSOR_ASSERT(b.dimension() == 1); + + XTENSOR_ASSERT(A.shape()[0] == A.shape()[1]); + + int info = cxxlapack::potrs( + uplo, + static_cast(A.shape()[0]), + 1, + A.data(), + static_cast(A.shape()[0]), + b.data(), + static_cast(b.shape()[0]) + ); + + return info; + } + + template + int trtrs(E1& A, E2& b, char uplo = 'L', char trans = 'N', char diag = 'N') + { + XTENSOR_ASSERT(A.dimension() == 2); + XTENSOR_ASSERT(A.layout() == layout_type::column_major); + + XTENSOR_ASSERT(b.dimension() == 1); + + XTENSOR_ASSERT(A.shape()[0] == A.shape()[1]); + + int info = cxxlapack::trtrs( + uplo, + trans, + diag, + static_cast(A.shape()[0]), + 1, + A.data(), + static_cast(A.shape()[0]), + b.data(), + static_cast(b.shape()[0]) + ); + + return info; + } + /** * Interface to LAPACK getri. * diff --git a/include/xtensor-blas/xlinalg.hpp b/include/xtensor-blas/xlinalg.hpp index 3cd6c6f..b6e9db7 100644 --- a/include/xtensor-blas/xlinalg.hpp +++ b/include/xtensor-blas/xlinalg.hpp @@ -1253,6 +1253,50 @@ namespace linalg return M; } + /** + * Solves a system of linear equations M*X = B with a symmetric + * where M = A*A**T if uplo is L. + * Factorization of M can be computed with cholesky. + * @return solution X + */ + template + auto solve_cholesky(const xexpression& A, const xexpression& b) + { + assert_nd_square(A); + auto M = copy_to_layout(A.derived_cast()); + auto p = copy_to_layout(b.derived_cast()); + + int info = lapack::potrs(M, p, 'L'); + + if (info > 0) + { + throw std::runtime_error("Cholesky decomposition failed."); + } + + return p; + } + + /** + * Solves Ax = b, where A is a lower triangular matrix + * @return solution x + */ + template + auto solve_triangular(const xexpression& A, const xexpression& b) + { + assert_nd_square(A); + auto M = copy_to_layout(A.derived_cast()); + auto p = copy_to_layout(b.derived_cast()); + + int info = lapack::trtrs(M, p, 'L', 'N'); + + if (info > 0) + { + throw std::runtime_error("Cholesky decomposition failed."); + } + + return p; + } + /** * Compute the SVD decomposition of \em A. * @return tuple containing S, V, and D diff --git a/test/test_lapack.cpp b/test/test_lapack.cpp index 880be14..9fcea5a 100644 --- a/test/test_lapack.cpp +++ b/test/test_lapack.cpp @@ -147,4 +147,44 @@ namespace xt EXPECT_EQ(expected3, res3); } + TEST(xlapack, solveCholesky) { + + xarray A = + {{ 1. , 0. , 0. , 0. , 0. }, + { 0.44615865, 0.89495389, 0. , 0. , 0. }, + { 0.39541532, 0.24253783, 0.88590187, 0. , 0. }, + {-0.36681098, -0.26249522, 0.0338034 , 0.89185386, 0. }, + { 0.0881614 , 0.12356345, 0.19887529, -0.35996807, 0.89879433}}; + + xarray b = {1, 1, 1, -1, -1}; + auto x = linalg::solve_cholesky(A, b); + + const xarray x_expected = { 0.13757507429403265, 0.26609253571318064, 1.03715526610177222, + -1.3449222878385465 , -1.81183493755905478}; + + for (int i = 0; i < x_expected.shape()[0]; ++i) { + EXPECT_DOUBLE_EQ(x_expected[i], x[i]); + } + } + + TEST(xlapack, solveTriangular) { + + const xt::xtensor A = + {{ 1. , 0. , 0. , 0. , 0. }, + { 0.44615865, 0.89495389, 0. , 0. , 0. }, + { 0.39541532, 0.24253783, 0.88590187, 0. , 0. }, + {-0.36681098, -0.26249522, 0.0338034 , 0.89185386, 0. }, + { 0.0881614 , 0.12356345, 0.19887529, -0.35996807, 0.89879433}}; + + const xt::xtensor b = {0.38867999, 0.46467046, 0.39042938, -0.2736973, 0.20813322}; + auto x = linalg::solve_triangular(A, b); + + const xarray x_expected = { 0.38867998999999998, 0.32544416381003327, 0.17813128230545805, + -0.05799057434472885, 0.08606304705465571}; + + for (int i = 0; i < x_expected.shape()[0]; ++i) { + EXPECT_DOUBLE_EQ(x_expected[i], x[i]); + } + } + }