tune for AMD GPUs
This commit is contained in:
parent
ee88d0b1a4
commit
a5b10499cb
@ -17,7 +17,12 @@
|
|||||||
static const char HYPRE_SPGEMM_HASH_TYPE = 'D';
|
static const char HYPRE_SPGEMM_HASH_TYPE = 'D';
|
||||||
|
|
||||||
/* default settings associated with bin 5 */
|
/* default settings associated with bin 5 */
|
||||||
|
#if defined(HYPRE_USING_CUDA)
|
||||||
#define HYPRE_SPGEMM_NUMER_HASH_SIZE 256
|
#define HYPRE_SPGEMM_NUMER_HASH_SIZE 256
|
||||||
|
#endif
|
||||||
|
#if defined(HYPRE_USING_HIP)
|
||||||
|
#define HYPRE_SPGEMM_NUMER_HASH_SIZE 128
|
||||||
|
#endif
|
||||||
#define HYPRE_SPGEMM_SYMBL_HASH_SIZE 512
|
#define HYPRE_SPGEMM_SYMBL_HASH_SIZE 512
|
||||||
#define HYPRE_SPGEMM_BASE_GROUP_SIZE 32
|
#define HYPRE_SPGEMM_BASE_GROUP_SIZE 32
|
||||||
/* unroll factor in the kernels */
|
/* unroll factor in the kernels */
|
||||||
|
|||||||
@ -31,9 +31,16 @@ hypreDevice_CSRSpGemmNumerWithRownnzUpperboundNoBin( HYPRE_Int m,
|
|||||||
HYPRE_Complex **d_c_out,
|
HYPRE_Complex **d_c_out,
|
||||||
HYPRE_Int *nnzC_out )
|
HYPRE_Int *nnzC_out )
|
||||||
{
|
{
|
||||||
|
#if defined(HYPRE_USING_CUDA)
|
||||||
const HYPRE_Int SHMEM_HASH_SIZE = HYPRE_SPGEMM_NUMER_HASH_SIZE;
|
const HYPRE_Int SHMEM_HASH_SIZE = HYPRE_SPGEMM_NUMER_HASH_SIZE;
|
||||||
const HYPRE_Int GROUP_SIZE = HYPRE_SPGEMM_BASE_GROUP_SIZE;
|
const HYPRE_Int GROUP_SIZE = HYPRE_SPGEMM_BASE_GROUP_SIZE;
|
||||||
const HYPRE_Int BIN = 5;
|
const HYPRE_Int BIN = 5;
|
||||||
|
#endif
|
||||||
|
#if defined(HYPRE_USING_HIP)
|
||||||
|
const HYPRE_Int SHMEM_HASH_SIZE = 2 * HYPRE_SPGEMM_NUMER_HASH_SIZE;
|
||||||
|
const HYPRE_Int GROUP_SIZE = 2 * HYPRE_SPGEMM_BASE_GROUP_SIZE;
|
||||||
|
const HYPRE_Int BIN = 6;
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef HYPRE_SPGEMM_PRINTF
|
#ifdef HYPRE_SPGEMM_PRINTF
|
||||||
HYPRE_Int max_rc = HYPRE_THRUST_CALL(reduce, d_rc, d_rc + m, 0, thrust::maximum<HYPRE_Int>());
|
HYPRE_Int max_rc = HYPRE_THRUST_CALL(reduce, d_rc, d_rc + m, 0, thrust::maximum<HYPRE_Int>());
|
||||||
@ -122,7 +129,12 @@ hypreDevice_CSRSpGemmNumerWithRownnzUpperboundBinned( HYPRE_Int m,
|
|||||||
HYPRE_Int h_bin_ptr[HYPRE_SPGEMM_MAX_NBIN + 1];
|
HYPRE_Int h_bin_ptr[HYPRE_SPGEMM_MAX_NBIN + 1];
|
||||||
//HYPRE_Int num_bins = hypre_HandleSpgemmNumBin(hypre_handle());
|
//HYPRE_Int num_bins = hypre_HandleSpgemmNumBin(hypre_handle());
|
||||||
HYPRE_Int high_bin = hypre_HandleSpgemmHighestBin(hypre_handle())[1];
|
HYPRE_Int high_bin = hypre_HandleSpgemmHighestBin(hypre_handle())[1];
|
||||||
|
#if defined(HYPRE_USING_CUDA)
|
||||||
const char s = 8, t = 2, u = high_bin;
|
const char s = 8, t = 2, u = high_bin;
|
||||||
|
#endif
|
||||||
|
#if defined(HYPRE_USING_HIP)
|
||||||
|
const char s = 4, t = 2, u = high_bin;
|
||||||
|
#endif
|
||||||
|
|
||||||
hypre_SpGemmCreateBins(m, s, t, u, d_rc, false, d_rind, h_bin_ptr);
|
hypre_SpGemmCreateBins(m, s, t, u, d_rc, false, d_rind, h_bin_ptr);
|
||||||
|
|
||||||
|
|||||||
@ -454,7 +454,7 @@ hypre_spgemm_numerical_with_rownnz( HYPRE_Int m,
|
|||||||
#if defined(HYPRE_USING_CUDA)
|
#if defined(HYPRE_USING_CUDA)
|
||||||
const HYPRE_Int BDIMX = hypre_min(4, GROUP_SIZE);
|
const HYPRE_Int BDIMX = hypre_min(4, GROUP_SIZE);
|
||||||
#elif defined(HYPRE_USING_HIP)
|
#elif defined(HYPRE_USING_HIP)
|
||||||
const HYPRE_Int BDIMX = hypre_min(2, GROUP_SIZE);
|
const HYPRE_Int BDIMX = hypre_min(4, GROUP_SIZE);
|
||||||
#endif
|
#endif
|
||||||
const HYPRE_Int BDIMY = GROUP_SIZE / BDIMX;
|
const HYPRE_Int BDIMY = GROUP_SIZE / BDIMX;
|
||||||
|
|
||||||
|
|||||||
@ -22,6 +22,7 @@ hypre_spgemm_numerical_with_rownnz
|
|||||||
HYPRE_Complex *d_b,
|
HYPRE_Complex *d_b,
|
||||||
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
|
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
|
||||||
|
|
||||||
|
#if defined(HYPRE_USING_CUDA)
|
||||||
template HYPRE_Int
|
template HYPRE_Int
|
||||||
hypre_spgemm_numerical_with_rownnz
|
hypre_spgemm_numerical_with_rownnz
|
||||||
< 5, HYPRE_SPGEMM_NUMER_HASH_SIZE, HYPRE_SPGEMM_BASE_GROUP_SIZE, false >
|
< 5, HYPRE_SPGEMM_NUMER_HASH_SIZE, HYPRE_SPGEMM_BASE_GROUP_SIZE, false >
|
||||||
@ -30,6 +31,7 @@ hypre_spgemm_numerical_with_rownnz
|
|||||||
HYPRE_Int *d_ia, HYPRE_Int *d_ja, HYPRE_Complex *d_a, HYPRE_Int *d_ib, HYPRE_Int *d_jb,
|
HYPRE_Int *d_ia, HYPRE_Int *d_ja, HYPRE_Complex *d_a, HYPRE_Int *d_ib, HYPRE_Int *d_jb,
|
||||||
HYPRE_Complex *d_b,
|
HYPRE_Complex *d_b,
|
||||||
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
|
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
|
||||||
|
#endif
|
||||||
|
|
||||||
template HYPRE_Int
|
template HYPRE_Int
|
||||||
hypre_spgemm_numerical_max_num_blocks
|
hypre_spgemm_numerical_max_num_blocks
|
||||||
|
|||||||
@ -21,6 +21,17 @@ hypre_spgemm_numerical_with_rownnz
|
|||||||
HYPRE_Complex *d_b,
|
HYPRE_Complex *d_b,
|
||||||
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
|
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
|
||||||
|
|
||||||
|
#if defined(HYPRE_USING_HIP)
|
||||||
|
template HYPRE_Int
|
||||||
|
hypre_spgemm_numerical_with_rownnz
|
||||||
|
< 6, HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, HYPRE_SPGEMM_BASE_GROUP_SIZE * 2, false >
|
||||||
|
( HYPRE_Int m, HYPRE_Int *row_ind, HYPRE_Int k, HYPRE_Int n, bool need_ghash,
|
||||||
|
HYPRE_Int exact_rownnz,
|
||||||
|
HYPRE_Int *d_ia, HYPRE_Int *d_ja, HYPRE_Complex *d_a, HYPRE_Int *d_ib, HYPRE_Int *d_jb,
|
||||||
|
HYPRE_Complex *d_b,
|
||||||
|
HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c );
|
||||||
|
#endif
|
||||||
|
|
||||||
template HYPRE_Int
|
template HYPRE_Int
|
||||||
hypre_spgemm_numerical_max_num_blocks
|
hypre_spgemm_numerical_max_num_blocks
|
||||||
< HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, HYPRE_SPGEMM_BASE_GROUP_SIZE * 2 >
|
< HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, HYPRE_SPGEMM_BASE_GROUP_SIZE * 2 >
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user