Vectorize tensor.isnan() by using typed predicates.
This commit is contained in:
		
							parent
							
								
									f02856c640
								
							
						
					
					
						commit
						0488b708b4
					
				| @ -443,6 +443,12 @@ pnot(const Packet& a) { | ||||
| template<typename Packet> EIGEN_DEVICE_FUNC inline Packet | ||||
| pandnot(const Packet& a, const Packet& b) { return pand(a, pnot(b)); } | ||||
| 
 | ||||
| /** \internal \returns isnan(a) */ | ||||
| template<typename Packet> EIGEN_DEVICE_FUNC inline Packet | ||||
| pisnan(const Packet& a) { | ||||
|   return pandnot(ptrue(a), pcmp_eq(a, a)); | ||||
| } | ||||
| 
 | ||||
| // In the general case, use bitwise select.
 | ||||
| template<typename Packet, typename EnableIf = void> | ||||
| struct pselect_impl { | ||||
|  | ||||
| @ -634,6 +634,7 @@ template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8 | ||||
| template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); } | ||||
| template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); } | ||||
| template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); } | ||||
| template<> EIGEN_STRONG_INLINE Packet8f pisnan(const Packet8f& a) { return _mm256_cmp_ps(a,a,_CMP_UNORD_Q); } | ||||
| 
 | ||||
| template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); } | ||||
| template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); } | ||||
|  | ||||
| @ -353,7 +353,7 @@ EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) { | ||||
| } | ||||
| template <> | ||||
| EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) { | ||||
|   return _mm512_sub_epi32(_mm512_set1_epi32(0), a); | ||||
|   return _mm512_sub_epi32(_mm512_setzero_si512(), a); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| @ -580,66 +580,72 @@ EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) { | ||||
|   return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| EIGEN_STRONG_INLINE Packet16f pisnan(const Packet16f& a) { | ||||
|   __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_UNORD_Q); | ||||
|   return _mm512_castsi512_ps(_mm512_maskz_set1_epi32(mask, 0xffffffffu)); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) { | ||||
|   __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); | ||||
|   return _mm512_castsi512_ps( | ||||
|       _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); | ||||
|       _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); | ||||
| } | ||||
| template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) { | ||||
|   __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ); | ||||
|   return _mm512_castsi512_ps( | ||||
|       _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); | ||||
|       _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); | ||||
| } | ||||
| 
 | ||||
| template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) { | ||||
|   __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); | ||||
|   return _mm512_castsi512_ps( | ||||
|       _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); | ||||
|       _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); | ||||
| } | ||||
| 
 | ||||
| template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) { | ||||
|   __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ); | ||||
|   return _mm512_castsi512_ps( | ||||
|       _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); | ||||
|       _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); | ||||
| } | ||||
| 
 | ||||
| template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) { | ||||
|   __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ); | ||||
|   return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); | ||||
|   return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); | ||||
| } | ||||
| template<> EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) { | ||||
|   __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE); | ||||
|   return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); | ||||
|   return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); | ||||
| } | ||||
| template<> EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) { | ||||
|   __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT); | ||||
|   return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); | ||||
|   return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) { | ||||
|   __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ); | ||||
|   return _mm512_castsi512_pd( | ||||
|       _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); | ||||
|       _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu)); | ||||
| } | ||||
| template <> | ||||
| EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) { | ||||
|   __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ); | ||||
|   return _mm512_castsi512_pd( | ||||
|       _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); | ||||
|       _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu)); | ||||
| } | ||||
| template <> | ||||
| EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) { | ||||
|   __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ); | ||||
|   return _mm512_castsi512_pd( | ||||
|       _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); | ||||
|       _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu)); | ||||
| } | ||||
| template <> | ||||
| EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) { | ||||
|   __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ); | ||||
|   return _mm512_castsi512_pd( | ||||
|       _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); | ||||
|       _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu)); | ||||
| } | ||||
| 
 | ||||
| template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); } | ||||
|  | ||||
| @ -16,6 +16,33 @@ namespace Eigen { | ||||
| 
 | ||||
| namespace internal { | ||||
| 
 | ||||
| template <> | ||||
| struct type_casting_traits<float, bool> { | ||||
|   enum { | ||||
|     VectorizedCast = 1, | ||||
|     SrcCoeffRatio = 1, | ||||
|     TgtCoeffRatio = 1 | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
| template <> | ||||
| struct type_casting_traits<bool,float> { | ||||
|   enum { | ||||
|     VectorizedCast = 1, | ||||
|     SrcCoeffRatio = 1, | ||||
|     TgtCoeffRatio = 1 | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
| template<> EIGEN_STRONG_INLINE Packet16b pcast<Packet16f, Packet16b>(const Packet16f& a) { | ||||
|   __mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a)); | ||||
|   return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1)); | ||||
| } | ||||
| 
 | ||||
| template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16b, Packet16f>(const Packet16b& a) { | ||||
|   return _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(a)); | ||||
| } | ||||
| 
 | ||||
| template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) { | ||||
|   return _mm512_cvttps_epi32(a); | ||||
| } | ||||
|  | ||||
| @ -1124,7 +1124,7 @@ Packet psqrt_complex(const Packet& a) { | ||||
|   Packet imag_inf_result; | ||||
|   imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); | ||||
|   // unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan
 | ||||
|   Packet result_is_nan = pandnot(ptrue(result), pcmp_eq(result, result)); | ||||
|   Packet result_is_nan = pisnan(result); | ||||
|   result = por(result_is_nan, result); | ||||
| 
 | ||||
