From 1e1848fdb12f63781b1edf1448552c05a3015816 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 28 Sep 2022 20:46:49 +0000 Subject: [PATCH] Add a vectorized implementation of atan2 to Eigen. --- Eigen/src/Core/GlobalFunctions.h | 19 ++++++++ Eigen/src/Core/arch/Default/BFloat16.h | 3 ++ Eigen/src/Core/arch/Default/Half.h | 3 ++ Eigen/src/Core/functors/BinaryFunctors.h | 58 ++++++++++++++++++++++++ Eigen/src/plugins/ArrayCwiseBinaryOps.h | 8 ++++ doc/snippets/Cwise_array_atan2_array.cpp | 4 ++ test/array_cwise.cpp | 10 +++- 7 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 doc/snippets/Cwise_array_atan2_array.cpp diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h index f8d00b165..c801add74 100644 --- a/Eigen/src/Core/GlobalFunctions.h +++ b/Eigen/src/Core/GlobalFunctions.h @@ -181,6 +181,25 @@ namespace Eigen } #endif + /** \returns an expression of the coefficient-wise atan2(\a x, \a y). \a x and \a y must be of the same type. + * + * This function computes the coefficient-wise atan2(). + * + * \sa ArrayBase::atan2() + * + * \relates ArrayBase + */ + template + inline const std::enable_if_t< + std::is_same::value, + Eigen::CwiseBinaryOp, const LhsDerived, const RhsDerived> + > + atan2(const Eigen::ArrayBase& x, const Eigen::ArrayBase& exponents) { + return Eigen::CwiseBinaryOp, const LhsDerived, const RhsDerived>( + x.derived(), + exponents.derived() + ); + } namespace internal { diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index c444b8a5d..fc07c2190 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -626,6 +626,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) { return bfloat16(::powf(float(a), float(b))); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(const bfloat16& a, const bfloat16& b) { + return bfloat16(::atan2f(float(a), float(b))); +} EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) { return bfloat16(::sinf(float(a))); } diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index d58b6a37f..75d62283e 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -758,6 +758,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) { return half(::powf(float(a), float(b))); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half atan2(const half& a, const half& b) { + return half(::atan2f(float(a), float(b))); +} EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) { return half(::sinf(float(a))); } diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index 9b560e991..ee568ecf6 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -509,6 +509,64 @@ struct functor_traits > { }; +template +struct scalar_atan2_op { + using Scalar = LhsScalar; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Scalar> + operator()(const Scalar& y, const Scalar& x) const { + EIGEN_USING_STD(atan2); + return static_cast(atan2(y, x)); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Packet> packetOp(const Packet& y, + const Packet& x) const { + // See https://en.cppreference.com/w/cpp/numeric/math/atan2 + // for how corner cases are supposed to be handles according to the + // IEEE floating-point standard (IEC 60559). + constexpr Scalar k3PiO3f = Scalar(3.0 * M_PI_4); + const Packet kSignMask = pset1(Scalar(-0.0)); + const Packet kPi = pset1(Scalar(EIGEN_PI)); + const Packet kPiO2 = pset1(Scalar(M_PI_2)); + const Packet kPiO4 = pset1(Scalar(M_PI_4)); + const Packet k3PiO4 = pset1(k3PiO3f); + Packet x_neg = pcmp_lt(x, pzero(x)); + Packet x_sign = pand(x, kSignMask); + Packet y_sign = pand(y, kSignMask); + Packet x_zero = pcmp_eq(x, pzero(x)); + Packet y_zero = pcmp_eq(y, pzero(y)); + + // Compute the normal case. Notice that we expect that + // finite/infinite = +/-0 here. + Packet result = patan(pdiv(y, x)); + + // Compute shift for when x != 0 and y != 0. + Packet shift = pselect(x_neg, por(kPi, y_sign), pzero(x)); + + // Special cases: + // Handle x = +/-inf && y = +/-inf. + Packet is_not_nan = pcmp_eq(result, result); + result = pselect(is_not_nan, padd(shift, result), + pselect(x_neg, por(k3PiO4, y_sign), por(kPiO4, y_sign))); + // Handle x == +/-0. + result = + pselect(x_zero, pselect(y_zero, pzero(y), por(y_sign, kPiO2)), result); + // Handle y == +/-0. + result = pselect(y_zero, + pselect(x_sign, por(y_sign, kPi), por(y_sign, pzero(y))), + result); + + return result; + } +}; + +template + struct functor_traits> { + enum { + PacketAccess = is_same::value && packet_traits::HasATan && packet_traits::HasDiv && !NumTraits::IsInteger && !NumTraits::IsComplex, + Cost = + scalar_div_cost::value + 5 * NumTraits::MulCost + 5 * NumTraits::AddCost + }; +}; //---------- binary functors bound to a constant, thus appearing as a unary functor ---------- diff --git a/Eigen/src/plugins/ArrayCwiseBinaryOps.h b/Eigen/src/plugins/ArrayCwiseBinaryOps.h index 5f1e84459..30e3ee107 100644 --- a/Eigen/src/plugins/ArrayCwiseBinaryOps.h +++ b/Eigen/src/plugins/ArrayCwiseBinaryOps.h @@ -134,6 +134,14 @@ absolute_difference */ EIGEN_MAKE_CWISE_BINARY_OP(pow,pow) +/** \returns an expression of the coefficient-wise atan2(\c *this, \a y), where \a y is the given array argument. + * + * This function computes the coefficient-wise atan2. + * + */ +EIGEN_MAKE_CWISE_BINARY_OP(atan2,atan2) + + // TODO code generating macros could be moved to Macros.h and could include generation of documentation #define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \ template \ diff --git a/doc/snippets/Cwise_array_atan2_array.cpp b/doc/snippets/Cwise_array_atan2_array.cpp new file mode 100644 index 000000000..ace075a4a --- /dev/null +++ b/doc/snippets/Cwise_array_atan2_array.cpp @@ -0,0 +1,4 @@ +Array x(8,-25,3), + y(1./3.,0.5,-2.); +cout << "atan2([" << x << "], [" << y << "]) = " << x.atan2(y) << endl; // using ArrayBase::pow +cout << "atan2([" << x << "], [" << y << "] = " << atan2(x,y) << endl; // using Eigen::pow diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 319eba303..b44837fd2 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -531,6 +531,8 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m1.sinh(), sinh(m1)); VERIFY_IS_APPROX(m1.cosh(), cosh(m1)); VERIFY_IS_APPROX(m1.tanh(), tanh(m1)); + VERIFY_IS_APPROX(m1.atan2(m2), atan2(m1,m2)); + #if EIGEN_HAS_CXX11_MATH VERIFY_IS_APPROX(m1.tanh().atanh(), atanh(tanh(m1))); VERIFY_IS_APPROX(m1.sinh().asinh(), asinh(sinh(m1))); @@ -592,6 +594,13 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX( m1.sign(), -(-m1).sign() ); VERIFY_IS_APPROX( m1*m1.sign(),m1.abs()); VERIFY_IS_APPROX(m1.sign() * m1.abs(), m1); + + ArrayType tmp = m1.atan2(m2); + for (Index i = 0; i < tmp.size(); ++i) { + Scalar actual = tmp.array()(i); + Scalar expected = atan2(m1.array()(i), m2.array()(i)); + VERIFY_IS_APPROX(actual, expected); + } VERIFY_IS_APPROX(numext::abs2(numext::real(m1)) + numext::abs2(numext::imag(m1)), numext::abs2(m1)); VERIFY_IS_APPROX(numext::abs2(Eigen::real(m1)) + numext::abs2(Eigen::imag(m1)), numext::abs2(m1)); @@ -684,7 +693,6 @@ template void array_complex(const ArrayType& m) VERIFY_IS_APPROX(cos(m1+RealScalar(3)*m2), cos((m1+RealScalar(3)*m2).eval())); VERIFY_IS_APPROX(m1.sign(), sign(m1)); - VERIFY_IS_APPROX(m1.exp() * m2.exp(), exp(m1+m2)); VERIFY_IS_APPROX(m1.exp(), exp(m1)); VERIFY_IS_APPROX(m1.exp() / m2.exp(),(m1-m2).exp());