Revert change that made conversion from bfloat16 to {float, double} implicit.
Add roundtrip tests for casting between bfloat16 and complex types.
This commit is contained in:
		
							parent
							
								
									38b91f256b
								
							
						
					
					
						commit
						1b84f21e32
					
				| @ -34,8 +34,9 @@ namespace Eigen { | ||||
| 
 | ||||
| struct bfloat16; | ||||
| 
 | ||||
| // Since we allow implicit conversion of bfloat16 to float and double, we
 | ||||
| // need to make the cast to complex a bit more explicit
 | ||||
| // explicit conversion operators are no available before C++11 so we first cast
 | ||||
| // bfloat16 to RealScalar rather than to std::complex<RealScalar> directly
 | ||||
| #if !EIGEN_HAS_CXX11 | ||||
| namespace internal { | ||||
| template <typename RealScalar> | ||||
| struct cast_impl<bfloat16, std::complex<RealScalar> > { | ||||
| @ -45,6 +46,7 @@ struct cast_impl<bfloat16, std::complex<RealScalar> > { | ||||
|   } | ||||
| }; | ||||
| } // namespace internal
 | ||||
| #endif  // EIGEN_HAS_CXX11
 | ||||
| 
 | ||||
| namespace bfloat16_impl { | ||||
| 
 | ||||
| @ -129,10 +131,10 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { | ||||
|   EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const { | ||||
|     return static_cast<unsigned long long>(bfloat16_to_float(*this)); | ||||
|   } | ||||
|   EIGEN_DEVICE_FUNC operator float() const { | ||||
|   EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { | ||||
|     return bfloat16_impl::bfloat16_to_float(*this); | ||||
|   } | ||||
|   EIGEN_DEVICE_FUNC operator double() const { | ||||
|   EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { | ||||
|     return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)); | ||||
|   } | ||||
|   template<typename RealScalar> | ||||
|  | ||||
| @ -41,6 +41,19 @@ void test_truncate(float input, float expected_truncation, float expected_roundi | ||||
|   VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded)); | ||||
| } | ||||
| 
 | ||||
| template<typename T> | ||||
|  void test_roundtrip() { | ||||
|   // Representable T round trip via bfloat16
 | ||||
|   VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(-std::numeric_limits<T>::infinity())), -std::numeric_limits<T>::infinity()); | ||||
|   VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(std::numeric_limits<T>::infinity())), std::numeric_limits<T>::infinity()); | ||||
|   VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-1.0))), T(-1.0)); | ||||
|   VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.5))), T(-0.5)); | ||||
|   VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.0))), T(-0.0)); | ||||
|   VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(1.0))), T(1.0)); | ||||
|   VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.5))), T(0.5)); | ||||
|   VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.0))), T(0.0)); | ||||
| } | ||||
| 
 | ||||
| void test_conversion() | ||||
| { | ||||
|   using Eigen::bfloat16_impl::__bfloat16_raw; | ||||
| @ -53,9 +66,9 @@ void test_conversion() | ||||
|   VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80);  // Becomes infinity.
 | ||||
| 
 | ||||
|   // Verify round-to-nearest-even behavior.
 | ||||
|   float val1 = bfloat16(__bfloat16_raw(0x3c00)); | ||||
|   float val2 = bfloat16(__bfloat16_raw(0x3c01)); | ||||
|   float val3 = bfloat16(__bfloat16_raw(0x3c02)); | ||||
|   float val1 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c00))); | ||||
|   float val2 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c01))); | ||||
|   float val3 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c02))); | ||||
|   VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00); | ||||
|   VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02); | ||||
| 
 | ||||
| @ -106,14 +119,10 @@ void test_conversion() | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f); | ||||
| 
 | ||||
|   // Representable floats round trip via bfloat16
 | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-std::numeric_limits<float>::infinity())), -std::numeric_limits<float>::infinity()); | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(std::numeric_limits<float>::infinity())), std::numeric_limits<float>::infinity()); | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-1.0f)), -1.0f); | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.5f)), -0.5f); | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.0f)), -0.0f); | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(1.0f)), 1.0f); | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.5f)), 0.5f); | ||||
|   VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.0f)), 0.0f); | ||||
|   test_roundtrip<float>(); | ||||
|   test_roundtrip<double>(); | ||||
|   test_roundtrip<std::complex<float> >(); | ||||
|   test_roundtrip<std::complex<double> >(); | ||||
| 
 | ||||
|   // Truncate test
 | ||||
|   test_truncate( | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Rasmus Munk Larsen
						Rasmus Munk Larsen