A: row nnz = 1, numer (1st version)

This commit is contained in:
Ruipeng Li 2022-04-12 11:18:42 -07:00
parent df10f0b5ed
commit 6450927874

View File

@ -95,16 +95,19 @@ hypre_spgemm_hash_insert_numer( HYPRE_Int HashSize,
return -1;
}
template <HYPRE_Int SHMEM_HASH_SIZE, char HASHTYPE, HYPRE_Int GROUP_SIZE, bool HAS_GHASH, HYPRE_Int UNROLL_FACTOR>
template <HYPRE_Int SHMEM_HASH_SIZE, char HASHTYPE, HYPRE_Int GROUP_SIZE, bool HAS_GHASH, bool IA1, HYPRE_Int UNROLL_FACTOR>
static __device__ __forceinline__
void
hypre_spgemm_compute_row_numer( HYPRE_Int istart_a,
HYPRE_Int iend_a,
HYPRE_Int istart_c,
const HYPRE_Int *ja,
const HYPRE_Complex *aa,
const HYPRE_Int *ib,
const HYPRE_Int *jb,
const HYPRE_Complex *ab,
HYPRE_Int *jc,
HYPRE_Complex *ac,
volatile HYPRE_Int *s_HashKeys,
volatile HYPRE_Complex *s_HashVals,
HYPRE_Int g_HashSize,
@ -150,17 +153,26 @@ hypre_spgemm_compute_row_numer( HYPRE_Int istart_a,
{
const HYPRE_Int k_idx = read_only_load(jb + k);
const HYPRE_Complex k_val = (ab ? read_only_load(ab + k) : 1.0) * mult;
/* first try to insert into shared memory hash table */
HYPRE_Int pos = hypre_spgemm_hash_insert_numer<SHMEM_HASH_SIZE, HASHTYPE, UNROLL_FACTOR>
(s_HashKeys, s_HashVals, k_idx, k_val);
if (HAS_GHASH && -1 == pos)
if (IA1)
{
pos = hypre_spgemm_hash_insert_numer<HASHTYPE>
(g_HashSize, g_HashKeys, g_HashVals, k_idx, k_val);
jc[istart_c + k - rowB_start] = k_idx;
ac[istart_c + k - rowB_start] = k_val;
}
else
{
/* first try to insert into shared memory hash table */
HYPRE_Int pos = hypre_spgemm_hash_insert_numer<SHMEM_HASH_SIZE, HASHTYPE, UNROLL_FACTOR>
(s_HashKeys, s_HashVals, k_idx, k_val);
hypre_device_assert(pos != -1);
if (HAS_GHASH && -1 == pos)
{
pos = hypre_spgemm_hash_insert_numer<HASHTYPE>
(g_HashSize, g_HashKeys, g_HashVals, k_idx, k_val);
}
hypre_device_assert(pos != -1);
}
}
}
}
@ -321,10 +333,32 @@ hypre_spgemm_numeric( const HYPRE_Int M,
istart_a, iend_a,
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
/* start/end position of row of C */
HYPRE_Int istart_c = 0;
#if defined(HYPRE_DEBUG)
HYPRE_Int iend_c = 0;
group_read<GROUP_SIZE>(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M,
istart_c, iend_c,
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
#else
group_read<GROUP_SIZE>(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M,
istart_c,
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
#endif
/* work with two hash tables */
hypre_spgemm_compute_row_numer<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, UNROLL_FACTOR>
(istart_a, iend_a, ja, aa, ib, jb, ab, group_s_HashKeys, group_s_HashVals, ghash_size,
jg + istart_g, ag + istart_g);
if (iend_a == istart_a + 1)
{
hypre_spgemm_compute_row_numer<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, true, UNROLL_FACTOR>
(istart_a, iend_a, istart_c, ja, aa, ib, jb, ab, jc, ac, group_s_HashKeys, group_s_HashVals, ghash_size,
jg + istart_g, ag + istart_g);
}
else
{
hypre_spgemm_compute_row_numer<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, false, UNROLL_FACTOR>
(istart_a, iend_a, istart_c, ja, aa, ib, jb, ab, jc, ac, group_s_HashKeys, group_s_HashVals, ghash_size,
jg + istart_g, ag + istart_g);
}
if (GROUP_SIZE > HYPRE_WARP_SIZE)
{
@ -333,21 +367,10 @@ hypre_spgemm_numeric( const HYPRE_Int M,
HYPRE_Int jsum = 0;
/* copy results into the final C */
/* the first warp of the group copies results into the final C
* if GROUP_SIZE < WARP_SIZE, the whole group copies */
if (get_warp_in_group_id<GROUP_SIZE>() == 0)
{
HYPRE_Int istart_c = 0;
#ifdef HYPRE_DEBUG
HYPRE_Int iend_c = 0;
group_read<GROUP_SIZE>(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M,
istart_c, iend_c,
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
#else
group_read<GROUP_SIZE>(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M,
istart_c,
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
#endif
jsum = hypre_spgemm_copy_from_hash_into_C_row<GROUP_SIZE, SHMEM_HASH_SIZE, HAS_GHASH, UNROLL_FACTOR>
(lane_id,
GROUP_SIZE >= HYPRE_WARP_SIZE || i < M ? group_s_HashKeys : NULL,