diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index be9362958..98305f993 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -75,7 +75,7 @@ struct ei_product_triangular_matrix_matrix Blocking; enum { SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr), - IsLower = (Mode&Lower) == Lower + IsLower = (Mode&Lower) == Lower, + SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1 }; Index kc = depth; // cache block size along the K direction @@ -127,7 +128,10 @@ struct ei_product_triangular_matrix_matrix triangularBuffer; triangularBuffer.setZero(); - triangularBuffer.diagonal().setOnes(); + if((Mode&ZeroDiag)==ZeroDiag) + triangularBuffer.diagonal().setZero(); + else + triangularBuffer.diagonal().setOnes(); ei_gebp_kernel gebp_kernel; ei_gemm_pack_lhs pack_lhs; @@ -169,7 +173,7 @@ struct ei_product_triangular_matrix_matrix Blocking; enum { SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr), - IsLower = (Mode&Lower) == Lower + IsLower = (Mode&Lower) == Lower, + SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1 }; Index kc = depth; // cache block size along the K direction @@ -252,7 +257,10 @@ struct ei_product_triangular_matrix_matrix triangularBuffer; triangularBuffer.setZero(); - triangularBuffer.diagonal().setOnes(); + if((Mode&ZeroDiag)==ZeroDiag) + triangularBuffer.diagonal().setZero(); + else + triangularBuffer.diagonal().setOnes(); ei_gebp_kernel gebp_kernel; ei_gemm_pack_lhs pack_lhs; @@ -300,7 +308,7 @@ struct ei_product_triangular_matrix_matrix void trmm(int size,int /*othersize*/) DenseIndex cols = ei_random(1,size); MatrixColMaj triV(rows,cols), triH(cols,rows), upTri(cols,rows), loTri(rows,cols), - unitUpTri(cols,rows), unitLoTri(rows,cols); + unitUpTri(cols,rows), unitLoTri(rows,cols), strictlyUpTri(cols,rows), strictlyLoTri(rows,cols); MatrixColMaj ge1(rows,cols), ge2(cols,rows), ge3; MatrixRowMaj rge3; @@ -48,6 +48,8 @@ template void trmm(int size,int /*othersize*/) upTri = triH.template triangularView(); unitLoTri = triV.template triangularView(); unitUpTri = triH.template triangularView(); + strictlyLoTri = triV.template triangularView(); + strictlyUpTri = triH.template triangularView(); ge1.setRandom(); ge2.setRandom(); @@ -72,6 +74,11 @@ template void trmm(int size,int /*othersize*/) VERIFY_IS_APPROX( rge3.noalias() = ge2 * triV.template triangularView(), ge2 * unitLoTri); VERIFY_IS_APPROX( ge3 = ge2 * triV.template triangularView(), ge2 * unitLoTri); VERIFY_IS_APPROX( ge3 = (s1*triV).adjoint().template triangularView() * ge2.adjoint(), ei_conj(s1) * unitLoTri.adjoint() * ge2.adjoint()); + + VERIFY_IS_APPROX( ge3 = triV.template triangularView() * ge2, strictlyLoTri * ge2); + VERIFY_IS_APPROX( rge3.noalias() = ge2 * triV.template triangularView(), ge2 * strictlyLoTri); + VERIFY_IS_APPROX( ge3 = ge2 * triV.template triangularView(), ge2 * strictlyLoTri); + VERIFY_IS_APPROX( ge3 = (s1*triV).adjoint().template triangularView() * ge2.adjoint(), ei_conj(s1) * strictlyLoTri.adjoint() * ge2.adjoint()); } void test_product_trmm()