diff --git a/src/seq_mv/csr_spgemm_device.h b/src/seq_mv/csr_spgemm_device.h index 04f305023..7bb50edd4 100644 --- a/src/seq_mv/csr_spgemm_device.h +++ b/src/seq_mv/csr_spgemm_device.h @@ -17,7 +17,12 @@ static const char HYPRE_SPGEMM_HASH_TYPE = 'D'; /* default settings associated with bin 5 */ +#if defined(HYPRE_USING_CUDA) #define HYPRE_SPGEMM_NUMER_HASH_SIZE 256 +#endif +#if defined(HYPRE_USING_HIP) +#define HYPRE_SPGEMM_NUMER_HASH_SIZE 128 +#endif #define HYPRE_SPGEMM_SYMBL_HASH_SIZE 512 #define HYPRE_SPGEMM_BASE_GROUP_SIZE 32 /* unroll factor in the kernels */ diff --git a/src/seq_mv/csr_spgemm_device_numer.c b/src/seq_mv/csr_spgemm_device_numer.c index 2df44baa6..e9b0cf632 100644 --- a/src/seq_mv/csr_spgemm_device_numer.c +++ b/src/seq_mv/csr_spgemm_device_numer.c @@ -31,9 +31,16 @@ hypreDevice_CSRSpGemmNumerWithRownnzUpperboundNoBin( HYPRE_Int m, HYPRE_Complex **d_c_out, HYPRE_Int *nnzC_out ) { +#if defined(HYPRE_USING_CUDA) const HYPRE_Int SHMEM_HASH_SIZE = HYPRE_SPGEMM_NUMER_HASH_SIZE; const HYPRE_Int GROUP_SIZE = HYPRE_SPGEMM_BASE_GROUP_SIZE; const HYPRE_Int BIN = 5; +#endif +#if defined(HYPRE_USING_HIP) + const HYPRE_Int SHMEM_HASH_SIZE = 2 * HYPRE_SPGEMM_NUMER_HASH_SIZE; + const HYPRE_Int GROUP_SIZE = 2 * HYPRE_SPGEMM_BASE_GROUP_SIZE; + const HYPRE_Int BIN = 6; +#endif #ifdef HYPRE_SPGEMM_PRINTF HYPRE_Int max_rc = HYPRE_THRUST_CALL(reduce, d_rc, d_rc + m, 0, thrust::maximum()); @@ -122,7 +129,12 @@ hypreDevice_CSRSpGemmNumerWithRownnzUpperboundBinned( HYPRE_Int m, HYPRE_Int h_bin_ptr[HYPRE_SPGEMM_MAX_NBIN + 1]; //HYPRE_Int num_bins = hypre_HandleSpgemmNumBin(hypre_handle()); HYPRE_Int high_bin = hypre_HandleSpgemmHighestBin(hypre_handle())[1]; +#if defined(HYPRE_USING_CUDA) const char s = 8, t = 2, u = high_bin; +#endif +#if defined(HYPRE_USING_HIP) + const char s = 4, t = 2, u = high_bin; +#endif hypre_SpGemmCreateBins(m, s, t, u, d_rc, false, d_rind, h_bin_ptr); diff --git a/src/seq_mv/csr_spgemm_device_numer.h b/src/seq_mv/csr_spgemm_device_numer.h index 005a60ead..fac4fca41 100644 --- a/src/seq_mv/csr_spgemm_device_numer.h +++ b/src/seq_mv/csr_spgemm_device_numer.h @@ -454,7 +454,7 @@ hypre_spgemm_numerical_with_rownnz( HYPRE_Int m, #if defined(HYPRE_USING_CUDA) const HYPRE_Int BDIMX = hypre_min(4, GROUP_SIZE); #elif defined(HYPRE_USING_HIP) - const HYPRE_Int BDIMX = hypre_min(2, GROUP_SIZE); + const HYPRE_Int BDIMX = hypre_min(4, GROUP_SIZE); #endif const HYPRE_Int BDIMY = GROUP_SIZE / BDIMX; diff --git a/src/seq_mv/csr_spgemm_device_numer5.c b/src/seq_mv/csr_spgemm_device_numer5.c index 5c55dad4f..050e9c3ff 100644 --- a/src/seq_mv/csr_spgemm_device_numer5.c +++ b/src/seq_mv/csr_spgemm_device_numer5.c @@ -22,6 +22,7 @@ hypre_spgemm_numerical_with_rownnz HYPRE_Complex *d_b, HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c ); +#if defined(HYPRE_USING_CUDA) template HYPRE_Int hypre_spgemm_numerical_with_rownnz < 5, HYPRE_SPGEMM_NUMER_HASH_SIZE, HYPRE_SPGEMM_BASE_GROUP_SIZE, false > @@ -30,6 +31,7 @@ hypre_spgemm_numerical_with_rownnz HYPRE_Int *d_ia, HYPRE_Int *d_ja, HYPRE_Complex *d_a, HYPRE_Int *d_ib, HYPRE_Int *d_jb, HYPRE_Complex *d_b, HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c ); +#endif template HYPRE_Int hypre_spgemm_numerical_max_num_blocks diff --git a/src/seq_mv/csr_spgemm_device_numer6.c b/src/seq_mv/csr_spgemm_device_numer6.c index 561bf0a3f..6b96b1249 100644 --- a/src/seq_mv/csr_spgemm_device_numer6.c +++ b/src/seq_mv/csr_spgemm_device_numer6.c @@ -21,6 +21,17 @@ hypre_spgemm_numerical_with_rownnz HYPRE_Complex *d_b, HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c ); +#if defined(HYPRE_USING_HIP) +template HYPRE_Int +hypre_spgemm_numerical_with_rownnz +< 6, HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, HYPRE_SPGEMM_BASE_GROUP_SIZE * 2, false > +( HYPRE_Int m, HYPRE_Int *row_ind, HYPRE_Int k, HYPRE_Int n, bool need_ghash, + HYPRE_Int exact_rownnz, + HYPRE_Int *d_ia, HYPRE_Int *d_ja, HYPRE_Complex *d_a, HYPRE_Int *d_ib, HYPRE_Int *d_jb, + HYPRE_Complex *d_b, + HYPRE_Int *d_rc, HYPRE_Int *d_ic, HYPRE_Int *d_jc, HYPRE_Complex *d_c ); +#endif + template HYPRE_Int hypre_spgemm_numerical_max_num_blocks < HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, HYPRE_SPGEMM_BASE_GROUP_SIZE * 2 >