From 2af03fb6854c9e046b2a3f2412f2adf3e55ff0ba Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Thu, 4 May 2023 16:02:08 +0000 Subject: [PATCH] clean up array_cwise test --- test/array_cwise.cpp | 59 ++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index dfa81d4da..5635ca8ab 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -10,6 +10,20 @@ #include #include "main.h" +// suppress annoying unsigned integer warnings +template ::IsSigned> +struct negative_or_zero_impl { + static Scalar run(const Scalar& a) { return -a; } +}; +template +struct negative_or_zero_impl { + static Scalar run(const Scalar&) { return 0; } +}; +template +Scalar negative_or_zero(const Scalar& a) { + return negative_or_zero_impl::run(a); +} + template ::IsInteger,int> = 0> std::vector special_values() { const Scalar zero = Scalar(0); @@ -249,7 +263,7 @@ template (exponent)); + return static_cast(pow(base, static_cast(exponent))); } }; @@ -257,7 +271,7 @@ template struct ref_pow { static Base run(Base base, Exponent exponent) { EIGEN_USING_STD(pow); - return pow(base, exponent); + return static_cast(pow(base, exponent)); } }; @@ -302,7 +316,7 @@ void test_exponent(Exponent exponent) { template void unary_pow_test() { Exponent max_exponent = static_cast(NumTraits::digits()); - Exponent min_exponent = static_cast(NumTraits::IsSigned ? -max_exponent : 0); + Exponent min_exponent = negative_or_zero(max_exponent); for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) { test_exponent(exponent); @@ -374,7 +388,7 @@ void signbit_test() { std::vector special_vals = special_values(); for (size_t i = 0; i < special_vals.size(); i++) { x(2 * i + 0) = special_vals[i]; - x(2 * i + 1) = -special_vals[i]; + x(2 * i + 1) = negative_or_zero(special_vals[i]); } y = x.unaryExpr(internal::test_signbit_op()); @@ -1020,7 +1034,7 @@ template struct shift_left { template Scalar operator()(const Scalar& v) const { - return v << N; + return (v << N); } }; @@ -1028,29 +1042,10 @@ template struct arithmetic_shift_right { template Scalar operator()(const Scalar& v) const { - return v >> N; + return (v >> N); } }; -template void array_integer(const ArrayType& m) -{ - Index rows = m.rows(); - Index cols = m.cols(); - - ArrayType m1 = ArrayType::Random(rows, cols), - m2(rows, cols); - - m2 = m1.template shiftLeft<2>(); - VERIFY( (m2 == m1.unaryExpr(shift_left<2>())).all() ); - m2 = m1.template shiftLeft<9>(); - VERIFY( (m2 == m1.unaryExpr(shift_left<9>())).all() ); - - m2 = m1.template shiftRight<2>(); - VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<2>())).all() ); - m2 = m1.template shiftRight<9>(); - VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<9>())).all() ); -} - template struct signed_shift_test_impl { typedef typename ArrayType::Scalar Scalar; @@ -1064,13 +1059,15 @@ struct signed_shift_test_impl { const Index rows = m.rows(); const Index cols = m.cols(); - ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols); + ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols), m3(rows, cols); - m2 = m1.unaryExpr([](const Scalar& x) { return x >> N; }); - VERIFY((m2 == m1.unaryExpr(internal::scalar_shift_right_op())).all()); + m2 = m1.unaryExpr(internal::scalar_shift_right_op()); + m3 = m1.unaryExpr(arithmetic_shift_right()); + VERIFY_IS_CWISE_EQUAL(m2, m3); - m2 = m1.unaryExpr([](const Scalar& x) { return x << N; }); - VERIFY((m2 == m1.unaryExpr( internal::scalar_shift_left_op())).all()); + m2 = m1.unaryExpr(internal::scalar_shift_left_op()); + m3 = m1.unaryExpr(shift_left()); + VERIFY_IS_CWISE_EQUAL(m2, m3); run(m); } @@ -1193,8 +1190,6 @@ EIGEN_DECLARE_TEST(array_cwise) CALL_SUBTEST_5( array(ArrayXXf(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_6( array(ArrayXXi(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_6( array(Array(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); - CALL_SUBTEST_6( array_integer(ArrayXXi(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); - CALL_SUBTEST_6( array_integer(Array(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_7( signed_shift_test(ArrayXXi(internal::random(1, EIGEN_TEST_MAX_SIZE), internal::random(1, EIGEN_TEST_MAX_SIZE)))); CALL_SUBTEST_7( signed_shift_test(Array(internal::random(1, EIGEN_TEST_MAX_SIZE), internal::random(1, EIGEN_TEST_MAX_SIZE)))); CALL_SUBTEST_8( array(Array(internal::random(1, EIGEN_TEST_MAX_SIZE), internal::random(1, EIGEN_TEST_MAX_SIZE))));