A: row nnz = 1, symbl

This commit is contained in:
Ruipeng Li 2022-04-12 11:17:56 -07:00
parent 4025b32ad8
commit df10f0b5ed

View File

@ -94,7 +94,7 @@ hypre_spgemm_hash_insert_symbl( HYPRE_Int HashSize,
return -1;
}
template <HYPRE_Int SHMEM_HASH_SIZE, char HASHTYPE, HYPRE_Int GROUP_SIZE, bool HAS_GHASH, HYPRE_Int UNROLL_FACTOR>
template <HYPRE_Int SHMEM_HASH_SIZE, char HASHTYPE, HYPRE_Int GROUP_SIZE, bool HAS_GHASH, bool IA1, HYPRE_Int UNROLL_FACTOR>
static __device__ __forceinline__
HYPRE_Int
hypre_spgemm_compute_row_symbl( HYPRE_Int istart_a,
@ -143,23 +143,30 @@ hypre_spgemm_compute_row_symbl( HYPRE_Int istart_a,
{
if (k < rowB_end)
{
const HYPRE_Int k_idx = read_only_load(jb + k);
/* first try to insert into shared memory hash table */
HYPRE_Int pos = hypre_spgemm_hash_insert_symbl<SHMEM_HASH_SIZE, HASHTYPE, UNROLL_FACTOR>
(s_HashKeys, k_idx, num_new_insert);
if (HAS_GHASH && -1 == pos)
{
pos = hypre_spgemm_hash_insert_symbl<HASHTYPE>
(g_HashSize, g_HashKeys, k_idx, num_new_insert);
}
/* if failed again, both hash tables must have been full
(hash table size estimation was too small).
Increase the counter anyhow (will lead to over-counting) */
if (pos == -1)
if (IA1)
{
num_new_insert ++;
failed = 1;
}
else
{
const HYPRE_Int k_idx = read_only_load(jb + k);
/* first try to insert into shared memory hash table */
HYPRE_Int pos = hypre_spgemm_hash_insert_symbl<SHMEM_HASH_SIZE, HASHTYPE, UNROLL_FACTOR>
(s_HashKeys, k_idx, num_new_insert);
if (HAS_GHASH && -1 == pos)
{
pos = hypre_spgemm_hash_insert_symbl<HASHTYPE>
(g_HashSize, g_HashKeys, k_idx, num_new_insert);
}
/* if failed again, both hash tables must have been full
(hash table size estimation was too small).
Increase the counter anyhow (will lead to over-counting) */
if (pos == -1)
{
num_new_insert ++;
failed = 1;
}
}
}
}
@ -263,9 +270,18 @@ hypre_spgemm_symbolic( const HYPRE_Int M, /* HYPRE_Int K, HYPRE_In
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
/* work with two hash tables */
HYPRE_Int jsum =
hypre_spgemm_compute_row_symbl<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, UNROLL_FACTOR>
(istart_a, iend_a, ja, ib, jb, group_s_HashKeys, ghash_size, jg + istart_g, failed);
HYPRE_Int jsum;
if (iend_a == istart_a + 1)
{
jsum = hypre_spgemm_compute_row_symbl<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, true, UNROLL_FACTOR>
(istart_a, iend_a, ja, ib, jb, group_s_HashKeys, ghash_size, jg + istart_g, failed);
}
else
{
jsum = hypre_spgemm_compute_row_symbl<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, false, UNROLL_FACTOR>
(istart_a, iend_a, ja, ib, jb, group_s_HashKeys, ghash_size, jg + istart_g, failed);
}
#if defined(HYPRE_DEBUG)
hypre_device_assert(CAN_FAIL || failed == 0);
@ -279,7 +295,7 @@ hypre_spgemm_symbolic( const HYPRE_Int M, /* HYPRE_Int K, HYPRE_In
}
else
{
__syncthreads();
group_sync<GROUP_SIZE>();
jsum = group_reduce_sum<HYPRE_Int, NUM_GROUPS_PER_BLOCK, GROUP_SIZE>(jsum, s_HashKeys);
}