322 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			322 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // This file is part of Eigen, a lightweight C++ template library
 | |
| // for linear algebra.
 | |
| //
 | |
| // Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
 | |
| //
 | |
| // This Source Code Form is subject to the terms of the Mozilla
 | |
| // Public License v. 2.0. If a copy of the MPL was not distributed
 | |
| // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
 | |
| 
 | |
| #include "main.h"
 | |
| 
 | |
| template <typename T, typename U>
 | |
| bool check_if_equal_or_nans(const T& actual, const U& expected) {
 | |
|   return (numext::equal_strict(actual, expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
 | |
| }
 | |
| 
 | |
| template <typename T, typename U>
 | |
| bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
 | |
|   return check_if_equal_or_nans(numext::real(actual), numext::real(expected)) &&
 | |
|          check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
 | |
| }
 | |
| 
 | |
| template <typename T, typename U>
 | |
| bool test_is_equal_or_nans(const T& actual, const U& expected) {
 | |
|   if (check_if_equal_or_nans(actual, expected)) {
 | |
|     return true;
 | |
|   }
 | |
| 
 | |
|   // false:
 | |
|   std::cerr << "\n    actual   = " << actual << "\n    expected = " << expected << "\n\n";
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| #define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
 | |
| 
 | |
| template <typename T>
 | |
| void check_abs() {
 | |
|   typedef typename NumTraits<T>::Real Real;
 | |
|   Real zero(0);
 | |
| 
 | |
|   if (NumTraits<T>::IsSigned) VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
 | |
|   VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
 | |
|   VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
 | |
| 
 | |
|   for (int k = 0; k < 100; ++k) {
 | |
|     T x = internal::random<T>();
 | |
|     if (!internal::is_same<T, bool>::value) x = x / Real(2);
 | |
|     if (NumTraits<T>::IsSigned) {
 | |
|       VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
 | |
|       VERIFY(numext::abs(-x) >= zero);
 | |
|     }
 | |
|     VERIFY(numext::abs(x) >= zero);
 | |
|     VERIFY_IS_APPROX(numext::abs2(x), numext::abs2(numext::abs(x)));
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <typename T>
 | |
| void check_arg() {
 | |
|   typedef typename NumTraits<T>::Real Real;
 | |
|   VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
 | |
|   VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
 | |
| 
 | |
|   for (int k = 0; k < 100; ++k) {
 | |
|     T x = internal::random<T>();
 | |
|     Real y = numext::arg(x);
 | |
|     VERIFY_IS_APPROX(y, std::arg(x));
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <typename T>
 | |
| struct check_sqrt_impl {
 | |
|   static void run() {
 | |
|     for (int i = 0; i < 1000; ++i) {
 | |
|       const T x = numext::abs(internal::random<T>());
 | |
|       const T sqrtx = numext::sqrt(x);
 | |
|       VERIFY_IS_APPROX(sqrtx * sqrtx, x);
 | |
|     }
 | |
| 
 | |
|     // Corner cases.
 | |
|     const T zero = T(0);
 | |
|     const T one = T(1);
 | |
|     const T inf = std::numeric_limits<T>::infinity();
 | |
|     const T nan = std::numeric_limits<T>::quiet_NaN();
 | |
|     VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
 | |
|     VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
 | |
|     VERIFY((numext::isnan)(numext::sqrt(nan)));
 | |
|     VERIFY((numext::isnan)(numext::sqrt(-one)));
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| struct check_sqrt_impl<std::complex<T> > {
 | |
|   static void run() {
 | |
|     typedef typename std::complex<T> ComplexT;
 | |
| 
 | |
|     for (int i = 0; i < 1000; ++i) {
 | |
|       const ComplexT x = internal::random<ComplexT>();
 | |
|       const ComplexT sqrtx = numext::sqrt(x);
 | |
|       VERIFY_IS_APPROX(sqrtx * sqrtx, x);
 | |
|     }
 | |
| 
 | |
|     // Corner cases.
 | |
|     const T zero = T(0);
 | |
|     const T one = T(1);
 | |
|     const T inf = std::numeric_limits<T>::infinity();
 | |
|     const T nan = std::numeric_limits<T>::quiet_NaN();
 | |
| 
 | |
|     // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
 | |
|     const int kNumCorners = 20;
 | |
|     const ComplexT corners[kNumCorners][2] = {
 | |
|         {ComplexT(zero, zero), ComplexT(zero, zero)},  {ComplexT(-zero, zero), ComplexT(zero, zero)},
 | |
|         {ComplexT(zero, -zero), ComplexT(zero, zero)}, {ComplexT(-zero, -zero), ComplexT(zero, zero)},
 | |
|         {ComplexT(one, inf), ComplexT(inf, inf)},      {ComplexT(nan, inf), ComplexT(inf, inf)},
 | |
|         {ComplexT(one, -inf), ComplexT(inf, -inf)},    {ComplexT(nan, -inf), ComplexT(inf, -inf)},
 | |
|         {ComplexT(-inf, one), ComplexT(zero, inf)},    {ComplexT(inf, one), ComplexT(inf, zero)},
 | |
|         {ComplexT(-inf, -one), ComplexT(zero, -inf)},  {ComplexT(inf, -one), ComplexT(inf, -zero)},
 | |
|         {ComplexT(-inf, nan), ComplexT(nan, inf)},     {ComplexT(inf, nan), ComplexT(inf, nan)},
 | |
|         {ComplexT(zero, nan), ComplexT(nan, nan)},     {ComplexT(one, nan), ComplexT(nan, nan)},
 | |
|         {ComplexT(nan, zero), ComplexT(nan, nan)},     {ComplexT(nan, one), ComplexT(nan, nan)},
 | |
|         {ComplexT(nan, -one), ComplexT(nan, nan)},     {ComplexT(nan, nan), ComplexT(nan, nan)},
 | |
|     };
 | |
| 
 | |
|     for (int i = 0; i < kNumCorners; ++i) {
 | |
|       const ComplexT& x = corners[i][0];
 | |
|       const ComplexT sqrtx = corners[i][1];
 | |
|       VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
 | |
|     }
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| void check_sqrt() {
 | |
|   check_sqrt_impl<T>::run();
 | |
| }
 | |
| 
 | |
| template <typename T>
 | |
| struct check_rsqrt_impl {
 | |
|   static void run() {
 | |
|     const T zero = T(0);
 | |
|     const T one = T(1);
 | |
|     const T inf = std::numeric_limits<T>::infinity();
 | |
|     const T nan = std::numeric_limits<T>::quiet_NaN();
 | |
| 
 | |
|     for (int i = 0; i < 1000; ++i) {
 | |
|       const T x = numext::abs(internal::random<T>());
 | |
|       const T rsqrtx = numext::rsqrt(x);
 | |
|       const T invx = one / x;
 | |
|       VERIFY_IS_APPROX(rsqrtx * rsqrtx, invx);
 | |
|     }
 | |
| 
 | |
|     // Corner cases.
 | |
|     VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
 | |
|     VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
 | |
|     VERIFY((numext::isnan)(numext::rsqrt(nan)));
 | |
|     VERIFY((numext::isnan)(numext::rsqrt(-one)));
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| struct check_rsqrt_impl<std::complex<T> > {
 | |
|   static void run() {
 | |
|     typedef typename std::complex<T> ComplexT;
 | |
|     const T zero = T(0);
 | |
|     const T one = T(1);
 | |
|     const T inf = std::numeric_limits<T>::infinity();
 | |
|     const T nan = std::numeric_limits<T>::quiet_NaN();
 | |
| 
 | |
|     for (int i = 0; i < 1000; ++i) {
 | |
|       const ComplexT x = internal::random<ComplexT>();
 | |
|       const ComplexT invx = ComplexT(one, zero) / x;
 | |
|       const ComplexT rsqrtx = numext::rsqrt(x);
 | |
|       VERIFY_IS_APPROX(rsqrtx * rsqrtx, invx);
 | |
|     }
 | |
| 
 | |
| // GCC and MSVC differ in their treatment of 1/(0 + 0i)
 | |
| //   GCC/clang = (inf, nan)
 | |
| //   MSVC = (nan, nan)
 | |
| // and 1 / (x + inf i)
 | |
| //   GCC/clang = (0, 0)
 | |
| //   MSVC = (nan, nan)
 | |
| #if (EIGEN_COMP_GNUC)
 | |
|     {
 | |
|       const int kNumCorners = 20;
 | |
|       const ComplexT corners[kNumCorners][2] = {
 | |
|           // Only consistent across GCC, clang
 | |
|           {ComplexT(zero, zero), ComplexT(zero, zero)},
 | |
|           {ComplexT(-zero, zero), ComplexT(zero, zero)},
 | |
|           {ComplexT(zero, -zero), ComplexT(zero, zero)},
 | |
|           {ComplexT(-zero, -zero), ComplexT(zero, zero)},
 | |
|           {ComplexT(one, inf), ComplexT(inf, inf)},
 | |
|           {ComplexT(nan, inf), ComplexT(inf, inf)},
 | |
|           {ComplexT(one, -inf), ComplexT(inf, -inf)},
 | |
|           {ComplexT(nan, -inf), ComplexT(inf, -inf)},
 | |
|           // Consistent across GCC, clang, MSVC
 | |
|           {ComplexT(-inf, one), ComplexT(zero, inf)},
 | |
|           {ComplexT(inf, one), ComplexT(inf, zero)},
 | |
|           {ComplexT(-inf, -one), ComplexT(zero, -inf)},
 | |
|           {ComplexT(inf, -one), ComplexT(inf, -zero)},
 | |
|           {ComplexT(-inf, nan), ComplexT(nan, inf)},
 | |
|           {ComplexT(inf, nan), ComplexT(inf, nan)},
 | |
|           {ComplexT(zero, nan), ComplexT(nan, nan)},
 | |
|           {ComplexT(one, nan), ComplexT(nan, nan)},
 | |
|           {ComplexT(nan, zero), ComplexT(nan, nan)},
 | |
|           {ComplexT(nan, one), ComplexT(nan, nan)},
 | |
|           {ComplexT(nan, -one), ComplexT(nan, nan)},
 | |
|           {ComplexT(nan, nan), ComplexT(nan, nan)},
 | |
|       };
 | |
| 
 | |
|       for (int i = 0; i < kNumCorners; ++i) {
 | |
|         const ComplexT& x = corners[i][0];
 | |
|         const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
 | |
|         VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
 | |
|       }
 | |
|     }
 | |
| #endif
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| void check_rsqrt() {
 | |
|   check_rsqrt_impl<T>::run();
 | |
| }
 | |
| 
 | |
| template <typename T>
 | |
| struct check_signbit_impl {
 | |
|   static void run() {
 | |
|     T true_mask;
 | |
|     std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(T));
 | |
|     T false_mask;
 | |
|     std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(T));
 | |
| 
 | |
|     std::vector<T> negative_values;
 | |
|     std::vector<T> non_negative_values;
 | |
| 
 | |
|     if (NumTraits<T>::IsInteger) {
 | |
|       negative_values = {static_cast<T>(-1), static_cast<T>(NumTraits<T>::lowest())};
 | |
|       non_negative_values = {static_cast<T>(0), static_cast<T>(1), static_cast<T>(NumTraits<T>::highest())};
 | |
|     } else {
 | |
|       // has sign bit
 | |
|       const T neg_zero = static_cast<T>(-0.0);
 | |
|       const T neg_one = static_cast<T>(-1.0);
 | |
|       const T neg_inf = -std::numeric_limits<T>::infinity();
 | |
|       const T neg_nan = -std::numeric_limits<T>::quiet_NaN();
 | |
|       // does not have sign bit
 | |
|       const T pos_zero = static_cast<T>(0.0);
 | |
|       const T pos_one = static_cast<T>(1.0);
 | |
|       const T pos_inf = std::numeric_limits<T>::infinity();
 | |
|       const T pos_nan = std::numeric_limits<T>::quiet_NaN();
 | |
|       negative_values = {neg_zero, neg_one, neg_inf, neg_nan};
 | |
|       non_negative_values = {pos_zero, pos_one, pos_inf, pos_nan};
 | |
|     }
 | |
| 
 | |
|     auto check_all = [](auto values, auto expected) {
 | |
|       bool all_pass = true;
 | |
|       for (T val : values) {
 | |
|         const T numext_val = numext::signbit(val);
 | |
|         bool not_same = internal::predux_any(internal::bitwise_helper<T>::bitwise_xor(expected, numext_val));
 | |
|         all_pass = all_pass && !not_same;
 | |
|         if (not_same) std::cout << "signbit(" << val << ") = " << numext_val << " != " << expected << std::endl;
 | |
|       }
 | |
|       return all_pass;
 | |
|     };
 | |
| 
 | |
|     bool check_all_pass = check_all(non_negative_values, false_mask);
 | |
|     check_all_pass = check_all_pass && check_all(negative_values, (NumTraits<T>::IsSigned ? true_mask : false_mask));
 | |
|     VERIFY(check_all_pass);
 | |
|   }
 | |
| };
 | |
| template <typename T>
 | |
| void check_signbit() {
 | |
|   check_signbit_impl<T>::run();
 | |
| }
 | |
| 
 | |
| EIGEN_DECLARE_TEST(numext) {
 | |
|   for (int k = 0; k < g_repeat; ++k) {
 | |
|     CALL_SUBTEST(check_abs<bool>());
 | |
|     CALL_SUBTEST(check_abs<signed char>());
 | |
|     CALL_SUBTEST(check_abs<unsigned char>());
 | |
|     CALL_SUBTEST(check_abs<short>());
 | |
|     CALL_SUBTEST(check_abs<unsigned short>());
 | |
|     CALL_SUBTEST(check_abs<int>());
 | |
|     CALL_SUBTEST(check_abs<unsigned int>());
 | |
|     CALL_SUBTEST(check_abs<long>());
 | |
|     CALL_SUBTEST(check_abs<unsigned long>());
 | |
|     CALL_SUBTEST(check_abs<half>());
 | |
|     CALL_SUBTEST(check_abs<bfloat16>());
 | |
|     CALL_SUBTEST(check_abs<float>());
 | |
|     CALL_SUBTEST(check_abs<double>());
 | |
|     CALL_SUBTEST(check_abs<long double>());
 | |
|     CALL_SUBTEST(check_abs<std::complex<float> >());
 | |
|     CALL_SUBTEST(check_abs<std::complex<double> >());
 | |
| 
 | |
|     CALL_SUBTEST(check_arg<std::complex<float> >());
 | |
|     CALL_SUBTEST(check_arg<std::complex<double> >());
 | |
| 
 | |
|     CALL_SUBTEST(check_sqrt<float>());
 | |
|     CALL_SUBTEST(check_sqrt<double>());
 | |
|     CALL_SUBTEST(check_sqrt<std::complex<float> >());
 | |
|     CALL_SUBTEST(check_sqrt<std::complex<double> >());
 | |
| 
 | |
|     CALL_SUBTEST(check_rsqrt<float>());
 | |
|     CALL_SUBTEST(check_rsqrt<double>());
 | |
|     CALL_SUBTEST(check_rsqrt<std::complex<float> >());
 | |
|     CALL_SUBTEST(check_rsqrt<std::complex<double> >());
 | |
| 
 | |
|     CALL_SUBTEST(check_signbit<half>());
 | |
|     CALL_SUBTEST(check_signbit<bfloat16>());
 | |
|     CALL_SUBTEST(check_signbit<float>());
 | |
|     CALL_SUBTEST(check_signbit<double>());
 | |
| 
 | |
|     CALL_SUBTEST(check_signbit<uint8_t>());
 | |
|     CALL_SUBTEST(check_signbit<uint16_t>());
 | |
|     CALL_SUBTEST(check_signbit<uint32_t>());
 | |
|     CALL_SUBTEST(check_signbit<uint64_t>());
 | |
| 
 | |
|     CALL_SUBTEST(check_signbit<int8_t>());
 | |
|     CALL_SUBTEST(check_signbit<int16_t>());
 | |
|     CALL_SUBTEST(check_signbit<int32_t>());
 | |
|     CALL_SUBTEST(check_signbit<int64_t>());
 | |
|   }
 | |
| }
 | 