|   return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result)); | ||||
| @ -1796,7 +1796,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac | ||||
|   const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one); | ||||
|   const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); | ||||
|   const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); | ||||
|   const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x)); | ||||
|   const Packet x_is_nan = pisnan(x); | ||||
| 
 | ||||
|   // Predicates for sign and magnitude of y.
 | ||||
|   const Packet abs_y = pabs(y); | ||||
| @ -1804,7 +1804,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac | ||||
|   const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero); | ||||
|   const Packet y_is_neg = pcmp_lt(y, cst_zero); | ||||
|   const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg)); | ||||
|   const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y)); | ||||
|   const Packet y_is_nan = pisnan(y); | ||||
|   const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf); | ||||
|   EIGEN_CONSTEXPR Scalar huge_exponent = | ||||
|       (NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits<Scalar>::epsilon(); | ||||
|  | ||||
| @ -859,22 +859,39 @@ struct functor_traits<scalar_ceil_op<Scalar> > | ||||
|   * \brief Template functor to compute whether a scalar is NaN | ||||
|   * \sa class CwiseUnaryOp, ArrayBase::isnan() | ||||
|   */ | ||||
| template<typename Scalar> struct scalar_isnan_op { | ||||
|   typedef bool result_type; | ||||
|   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const { | ||||
| template<typename Scalar, bool UseTypedPredicate=false> | ||||
| struct scalar_isnan_op { | ||||
|   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const Scalar& a) const { | ||||
| #if defined(SYCL_DEVICE_ONLY) | ||||
|     return numext::isnan(a); | ||||
| #else | ||||
|     return (numext::isnan)(a); | ||||
|     return numext::isnan EIGEN_NOT_A_MACRO (a); | ||||
| #endif | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
| template<typename Scalar> | ||||
| struct functor_traits<scalar_isnan_op<Scalar> > | ||||
| struct scalar_isnan_op<Scalar, true> { | ||||
|   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator() (const Scalar& a) const { | ||||
| #if defined(SYCL_DEVICE_ONLY) | ||||
|     return (numext::isnan(a) ? ptrue(a) : pzero(a)); | ||||
| #else | ||||
|     return (numext::isnan EIGEN_NOT_A_MACRO (a)  ? ptrue(a) : pzero(a)); | ||||
| #endif | ||||
|   } | ||||
|   template <typename Packet> | ||||
|   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { | ||||
|     return pisnan(a); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template<typename Scalar, bool UseTypedPredicate> | ||||
| struct functor_traits<scalar_isnan_op<Scalar, UseTypedPredicate> > | ||||
| { | ||||
|   enum { | ||||
|     Cost = NumTraits<Scalar>::MulCost, | ||||
|     PacketAccess = false | ||||
|     PacketAccess = packet_traits<Scalar>::HasCmp && UseTypedPredicate | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -611,12 +611,13 @@ class TensorBase<Derived, ReadOnlyAccessors> | ||||
|       return operator!=(constant(threshold)); | ||||
|     } | ||||
| 
 | ||||
|     // Checks
 | ||||
|     // Predicates.
 | ||||
|     EIGEN_DEVICE_FUNC | ||||
|     EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isnan_op<Scalar>, const Derived> | ||||
|     EIGEN_STRONG_INLINE const TensorConversionOp<bool, const TensorCwiseUnaryOp<internal::scalar_isnan_op<Scalar, true>, const Derived>> | ||||
|     (isnan)() const { | ||||
|       return unaryExpr(internal::scalar_isnan_op<Scalar>()); | ||||
|       return unaryExpr(internal::scalar_isnan_op<Scalar, true>()).template cast<bool>(); | ||||
|     } | ||||
| 
 | ||||
|     EIGEN_DEVICE_FUNC | ||||
|     EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isinf_op<Scalar>, const Derived> | ||||
|     (isinf)() const { | ||||
| @ -1219,4 +1220,3 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> { | ||||
| } // end namespace Eigen
 | ||||
| 
 | ||||
| #endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H
 | ||||
| 
 | ||||
|  | ||||
| @ -79,8 +79,36 @@ static void test_equality() | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| static void test_isnan() | ||||
| { | ||||
|   Tensor<Scalar, 3> mat(2,3,7); | ||||
| 
 | ||||
|   mat.setRandom(); | ||||
|   for (int i = 0; i < 2; ++i) { | ||||
|     for (int j = 0; j < 3; ++j) { | ||||
|       for (int k = 0; k < 7; ++k) { | ||||
|         if (internal::random<bool>()) { | ||||
|           mat(i,j,k) = std::numeric_limits<Scalar>::quiet_NaN(); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   Tensor<bool, 3> nan(2,3,7); | ||||
|   nan = (mat.isnan)(); | ||||
|   for (int i = 0; i < 2; ++i) { | ||||
|     for (int j = 0; j < 3; ++j) { | ||||
|       for (int k = 0; k < 7; ++k) { | ||||
|         VERIFY_IS_EQUAL(nan(i,j,k), (std::isnan)(mat(i,j,k))); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| EIGEN_DECLARE_TEST(cxx11_tensor_comparisons) | ||||
| { | ||||
|   CALL_SUBTEST(test_orderings()); | ||||
|   CALL_SUBTEST(test_equality()); | ||||
|   CALL_SUBTEST(test_isnan()); | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Rasmus Munk Larsen
						Rasmus Munk Larsen