diff --git a/src/seq_mv/csr_spgemm_device_numer.h b/src/seq_mv/csr_spgemm_device_numer.h index addf22f00..d58a0292d 100644 --- a/src/seq_mv/csr_spgemm_device_numer.h +++ b/src/seq_mv/csr_spgemm_device_numer.h @@ -95,16 +95,19 @@ hypre_spgemm_hash_insert_numer( HYPRE_Int HashSize, return -1; } -template +template static __device__ __forceinline__ void hypre_spgemm_compute_row_numer( HYPRE_Int istart_a, HYPRE_Int iend_a, + HYPRE_Int istart_c, const HYPRE_Int *ja, const HYPRE_Complex *aa, const HYPRE_Int *ib, const HYPRE_Int *jb, const HYPRE_Complex *ab, + HYPRE_Int *jc, + HYPRE_Complex *ac, volatile HYPRE_Int *s_HashKeys, volatile HYPRE_Complex *s_HashVals, HYPRE_Int g_HashSize, @@ -150,17 +153,26 @@ hypre_spgemm_compute_row_numer( HYPRE_Int istart_a, { const HYPRE_Int k_idx = read_only_load(jb + k); const HYPRE_Complex k_val = (ab ? read_only_load(ab + k) : 1.0) * mult; - /* first try to insert into shared memory hash table */ - HYPRE_Int pos = hypre_spgemm_hash_insert_numer - (s_HashKeys, s_HashVals, k_idx, k_val); - if (HAS_GHASH && -1 == pos) + if (IA1) { - pos = hypre_spgemm_hash_insert_numer - (g_HashSize, g_HashKeys, g_HashVals, k_idx, k_val); + jc[istart_c + k - rowB_start] = k_idx; + ac[istart_c + k - rowB_start] = k_val; } + else + { + /* first try to insert into shared memory hash table */ + HYPRE_Int pos = hypre_spgemm_hash_insert_numer + (s_HashKeys, s_HashVals, k_idx, k_val); - hypre_device_assert(pos != -1); + if (HAS_GHASH && -1 == pos) + { + pos = hypre_spgemm_hash_insert_numer + (g_HashSize, g_HashKeys, g_HashVals, k_idx, k_val); + } + + hypre_device_assert(pos != -1); + } } } } @@ -321,10 +333,32 @@ hypre_spgemm_numeric( const HYPRE_Int M, istart_a, iend_a, GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id); + /* start/end position of row of C */ + HYPRE_Int istart_c = 0; +#if defined(HYPRE_DEBUG) + HYPRE_Int iend_c = 0; + group_read(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M, + istart_c, iend_c, + GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id); +#else + group_read(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M, + istart_c, + GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id); +#endif + /* work with two hash tables */ - hypre_spgemm_compute_row_numer - (istart_a, iend_a, ja, aa, ib, jb, ab, group_s_HashKeys, group_s_HashVals, ghash_size, - jg + istart_g, ag + istart_g); + if (iend_a == istart_a + 1) + { + hypre_spgemm_compute_row_numer + (istart_a, iend_a, istart_c, ja, aa, ib, jb, ab, jc, ac, group_s_HashKeys, group_s_HashVals, ghash_size, + jg + istart_g, ag + istart_g); + } + else + { + hypre_spgemm_compute_row_numer + (istart_a, iend_a, istart_c, ja, aa, ib, jb, ab, jc, ac, group_s_HashKeys, group_s_HashVals, ghash_size, + jg + istart_g, ag + istart_g); + } if (GROUP_SIZE > HYPRE_WARP_SIZE) { @@ -333,21 +367,10 @@ hypre_spgemm_numeric( const HYPRE_Int M, HYPRE_Int jsum = 0; - /* copy results into the final C */ + /* the first warp of the group copies results into the final C + * if GROUP_SIZE < WARP_SIZE, the whole group copies */ if (get_warp_in_group_id() == 0) { - HYPRE_Int istart_c = 0; -#ifdef HYPRE_DEBUG - HYPRE_Int iend_c = 0; - group_read(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M, - istart_c, iend_c, - GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id); -#else - group_read(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M, - istart_c, - GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id); -#endif - jsum = hypre_spgemm_copy_from_hash_into_C_row (lane_id, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M ? group_s_HashKeys : NULL,