diff --git a/Eigen/src/Core/DiagonalProduct.h b/Eigen/src/Core/DiagonalProduct.h index 5948111c6..b838d1b31 100644 --- a/Eigen/src/Core/DiagonalProduct.h +++ b/Eigen/src/Core/DiagonalProduct.h @@ -26,8 +26,8 @@ #ifndef EIGEN_DIAGONALPRODUCT_H #define EIGEN_DIAGONALPRODUCT_H -template -struct ei_traits > +template +struct ei_traits > { typedef typename MatrixType::Scalar Scalar; enum { @@ -35,14 +35,15 @@ struct ei_traits > ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime, - Flags = (unsigned int)(MatrixType::Flags) & HereditaryBits, + Flags = (HereditaryBits & (unsigned int)(MatrixType::Flags)) + | (PacketAccessBit & (unsigned int)(MatrixType::Flags) & (unsigned int)(DiagonalType::DiagonalVectorType::Flags)), CoeffReadCost = NumTraits::MulCost + MatrixType::CoeffReadCost + DiagonalType::DiagonalVectorType::CoeffReadCost }; }; -template +template class DiagonalProduct : ei_no_assignment_operator, - public MatrixBase > + public MatrixBase > { public: @@ -51,7 +52,7 @@ class DiagonalProduct : ei_no_assignment_operator, inline DiagonalProduct(const MatrixType& matrix, const DiagonalType& diagonal) : m_matrix(matrix), m_diagonal(diagonal) { - ei_assert(diagonal.diagonal().size() == (Order == DiagonalOnTheLeft ? matrix.rows() : matrix.cols())); + ei_assert(diagonal.diagonal().size() == (ProductOrder == DiagonalOnTheLeft ? matrix.rows() : matrix.cols())); } inline int rows() const { return m_matrix.rows(); } @@ -59,7 +60,30 @@ class DiagonalProduct : ei_no_assignment_operator, const Scalar coeff(int row, int col) const { - return m_diagonal.diagonal().coeff(Order == DiagonalOnTheLeft ? row : col) * m_matrix.coeff(row, col); + return m_diagonal.diagonal().coeff(ProductOrder == DiagonalOnTheLeft ? row : col) * m_matrix.coeff(row, col); + } + + template + EIGEN_STRONG_INLINE PacketScalar packet(int row, int col) const + { + enum { + StorageOrder = Flags & RowMajorBit ? RowMajor : ColMajor, + InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime, + DiagonalVectorPacketLoadMode = (LoadMode == Aligned && ((InnerSize%16) == 0)) ? Aligned : Unaligned + }; + const int indexInDiagonalVector = ProductOrder == DiagonalOnTheLeft ? row : col; + + if((int(StorageOrder) == RowMajor && int(ProductOrder) == DiagonalOnTheLeft) + ||(int(StorageOrder) == ColMajor && int(ProductOrder) == DiagonalOnTheRight)) + { + return ei_pmul(m_matrix.template packet(row, col), + ei_pset1(m_diagonal.diagonal().coeff(indexInDiagonalVector))); + } + else + { + return ei_pmul(m_matrix.template packet(row, col), + m_diagonal.diagonal().template packet(indexInDiagonalVector)); + } } protected: diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 3e4c1bf0b..af21c190f 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -53,7 +53,7 @@ template class Product; template class DiagonalBase; template class DiagonalWrapper; template class DiagonalMatrix; -template class DiagonalProduct; +template class DiagonalProduct; template class Diagonal; template class Map; diff --git a/test/diagonalmatrices.cpp b/test/diagonalmatrices.cpp index 0a8b7086b..9eb0f10f2 100644 --- a/test/diagonalmatrices.cpp +++ b/test/diagonalmatrices.cpp @@ -23,7 +23,7 @@ // Eigen. If not, see . #include "main.h" - +using namespace std; template void diagonalmatrices(const MatrixType& m) { typedef typename MatrixType::Scalar Scalar; @@ -34,7 +34,7 @@ template void diagonalmatrices(const MatrixType& m) typedef Matrix SquareMatrixType; typedef DiagonalMatrix LeftDiagonalMatrix; typedef DiagonalMatrix RightDiagonalMatrix; - + typedef Matrix BigMatrix; int rows = m.rows(); int cols = m.cols(); @@ -46,20 +46,7 @@ template void diagonalmatrices(const MatrixType& m) rv2 = RowVectorType::Random(cols); LeftDiagonalMatrix ldm1(v1), ldm2(v2); RightDiagonalMatrix rdm1(rv1), rdm2(rv2); - - int i = ei_random(0, rows-1); - int j = ei_random(0, cols-1); - - VERIFY_IS_APPROX( ((ldm1 * m1)(i,j)) , ldm1.diagonal()(i) * m1(i,j) ); - VERIFY_IS_APPROX( ((ldm1 * (m1+m2))(i,j)) , ldm1.diagonal()(i) * (m1+m2)(i,j) ); - VERIFY_IS_APPROX( ((m1 * rdm1)(i,j)) , rdm1.diagonal()(j) * m1(i,j) ); - VERIFY_IS_APPROX( ((v1.asDiagonal() * m1)(i,j)) , v1(i) * m1(i,j) ); - VERIFY_IS_APPROX( ((m1 * rv1.asDiagonal())(i,j)) , rv1(j) * m1(i,j) ); - VERIFY_IS_APPROX( (((v1+v2).asDiagonal() * m1)(i,j)) , (v1+v2)(i) * m1(i,j) ); - VERIFY_IS_APPROX( (((v1+v2).asDiagonal() * (m1+m2))(i,j)) , (v1+v2)(i) * (m1+m2)(i,j) ); - VERIFY_IS_APPROX( ((m1 * (rv1+rv2).asDiagonal())(i,j)) , (rv1+rv2)(j) * m1(i,j) ); - VERIFY_IS_APPROX( (((m1+m2) * (rv1+rv2).asDiagonal())(i,j)) , (rv1+rv2)(j) * (m1+m2)(i,j) ); - + SquareMatrixType sq_m1 (v1.asDiagonal()); VERIFY_IS_APPROX(sq_m1, v1.asDiagonal().toDenseMatrix()); sq_m1 = v1.asDiagonal(); @@ -77,6 +64,32 @@ template void diagonalmatrices(const MatrixType& m) VERIFY_IS_APPROX(sq_m1, ldm1.toDenseMatrix()); sq_m1.transpose() = ldm1; VERIFY_IS_APPROX(sq_m1, ldm1.toDenseMatrix()); + + int i = ei_random(0, rows-1); + int j = ei_random(0, cols-1); + + VERIFY_IS_APPROX( ((ldm1 * m1)(i,j)) , ldm1.diagonal()(i) * m1(i,j) ); + VERIFY_IS_APPROX( ((ldm1 * (m1+m2))(i,j)) , ldm1.diagonal()(i) * (m1+m2)(i,j) ); + VERIFY_IS_APPROX( ((m1 * rdm1)(i,j)) , rdm1.diagonal()(j) * m1(i,j) ); + VERIFY_IS_APPROX( ((v1.asDiagonal() * m1)(i,j)) , v1(i) * m1(i,j) ); + VERIFY_IS_APPROX( ((m1 * rv1.asDiagonal())(i,j)) , rv1(j) * m1(i,j) ); + VERIFY_IS_APPROX( (((v1+v2).asDiagonal() * m1)(i,j)) , (v1+v2)(i) * m1(i,j) ); + VERIFY_IS_APPROX( (((v1+v2).asDiagonal() * (m1+m2))(i,j)) , (v1+v2)(i) * (m1+m2)(i,j) ); + VERIFY_IS_APPROX( ((m1 * (rv1+rv2).asDiagonal())(i,j)) , (rv1+rv2)(j) * m1(i,j) ); + VERIFY_IS_APPROX( (((m1+m2) * (rv1+rv2).asDiagonal())(i,j)) , (rv1+rv2)(j) * (m1+m2)(i,j) ); + + BigMatrix big; + big.setZero(2*rows, 2*cols); + + big.block(i,j,rows,cols) = m1; + big.block(i,j,rows,cols) = v1.asDiagonal() * big.block(i,j,rows,cols); + + VERIFY_IS_APPROX((big.block(i,j,rows,cols)) , v1.asDiagonal() * m1 ); + + big.block(i,j,rows,cols) = m1; + big.block(i,j,rows,cols) = big.block(i,j,rows,cols) * rv1.asDiagonal(); + VERIFY_IS_APPROX((big.block(i,j,rows,cols)) , m1 * rv1.asDiagonal() ); + } void test_diagonalmatrices()