From d072fc4b1432b193d24e44d70885b636d4132405 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 10 Jan 2017 17:10:35 +0100 Subject: [PATCH] add writeable IndexedView --- Eigen/src/Core/DenseBase.h | 66 ++++++++++++++++++++++++++++++++---- Eigen/src/Core/IndexedView.h | 7 ++++ test/indexed_view.cpp | 15 ++++++++ 3 files changed, 81 insertions(+), 7 deletions(-) diff --git a/Eigen/src/Core/DenseBase.h b/Eigen/src/Core/DenseBase.h index 779cb4549..909fa0f12 100644 --- a/Eigen/src/Core/DenseBase.h +++ b/Eigen/src/Core/DenseBase.h @@ -558,27 +558,27 @@ template class DenseBase EIGEN_DEVICE_FUNC void reverseInPlace(); template - struct IndexedViewType { + struct ConstIndexedViewType { typedef IndexedView::type,typename internal::MakeIndexing::type> type; }; template typename internal::enable_if< - ! (internal::traits::type>::IsBlockAlike + ! (internal::traits::type>::IsBlockAlike || (internal::is_integral::value && internal::is_integral::value)), - typename IndexedViewType::type >::type + typename ConstIndexedViewType::type >::type operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { - return typename IndexedViewType::type( + return typename ConstIndexedViewType::type( derived(), internal::make_indexing(rowIndices,derived().rows()), internal::make_indexing(colIndices,derived().cols())); } template typename internal::enable_if< - internal::traits::type>::IsBlockAlike + internal::traits::type>::IsBlockAlike && !(internal::is_integral::value && internal::is_integral::value), - typename internal::traits::type>::BlockType>::type + typename internal::traits::type>::BlockType>::type operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { - typedef typename internal::traits::type>::BlockType BlockType; + typedef typename internal::traits::type>::BlockType BlockType; typename internal::MakeIndexing::type actualRowIndices = internal::make_indexing(rowIndices,derived().rows()); typename internal::MakeIndexing::type actualColIndices = internal::make_indexing(colIndices,derived().cols()); return BlockType(derived(), @@ -609,6 +609,58 @@ template class DenseBase derived(), rowIndices, colIndices); } + template + struct IndexedViewType { + typedef IndexedView::type,typename internal::MakeIndexing::type> type; + }; + + template + typename internal::enable_if< + ! (internal::traits::type>::IsBlockAlike + || (internal::is_integral::value && internal::is_integral::value)), + typename IndexedViewType::type >::type + operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { + return typename IndexedViewType::type( + derived(), internal::make_indexing(rowIndices,derived().rows()), internal::make_indexing(colIndices,derived().cols())); + } + + template + typename internal::enable_if< + internal::traits::type>::IsBlockAlike + && !(internal::is_integral::value && internal::is_integral::value), + typename internal::traits::type>::BlockType>::type + operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { + typedef typename internal::traits::type>::BlockType BlockType; + typename internal::MakeIndexing::type actualRowIndices = internal::make_indexing(rowIndices,derived().rows()); + typename internal::MakeIndexing::type actualColIndices = internal::make_indexing(colIndices,derived().cols()); + return BlockType(derived(), + internal::first(actualRowIndices), + internal::first(actualColIndices), + internal::size(actualRowIndices), + internal::size(actualColIndices)); + } + + template + IndexedView::type> + operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndices& colIndices) { + return IndexedView::type>( + derived(), rowIndices, internal::make_indexing(colIndices,derived().cols())); + } + + template + IndexedView::type, const ColIndicesT (&)[ColIndicesN]> + operator()(const RowIndices& rowIndices, const ColIndicesT (&colIndices)[ColIndicesN]) { + return IndexedView::type,const ColIndicesT (&)[ColIndicesN]>( + derived(), internal::make_indexing(rowIndices,derived().rows()), colIndices); + } + + template + IndexedView + operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndicesT (&colIndices)[ColIndicesN]) { + return IndexedView( + derived(), rowIndices, colIndices); + } + #define EIGEN_CURRENT_STORAGE_BASE_CLASS Eigen::DenseBase #define EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL #define EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(COND) diff --git a/Eigen/src/Core/IndexedView.h b/Eigen/src/Core/IndexedView.h index 5aaf5b4e0..81ff53758 100644 --- a/Eigen/src/Core/IndexedView.h +++ b/Eigen/src/Core/IndexedView.h @@ -104,6 +104,7 @@ class IndexedView : public IndexedViewImpl::StorageKind>::Base Base; EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView) + EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView) typedef typename internal::ref_selector::non_const_type MatrixTypeNested; typedef typename internal::remove_all::type NestedExpression; @@ -180,6 +181,12 @@ struct unary_evaluator, IndexBased> return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]); + } + protected: evaluator m_argImpl; diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp index fde3ee8f9..42d136847 100644 --- a/test/indexed_view.cpp +++ b/test/indexed_view.cpp @@ -168,6 +168,12 @@ void check_indexed_view() // Check fall-back to Block { + VERIFY( is_same_type(A.col(0), A(all,0)) ); + VERIFY( is_same_type(A.row(0), A(0,all)) ); + VERIFY( is_same_type(A.block(0,0,2,2), A(seqN(0,2),seq(0,1))) ); + VERIFY( is_same_type(A.middleRows(2,4), A(seqN(2,4),all)) ); + VERIFY( is_same_type(A.middleCols(2,4), A(all,seqN(2,4))) ); + const ArrayXXi& cA(A); VERIFY( is_same_type(cA.col(0), cA(all,0)) ); VERIFY( is_same_type(cA.row(0), cA(0,all)) ); @@ -176,6 +182,15 @@ void check_indexed_view() VERIFY( is_same_type(cA.middleCols(2,4), cA(all,seqN(2,4))) ); } + ArrayXXi A1=A, A2 = ArrayXXi::Random(4,4); + ArrayXi range25(4); range25 << 3,2,4,5; + A1(seqN(3,4),seq(2,5)) = A2; + VERIFY_IS_APPROX( A1.block(3,2,4,4), A2 ); + A1 = A; + A2.setOnes(); + A1(seq(6,3,-1),range25) = A2; + VERIFY_IS_APPROX( A1.block(3,2,4,4), A2 ); + #if EIGEN_HAS_CXX11 VERIFY( (A(all, std::array{{1,3,2,4}})).ColsAtCompileTime == 4);