tune for AMD GPUs

This commit is contained in:
Ruipeng Li 2022-06-12 11:34:22 -07:00
parent ee88d0b1a4
commit a5b10499cb
5 changed files with 31 additions and 1 deletions

View File

@ -17,7 +17,12 @@
static const char HYPRE_SPGEMM_HASH_TYPE = 'D';
/* default settings associated with bin 5 */
#if defined(HYPRE_USING_CUDA)
#define HYPRE_SPGEMM_NUMER_HASH_SIZE 256
#endif
#if defined(HYPRE_USING_HIP)
#define HYPRE_SPGEMM_NUMER_HASH_SIZE 128
#endif
#define HYPRE_SPGEMM_SYMBL_HASH_SIZE 512
#define HYPRE_SPGEMM_BASE_GROUP_SIZE 32
/* unroll factor in the kernels */

View File

@ -31,9 +31,16 @@ hypreDevice_CSRSpGemmNumerWithRownnzUpperboundNoBin( HYPRE_Int m,
HYPRE_Complex **d_c_out,
HYPRE_Int *nnzC_out )
{
#if defined(HYPRE_USING_CUDA)
const HYPRE_Int SHMEM_HASH_SIZE = HYPRE_SPGEMM_NUMER_HASH_SIZE;
const HYPRE_Int GROUP_SIZE = HYPRE_SPGEMM_BASE_GROUP_SIZE;
const HYPRE_Int BIN = 5;
#endif
#if defined(HYPRE_USING_HIP)
const HYPRE_Int SHMEM_HASH_SIZE = 2 * HYPRE_SPGEMM_NUMER_HASH_SIZE;
const HYPRE_Int GROUP_SIZE = 2 * HYPRE_SPGEMM_BASE_GROUP_SIZE;
const HYPRE_Int BIN = 6;
#endif
#ifdef HYPRE_SPGEMM_PRINTF
HYPRE_Int max_rc = HYPRE_THRUST_CALL(reduce, d_rc, d_rc + m, 0, thrust::maximum<HYPRE_Int>());
@ -122,7 +129,12 @@ hypreDevice_CSRSpGemmNumerWithRownnzUpperboundBinned( HYPRE_Int m,
HYPRE_Int h_bin_ptr[HYPRE_SPGEMM_MAX_NBIN + 1];
//HYPRE_Int num_bins = hypre_HandleSpgemmNumBin(hypre_handle());
HYPRE_Int high_bin = hypre_HandleSpgemmHighestBin(hypre_handle())[1];
#if defined(HYPRE_USING_CUDA)
const char s = 8, t = 2, u = high_bin;
#endif
#if defined(HYPRE_USING_HIP)
const char s = 4, t = 2, u = high_bin;
#endif
hypre_SpGemmCreateBins(m, s, t, u, d_rc, false, d_rind, h_bin_ptr);

View File

@ -454,7 +454,7 @@ hypre_spgemm_numerical_with_rownnz( HYPRE_Int m,
#if defined(HYPRE_USING_CUDA)
const HYPRE_Int BDIMX = hypre_min(4, GROUP_SIZE);
#elif defined(HYPRE_USING_HIP)
const HYPRE_Int BDIMX = hypre_min(2, GROUP_SIZE);
const HYPRE_Int BDIMX = hypre_min(4, GROUP_SIZE);
#endif
const HYPRE_Int BDIMY = GROUP_SIZE / BDIMX;

View File

@ -22,6 +22,7 @@ hypre_spgemm_numerical_with_rownnz
HYPRE_Complex *d_b,
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
#if defined(HYPRE_USING_CUDA)
template HYPRE_Int
hypre_spgemm_numerical_with_rownnz
< 5, HYPRE_SPGEMM_NUMER_HASH_SIZE, HYPRE_SPGEMM_BASE_GROUP_SIZE, false >
@ -30,6 +31,7 @@ hypre_spgemm_numerical_with_rownnz
HYPRE_Int *d_ia, HYPRE_Int *d_ja, HYPRE_Complex *d_a, HYPRE_Int *d_ib, HYPRE_Int *d_jb,
HYPRE_Complex *d_b,
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
#endif
template HYPRE_Int
hypre_spgemm_numerical_max_num_blocks

View File

@ -21,6 +21,17 @@ hypre_spgemm_numerical_with_rownnz
HYPRE_Complex *d_b,
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
#if defined(HYPRE_USING_HIP)
template HYPRE_Int
hypre_spgemm_numerical_with_rownnz
< 6, HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, HYPRE_SPGEMM_BASE_GROUP_SIZE * 2, false >
( HYPRE_Int m, HYPRE_Int *row_ind, HYPRE_Int k, HYPRE_Int n, bool need_ghash,
HYPRE_Int exact_rownnz,
HYPRE_Int *d_ia, HYPRE_Int *d_ja, HYPRE_Complex *d_a, HYPRE_Int *d_ib, HYPRE_Int *d_jb,
HYPRE_Complex *d_b,
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
#endif
template HYPRE_Int
hypre_spgemm_numerical_max_num_blocks
< HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, HYPRE_SPGEMM_BASE_GROUP_SIZE * 2 >