From cb1e8228e9654f78f89e3371c2e5f6e33a41118a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20Schl=C3=BCter?= Date: Sun, 13 Mar 2022 22:27:06 +0900 Subject: [PATCH] Convert bit calculation to constexpr, avoid casts. --- .../arch/Default/GenericPacketMathFunctions.h | 59 +++++++++---------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index dedf976cb..822113b9b 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -28,11 +28,11 @@ template<> struct make_integer { typedef numext::int64_t type; }; template<> struct make_integer { typedef numext::int16_t type; }; template<> struct make_integer { typedef numext::int16_t type; }; -template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pfrexp_generic_get_biased_exponent(const Packet& a) { typedef typename unpacket_traits::type Scalar; typedef typename unpacket_traits::integer_packet PacketI; - enum { mantissa_bits = numext::numeric_limits::digits - 1}; + static constexpr int mantissa_bits = numext::numeric_limits::digits - 1; return pcast(plogical_shift_right(preinterpret(pabs(a)))); } @@ -42,42 +42,41 @@ template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pfrexp_generic(const Packet& a, Packet& exponent) { typedef typename unpacket_traits::type Scalar; typedef typename make_unsigned::type>::type ScalarUI; - enum { + static constexpr int TotalBits = sizeof(Scalar) * CHAR_BIT, MantissaBits = numext::numeric_limits::digits - 1, - ExponentBits = int(TotalBits) - int(MantissaBits) - 1 - }; + ExponentBits = TotalBits - MantissaBits - 1; - EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask = - ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000 - const Packet sign_mantissa_mask = pset1frombits(static_cast(scalar_sign_mantissa_mask)); + EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask = + ~(((ScalarUI(1) << ExponentBits) - ScalarUI(1)) << MantissaBits); // ~0x7f800000 + const Packet sign_mantissa_mask = pset1frombits(static_cast(scalar_sign_mantissa_mask)); const Packet half = pset1(Scalar(0.5)); const Packet zero = pzero(a); const Packet normal_min = pset1((numext::numeric_limits::min)()); // Minimum normal value, 2^-126 - + // To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1). const Packet is_denormal = pcmp_lt(pabs(a), normal_min); - EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24 + EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(MantissaBits + 1); // 24 // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr. const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24 - const Packet normalization_factor = pset1(scalar_normalization_factor); + const Packet normalization_factor = pset1(scalar_normalization_factor); const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a); - + // Determine exponent offset: -126 if normal, -126-24 if denormal - const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126 + const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(ExponentBits-1)) - ScalarUI(2)); // -126 Packet exponent_offset = pset1(scalar_exponent_offset); const Packet normalization_offset = pset1(-Scalar(scalar_normalization_offset)); // -24 exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset); - + // Determine exponent and mantissa from normalized_a. exponent = pfrexp_generic_get_biased_exponent(normalized_a); // Zero, Inf and NaN return 'a' unmodified, exponent is zero // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero) - const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255 + const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << ExponentBits) - ScalarUI(1)); // 255 const Packet non_finite_exponent = pset1(scalar_non_finite_exponent); const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent)); const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half)); - exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset)); + exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset)); return m; } @@ -110,25 +109,24 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) { typedef typename unpacket_traits::integer_packet PacketI; typedef typename unpacket_traits::type Scalar; typedef typename unpacket_traits::type ScalarI; - enum { + static constexpr int TotalBits = sizeof(Scalar) * CHAR_BIT, MantissaBits = numext::numeric_limits::digits - 1, - ExponentBits = int(TotalBits) - int(MantissaBits) - 1 - }; + ExponentBits = TotalBits - MantissaBits - 1; - const Packet max_exponent = pset1(Scalar((ScalarI(1)<((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127 + const Packet max_exponent = pset1(Scalar((ScalarI(1)<((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1)); // 127 const PacketI e = pcast(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); - Packet c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^b + Packet c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^b Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) b = psub(psub(psub(e, b), b), b); // e - 3b - c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^(e-3*b) + c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^(e-3*b) out = pmul(out, c); return out; } -// Explicitly multiplies +// Explicitly multiplies // a * (2^e) // clamping e to the range // [NumTraits::min_exponent()-2, NumTraits::max_exponent()] @@ -142,20 +140,19 @@ struct pldexp_fast_impl { typedef typename unpacket_traits::integer_packet PacketI; typedef typename unpacket_traits::type Scalar; typedef typename unpacket_traits::type ScalarI; - enum { + static constexpr int TotalBits = sizeof(Scalar) * CHAR_BIT, MantissaBits = numext::numeric_limits::digits - 1, - ExponentBits = int(TotalBits) - int(MantissaBits) - 1 - }; - + ExponentBits = TotalBits - MantissaBits - 1; + static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet run(const Packet& a, const Packet& exponent) { - const Packet bias = pset1(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127 - const Packet limit = pset1(Scalar((ScalarI(1)<(Scalar((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1))); // 127 + const Packet limit = pset1(Scalar((ScalarI(1)<(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127 // return a * (2^e) - return pmul(a, preinterpret(plogical_shift_left(e))); + return pmul(a, preinterpret(plogical_shift_left(e))); } };