sync device at ending timing

This commit is contained in:
Ruipeng Li 2022-03-07 15:13:08 -08:00
parent c2e4836c1e
commit b97fbc13ed
4 changed files with 23 additions and 3 deletions

View File

@ -1793,6 +1793,8 @@ HYPRE_Int hypre_CurandUniform( HYPRE_Int n, HYPRE_Real *urand, HYPRE_Int set_see
HYPRE_Int hypre_CurandUniformSingle( HYPRE_Int n, float *urand, HYPRE_Int set_seed,
hypre_ulonglongint seed, HYPRE_Int set_offset, hypre_ulonglongint offset);
HYPRE_Int hypre_ResetDeviceRandGenerator( hypre_ulonglongint seed, hypre_ulonglongint offset );
HYPRE_Int hypre_bind_device(HYPRE_Int myid, HYPRE_Int nproc, MPI_Comm comm);
/* nvtx.c */

View File

@ -1309,8 +1309,6 @@ hypre_DeviceDataComputeStream(hypre_DeviceData *data)
#endif // #if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
#if defined(HYPRE_USING_CURAND)
curandGenerator_t
hypre_DeviceDataCurandGenerator(hypre_DeviceData *data)
@ -1323,6 +1321,7 @@ hypre_DeviceDataCurandGenerator(hypre_DeviceData *data)
curandGenerator_t gen;
HYPRE_CURAND_CALL( curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT) );
HYPRE_CURAND_CALL( curandSetPseudoRandomGeneratorSeed(gen, 1234ULL) );
HYPRE_CURAND_CALL( curandSetGeneratorOffset(gen, 0) );
HYPRE_CURAND_CALL( curandSetStream(gen, hypre_DeviceDataComputeStream(data)) );
data->curand_generator = gen;
@ -1377,6 +1376,7 @@ hypre_DeviceDataCurandGenerator(hypre_DeviceData *data)
rocrand_generator gen;
HYPRE_ROCRAND_CALL( rocrand_create_generator(&gen, ROCRAND_RNG_PSEUDO_DEFAULT) );
HYPRE_ROCRAND_CALL( rocrand_set_seed(gen, 1234ULL) );
HYPRE_ROCRAND_CALL( rocrand_set_offset(gen, 0) );
HYPRE_ROCRAND_CALL( rocrand_set_stream(gen, hypre_DeviceDataComputeStream(data)) );
data->curand_generator = gen;
@ -1446,6 +1446,21 @@ hypre_CurandUniformSingle( HYPRE_Int n,
return hypre_CurandUniform_core(n, urand, set_seed, seed, set_offset, offset);
}
HYPRE_Int
hypre_ResetDeviceRandGenerator( hypre_ulonglongint seed,
hypre_ulonglongint offset )
{
curandGenerator_t gen = hypre_HandleCurandGenerator(hypre_handle());
#if defined(HYPRE_USING_CURAND)
HYPRE_CURAND_CALL( curandSetPseudoRandomGeneratorSeed(gen, seed) );
HYPRE_CURAND_CALL( curandSetGeneratorOffset(gen, offset) );
#elif defined(HYPRE_USING_ROCRAND)
HYPRE_ROCRAND_CALL( rocrand_set_seed(gen, seed) );
HYPRE_ROCRAND_CALL( rocrand_set_offset(gen, offset) );
#endif
return hypre_error_flag;
}
#endif /* #if defined(HYPRE_USING_CURAND) || defined(HYPRE_USING_ROCRAND) */
#if defined(HYPRE_USING_CUBLAS)

View File

@ -303,6 +303,8 @@ HYPRE_Int hypre_CurandUniform( HYPRE_Int n, HYPRE_Real *urand, HYPRE_Int set_see
HYPRE_Int hypre_CurandUniformSingle( HYPRE_Int n, float *urand, HYPRE_Int set_seed,
hypre_ulonglongint seed, HYPRE_Int set_offset, hypre_ulonglongint offset);
HYPRE_Int hypre_ResetDeviceRandGenerator( hypre_ulonglongint seed, hypre_ulonglongint offset );
HYPRE_Int hypre_bind_device(HYPRE_Int myid, HYPRE_Int nproc, MPI_Comm comm);
/* nvtx.c */

View File

@ -279,7 +279,8 @@ hypre_EndTiming( HYPRE_Int time_index )
if (hypre_TimingState(time_index) == 0)
{
#if defined(HYPRE_USING_GPU)
hypre_ForceSyncComputeStream(hypre_handle());
/* hypre_ForceSyncComputeStream(hypre_handle()); */
hypre_SyncCudaDevice(hypre_handle());
#endif
hypre_StopTiming();
hypre_TimingWallTime(time_index) += hypre_TimingWallCount;