From b50c3e967e1676f248c93c1a79e6574ae746e2fd Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 9 Jan 2017 23:42:16 +0100 Subject: [PATCH] Add a minimalistic symbolic scalar type with expression template and make use of it to define the last placeholder and to unify the return type of seq and seqN. --- Eigen/src/Core/ArithmeticSequence.h | 220 ++++++++++++++++++++++++---- test/indexed_view.cpp | 5 + 2 files changed, 197 insertions(+), 28 deletions(-) diff --git a/Eigen/src/Core/ArithmeticSequence.h b/Eigen/src/Core/ArithmeticSequence.h index 71301797a..9f4fe327b 100644 --- a/Eigen/src/Core/ArithmeticSequence.h +++ b/Eigen/src/Core/ArithmeticSequence.h @@ -34,7 +34,7 @@ struct last_t { int operator- (last_t) const { return 0; } int operator- (shifted_last x) const { return -x.offset; } }; -static const last_t last; +static const last_t last_legacy; struct shifted_end { @@ -52,7 +52,145 @@ struct end_t { int operator- (end_t) const { return 0; } int operator- (shifted_end x) const { return -x.offset; } }; -static const end_t end; +static const end_t end_legacy; + +// A simple wrapper around an Index to provide the eval method. +// We could also use a free-function symbolic_eval... +class symbolic_value_wrapper { +public: + symbolic_value_wrapper(Index val) : m_value(val) {} + template + Index eval(const T&) const { return m_value; } +protected: + Index m_value; +}; + +//-------------------------------------------------------------------------------- +// minimalistic symbolic scalar type +//-------------------------------------------------------------------------------- + +template class symbolic_symbol; +template class symbolic_negate; +template class symbolic_add; +template class symbolic_product; +template class symbolic_quotient; + +template +class symbolic_index_base +{ +public: + const Derived& derived() const { return *static_cast(this); } + + symbolic_negate operator-() const { return symbolic_negate(derived()); } + + symbolic_add operator+(Index b) const + { return symbolic_add(derived(), b); } + symbolic_add operator-(Index a) const + { return symbolic_add(derived(), -a); } + symbolic_quotient operator/(Index a) const + { return symbolic_quotient(derived(),a); } + + friend symbolic_add operator+(Index a, const symbolic_index_base& b) + { return symbolic_add(b.derived(), a); } + friend symbolic_add,symbolic_value_wrapper> operator-(Index a, const symbolic_index_base& b) + { return symbolic_add,symbolic_value_wrapper>(-b.derived(), a); } + friend symbolic_add operator/(Index a, const symbolic_index_base& b) + { return symbolic_add(a,b.derived()); } + + template + symbolic_add operator+(const symbolic_index_base &b) const + { return symbolic_add(derived(), b.derived()); } + + template + symbolic_add > operator-(const symbolic_index_base &b) const + { return symbolic_add >(derived(), -b.derived()); } + + template + symbolic_add operator/(const symbolic_index_base &b) const + { return symbolic_quotient(derived(), b.derived()); } +}; + +template +struct is_symbolic { + enum { value = internal::is_convertible >::value }; +}; + +template +class symbolic_value_pair +{ +public: + symbolic_value_pair(Index val) : m_value(val) {} + Index value() const { return m_value; } +protected: + Index m_value; +}; + +template +class symbolic_value : public symbolic_index_base > +{ +public: + symbolic_value() {} + + Index eval(const symbolic_value_pair &values) const { return values.value(); } + + // TODO add a c++14 eval taking a tuple of symbolic_value_pair and getting the value with std::get >... +}; + +template +class symbolic_negate : public symbolic_index_base > +{ +public: + symbolic_negate(const Arg0& arg0) : m_arg0(arg0) {} + + template + Index eval(const T& values) const { return -m_arg0.eval(values); } +protected: + Arg0 m_arg0; +}; + +template +class symbolic_add : public symbolic_index_base > +{ +public: + symbolic_add(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + + template + Index eval(const T& values) const { return m_arg0.eval(values) + m_arg1.eval(values); } +protected: + Arg0 m_arg0; + Arg1 m_arg1; +}; + +template +class symbolic_product : public symbolic_index_base > +{ +public: + symbolic_product(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + + template + Index eval(const T& values) const { return m_arg0.eval(values) * m_arg1.eval(values); } +protected: + Arg0 m_arg0; + Arg1 m_arg1; +}; + +template +class symbolic_quotient : public symbolic_index_base > +{ +public: + symbolic_quotient(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} + + template + Index eval(const T& values) const { return m_arg0.eval(values) / m_arg1.eval(values); } +protected: + Arg0 m_arg0; + Arg1 m_arg1; +}; + +struct symb_last_tag {}; + +static const symbolic_value last; +static const symbolic_add,symbolic_value_wrapper> end(last+1); //-------------------------------------------------------------------------------- // integral constant @@ -116,34 +254,30 @@ protected: IncrType m_incr; }; -template struct cleanup_slice_type { typedef Index type; }; -template<> struct cleanup_slice_type { typedef last_t type; }; -template<> struct cleanup_slice_type { typedef shifted_last type; }; -template<> struct cleanup_slice_type { typedef end_t type; }; -template<> struct cleanup_slice_type { typedef shifted_end type; }; -template struct cleanup_slice_type > { typedef fix_t type; }; -template struct cleanup_slice_type (*)() > { typedef fix_t type; }; +template struct cleanup_seq_type { typedef T type; }; +template struct cleanup_seq_type > { typedef fix_t type; }; +template struct cleanup_seq_type (*)() > { typedef fix_t type; }; template -ArithemeticSequenceProxyWithBounds::type,typename cleanup_slice_type::type > -seq(FirstType f, LastType l) { - return ArithemeticSequenceProxyWithBounds::type,typename cleanup_slice_type::type>(f,l); +ArithemeticSequenceProxyWithBounds::type,typename cleanup_seq_type::type > +seq_legacy(FirstType f, LastType l) { + return ArithemeticSequenceProxyWithBounds::type,typename cleanup_seq_type::type>(f,l); } template -ArithemeticSequenceProxyWithBounds::type,typename cleanup_slice_type::type,typename cleanup_slice_type::type > -seq(FirstType f, LastType l, IncrType s) { - return ArithemeticSequenceProxyWithBounds::type,typename cleanup_slice_type::type,typename cleanup_slice_type::type>(f,l,typename cleanup_slice_type::type(s)); +ArithemeticSequenceProxyWithBounds::type,typename cleanup_seq_type::type,typename cleanup_seq_type::type > +seq_legacy(FirstType f, LastType l, IncrType s) { + return ArithemeticSequenceProxyWithBounds::type,typename cleanup_seq_type::type,typename cleanup_seq_type::type>(f,l,typename cleanup_seq_type::type(s)); } template > -class ArithemeticSequenceProxyWithSize +class ArithemeticSequence { public: - ArithemeticSequenceProxyWithSize(FirstType first, SizeType size) : m_first(first), m_size(size) {} - ArithemeticSequenceProxyWithSize(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {} + ArithemeticSequence(FirstType first, SizeType size) : m_first(first), m_size(size) {} + ArithemeticSequence(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {} enum { SizeAtCompileTime = get_compile_time::value, @@ -165,18 +299,30 @@ protected: template -ArithemeticSequenceProxyWithSize::type,typename cleanup_slice_type::type,typename cleanup_slice_type::type > +ArithemeticSequence::type,typename cleanup_seq_type::type,typename cleanup_seq_type::type > seqN(FirstType first, SizeType size, IncrType incr) { - return ArithemeticSequenceProxyWithSize::type,typename cleanup_slice_type::type,typename cleanup_slice_type::type>(first,size,incr); + return ArithemeticSequence::type,typename cleanup_seq_type::type,typename cleanup_seq_type::type>(first,size,incr); } template -ArithemeticSequenceProxyWithSize::type,typename cleanup_slice_type::type > +ArithemeticSequence::type,typename cleanup_seq_type::type > seqN(FirstType first, SizeType size) { - return ArithemeticSequenceProxyWithSize::type,typename cleanup_slice_type::type>(first,size); + return ArithemeticSequence::type,typename cleanup_seq_type::type>(first,size); } +template +auto seq(FirstType f, LastType l) -> decltype(seqN(f,(l-f+1))) +{ + return seqN(f,(l-f+1)); +} +template +auto seq(FirstType f, LastType l, IncrType incr) + -> decltype(seqN(f,(l-f+typename cleanup_seq_type::type(incr))/typename cleanup_seq_type::type(incr),typename cleanup_seq_type::type(incr))) +{ + typedef typename cleanup_seq_type::type CleanedIncrType; + return seqN(f,(l-f+CleanedIncrType(incr))/CleanedIncrType(incr),CleanedIncrType(incr)); +} namespace internal { @@ -214,7 +360,7 @@ struct get_compile_time_incr -struct get_compile_time_incr > { +struct get_compile_time_incr > { enum { value = get_compile_time::value }; }; @@ -258,6 +404,17 @@ Index symbolic2value(shifted_last x, Index size) { return size+x.offset-1; } Index symbolic2value(end_t, Index size) { return size; } Index symbolic2value(shifted_end x, Index size) { return size+x.offset; } +template +fix_t symbolic2value(fix_t x, Index /*size*/) { return x; } + +template +Index symbolic2value(const symbolic_index_base &x, Index size) +{ + Index h=x.derived().eval(symbolic_value_pair(size-1)); + return x.derived().eval(symbolic_value_pair(size-1)); +} + + // Convert a symbolic range into a usable one (i.e., remove last/end "keywords") template struct MakeIndexing > { @@ -270,14 +427,21 @@ ArithemeticSequenceProxyWithBounds make_indexing(const Ari } // Convert a symbolic span into a usable one (i.e., remove last/end "keywords") -template -struct MakeIndexing > { - typedef ArithemeticSequenceProxyWithSize type; +template +struct make_size_type { + typedef typename internal::conditional::value, Index, T>::type type; }; template -ArithemeticSequenceProxyWithSize make_indexing(const ArithemeticSequenceProxyWithSize& ids, Index size) { - return ArithemeticSequenceProxyWithSize(symbolic2value(ids.firstObject(),size),ids.sizeObject(),ids.incrObject()); +struct MakeIndexing > { + typedef ArithemeticSequence::type,IncrType> type; +}; + +template +ArithemeticSequence::type,IncrType> +make_indexing(const ArithemeticSequence& ids, Index size) { + return ArithemeticSequence::type,IncrType>( + symbolic2value(ids.firstObject(),size),symbolic2value(ids.sizeObject(),size),ids.incrObject()); } // Convert a symbolic 'all' into a usable range diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp index 23ad2d743..25a25499c 100644 --- a/test/indexed_view.cpp +++ b/test/indexed_view.cpp @@ -139,6 +139,11 @@ void check_indexed_view() VERIFY_IS_EQUAL( (A(eii, eii)).InnerStrideAtCompileTime, 0); VERIFY_IS_EQUAL( (A(eii, eii)).OuterStrideAtCompileTime, 0); + VERIFY_IS_APPROX( A(seq(n-1,2,-2), seqN(n-1-6,4)), A(seq(last,2,-2), seqN(last-6,4)) ); + VERIFY_IS_APPROX( A(seq(n-1-6,n-1-2), seqN(n-1-6,4)), A(seq(last-6,last-2), seqN(6+last-6-6,4)) ); + VERIFY_IS_APPROX( A(seq((n-1)/2,(n)/2+3), seqN(2,4)), A(seq(last/2,(last+1)/2+3), seqN(last+2-last,4)) ); + VERIFY_IS_APPROX( A(seq(n-2,2,-2), seqN(n-8,4)), A(seq(end-2,2,-2), seqN(end-8,4)) ); + #if EIGEN_HAS_CXX11 VERIFY( (A(all, std::array{{1,3,2,4}})).ColsAtCompileTime == 4);