Fix scalar_logistic_function overflow for complex inputs.
This commit is contained in:
		
							parent
							
								
									9688081029
								
							
						
					
					
						commit
						3252ecc7a4
					
				| @ -1091,12 +1091,9 @@ struct functor_traits<scalar_sign_op<Scalar>> { | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
| /** \internal
 | ||||
|  * \brief Template functor to compute the logistic function of a scalar | ||||
|  * \sa class CwiseUnaryOp, ArrayBase::logistic() | ||||
|  */ | ||||
| template <typename T> | ||||
| struct scalar_logistic_op { | ||||
| // Real-valued implementation.
 | ||||
| template <typename T, typename EnableIf = void> | ||||
| struct scalar_logistic_op_impl { | ||||
|   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return packetOp(x); } | ||||
| 
 | ||||
|   template <typename Packet> | ||||
| @ -1109,6 +1106,22 @@ struct scalar_logistic_op { | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Complex-valud implementation.
 | ||||
| template <typename T> | ||||
| struct scalar_logistic_op_impl<T, std::enable_if_t<NumTraits<T>::IsComplex>> { | ||||
|   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { | ||||
|     const T e = numext::exp(x); | ||||
|     return (numext::isinf)(numext::real(e)) ? T(1) : e / (e + T(1)); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| /** \internal
 | ||||
|  * \brief Template functor to compute the logistic function of a scalar | ||||
|  * \sa class CwiseUnaryOp, ArrayBase::logistic() | ||||
|  */ | ||||
| template <typename T> | ||||
| struct scalar_logistic_op : scalar_logistic_op_impl<T> {}; | ||||
| 
 | ||||
| // TODO(rmlarsen): Enable the following on host when integer_packet is defined
 | ||||
| // for the relevant packet types.
 | ||||
| #ifdef EIGEN_GPU_CC | ||||
| @ -1206,7 +1219,7 @@ struct functor_traits<scalar_logistic_op<T>> { | ||||
|     Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value + | ||||
|            (internal::is_same<T, float>::value ? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11 | ||||
|                                                : NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T>>::Cost), | ||||
|     PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv && | ||||
|     PacketAccess = !NumTraits<T>::IsComplex && packet_traits<T>::HasAdd && packet_traits<T>::HasDiv && | ||||
|                    (internal::is_same<T, float>::value | ||||
|                         ? packet_traits<T>::HasMul && packet_traits<T>::HasMax && packet_traits<T>::HasMin | ||||
|                         : packet_traits<T>::HasNegate && packet_traits<T>::HasExp) | ||||
|  | ||||
| @ -976,7 +976,14 @@ template<typename ArrayType> void array_complex(const ArrayType& m) | ||||
|   VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1))); | ||||
|   VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1))); | ||||
|   VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1)))); | ||||
|   VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0 + exp(-m1)))); | ||||
|   VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1)))); | ||||
|   if (m1.size() > 0) { | ||||
|     // Complex exponential overflow edge-case.
 | ||||
|     Scalar old_m1_val = m1(0, 0); | ||||
|     m1(0, 0) = std::complex<RealScalar>(1000.0, 1000.0); | ||||
|     VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1)))); | ||||
|     m1(0, 0) = old_m1_val;  // Restore value for future tests.
 | ||||
|   } | ||||
| 
 | ||||
|   for (Index i = 0; i < m.rows(); ++i) | ||||
|     for (Index j = 0; j < m.cols(); ++j) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Antonio Sánchez
						Antonio Sánchez