From d65c8cb60acd34a0eb898194713ff45e604de0fe Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 23 Dec 2009 11:48:53 +0100 Subject: [PATCH] fix #69 and extend unit tests or triangular solvers --- Eigen/src/Core/SolveTriangular.h | 20 +++++++++++++---- test/CMakeLists.txt | 2 +- .../{product_trsm.cpp => product_trsolve.cpp} | 22 ++++++++++++------- 3 files changed, 31 insertions(+), 13 deletions(-) rename test/{product_trsm.cpp => product_trsolve.cpp} (81%) diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index c7f0cd227..e8230dd50 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -31,7 +31,7 @@ template struct ei_triangular_solver_selector; @@ -143,18 +143,30 @@ struct ei_triangular_solver_selector +struct ei_triangular_solver_selector +{ + static void run(const Lhs& lhs, Rhs& rhs) + { + Transpose rhsTr(rhs); + Transpose lhsTr(lhs); + ei_triangular_solver_selector,Transpose,OnTheLeft,TriangularView::TransposeMode>::run(lhsTr,rhsTr); + } +}; + template struct ei_triangular_solve_matrix; // the rhs is a matrix -template -struct ei_triangular_solver_selector +template +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef ei_blas_traits LhsProductTraits; typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; static void run(const Lhs& lhs, Rhs& rhs) - { + {std::cerr << "mat\n"; const ActualLhsType actualLhs = LhsProductTraits::extract(lhs); ei_triangular_solve_matrix diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ffe89915a..b8efbcf51 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -116,7 +116,7 @@ ei_add_test(product_symm) ei_add_test(product_syrk) ei_add_test(product_trmv) ei_add_test(product_trmm) -ei_add_test(product_trsm) +ei_add_test(product_trsolve) ei_add_test(product_notemporary) ei_add_test(stable_norm) ei_add_test(bandmatrix) diff --git a/test/product_trsm.cpp b/test/product_trsolve.cpp similarity index 81% rename from test/product_trsm.cpp rename to test/product_trsolve.cpp index 1103e79a9..449240f7c 100644 --- a/test/product_trsm.cpp +++ b/test/product_trsolve.cpp @@ -36,15 +36,15 @@ VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \ } -template void trsm(int size,int cols) +template void trsolve(int size=Size,int cols=Cols) { typedef typename NumTraits::Real RealScalar; - Matrix cmLhs(size,size); - Matrix rmLhs(size,size); + Matrix cmLhs(size,size); + Matrix rmLhs(size,size); - Matrix cmRhs(size,cols), ref(size,cols); - Matrix rmRhs(size,cols); + Matrix cmRhs(size,cols), ref(size,cols); + Matrix rmRhs(size,cols); cmLhs.setRandom(); cmLhs *= static_cast(0.1); cmLhs.diagonal().cwise() += static_cast(1); rmLhs.setRandom(); rmLhs *= static_cast(0.1); rmLhs.diagonal().cwise() += static_cast(1); @@ -73,11 +73,17 @@ template void trsm(int size,int cols) VERIFY_TRSM_ONTHERIGHT(rmLhs.conjugate().template triangularView(), rmRhs); } -void test_product_trsm() +void test_product_trsolve() { for(int i = 0; i < g_repeat ; i++) { - CALL_SUBTEST_1((trsm(ei_random(1,320),ei_random(1,320)))); - CALL_SUBTEST_2((trsm >(ei_random(1,320),ei_random(1,320)))); + // matrices + CALL_SUBTEST_1((trsolve(ei_random(1,320),ei_random(1,320)))); + CALL_SUBTEST_2((trsolve,Dynamic,Dynamic>(ei_random(1,320),ei_random(1,320)))); + + // vectors + CALL_SUBTEST_3((trsolve,Dynamic,1>(ei_random(1,320)))); + CALL_SUBTEST_4((trsolve())); + CALL_SUBTEST_5((trsolve,4,1>())); } }