A: row nnz = 1, numer (1st version)
This commit is contained in:
parent
df10f0b5ed
commit
6450927874
@ -95,16 +95,19 @@ hypre_spgemm_hash_insert_numer( HYPRE_Int HashSize,
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <HYPRE_Int SHMEM_HASH_SIZE, char HASHTYPE, HYPRE_Int GROUP_SIZE, bool HAS_GHASH, HYPRE_Int UNROLL_FACTOR>
|
template <HYPRE_Int SHMEM_HASH_SIZE, char HASHTYPE, HYPRE_Int GROUP_SIZE, bool HAS_GHASH, bool IA1, HYPRE_Int UNROLL_FACTOR>
|
||||||
static __device__ __forceinline__
|
static __device__ __forceinline__
|
||||||
void
|
void
|
||||||
hypre_spgemm_compute_row_numer( HYPRE_Int istart_a,
|
hypre_spgemm_compute_row_numer( HYPRE_Int istart_a,
|
||||||
HYPRE_Int iend_a,
|
HYPRE_Int iend_a,
|
||||||
|
HYPRE_Int istart_c,
|
||||||
const HYPRE_Int *ja,
|
const HYPRE_Int *ja,
|
||||||
const HYPRE_Complex *aa,
|
const HYPRE_Complex *aa,
|
||||||
const HYPRE_Int *ib,
|
const HYPRE_Int *ib,
|
||||||
const HYPRE_Int *jb,
|
const HYPRE_Int *jb,
|
||||||
const HYPRE_Complex *ab,
|
const HYPRE_Complex *ab,
|
||||||
|
HYPRE_Int *jc,
|
||||||
|
HYPRE_Complex *ac,
|
||||||
volatile HYPRE_Int *s_HashKeys,
|
volatile HYPRE_Int *s_HashKeys,
|
||||||
volatile HYPRE_Complex *s_HashVals,
|
volatile HYPRE_Complex *s_HashVals,
|
||||||
HYPRE_Int g_HashSize,
|
HYPRE_Int g_HashSize,
|
||||||
@ -150,17 +153,26 @@ hypre_spgemm_compute_row_numer( HYPRE_Int istart_a,
|
|||||||
{
|
{
|
||||||
const HYPRE_Int k_idx = read_only_load(jb + k);
|
const HYPRE_Int k_idx = read_only_load(jb + k);
|
||||||
const HYPRE_Complex k_val = (ab ? read_only_load(ab + k) : 1.0) * mult;
|
const HYPRE_Complex k_val = (ab ? read_only_load(ab + k) : 1.0) * mult;
|
||||||
/* first try to insert into shared memory hash table */
|
|
||||||
HYPRE_Int pos = hypre_spgemm_hash_insert_numer<SHMEM_HASH_SIZE, HASHTYPE, UNROLL_FACTOR>
|
|
||||||
(s_HashKeys, s_HashVals, k_idx, k_val);
|
|
||||||
|
|
||||||
if (HAS_GHASH && -1 == pos)
|
if (IA1)
|
||||||
{
|
{
|
||||||
pos = hypre_spgemm_hash_insert_numer<HASHTYPE>
|
jc[istart_c + k - rowB_start] = k_idx;
|
||||||
(g_HashSize, g_HashKeys, g_HashVals, k_idx, k_val);
|
ac[istart_c + k - rowB_start] = k_val;
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
/* first try to insert into shared memory hash table */
|
||||||
|
HYPRE_Int pos = hypre_spgemm_hash_insert_numer<SHMEM_HASH_SIZE, HASHTYPE, UNROLL_FACTOR>
|
||||||
|
(s_HashKeys, s_HashVals, k_idx, k_val);
|
||||||
|
|
||||||
hypre_device_assert(pos != -1);
|
if (HAS_GHASH && -1 == pos)
|
||||||
|
{
|
||||||
|
pos = hypre_spgemm_hash_insert_numer<HASHTYPE>
|
||||||
|
(g_HashSize, g_HashKeys, g_HashVals, k_idx, k_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
hypre_device_assert(pos != -1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -321,10 +333,32 @@ hypre_spgemm_numeric( const HYPRE_Int M,
|
|||||||
istart_a, iend_a,
|
istart_a, iend_a,
|
||||||
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
|
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
|
||||||
|
|
||||||
|
/* start/end position of row of C */
|
||||||
|
HYPRE_Int istart_c = 0;
|
||||||
|
#if defined(HYPRE_DEBUG)
|
||||||
|
HYPRE_Int iend_c = 0;
|
||||||
|
group_read<GROUP_SIZE>(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M,
|
||||||
|
istart_c, iend_c,
|
||||||
|
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
|
||||||
|
#else
|
||||||
|
group_read<GROUP_SIZE>(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M,
|
||||||
|
istart_c,
|
||||||
|
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
|
||||||
|
#endif
|
||||||
|
|
||||||
/* work with two hash tables */
|
/* work with two hash tables */
|
||||||
hypre_spgemm_compute_row_numer<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, UNROLL_FACTOR>
|
if (iend_a == istart_a + 1)
|
||||||
(istart_a, iend_a, ja, aa, ib, jb, ab, group_s_HashKeys, group_s_HashVals, ghash_size,
|
{
|
||||||
jg + istart_g, ag + istart_g);
|
hypre_spgemm_compute_row_numer<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, true, UNROLL_FACTOR>
|
||||||
|
(istart_a, iend_a, istart_c, ja, aa, ib, jb, ab, jc, ac, group_s_HashKeys, group_s_HashVals, ghash_size,
|
||||||
|
jg + istart_g, ag + istart_g);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
hypre_spgemm_compute_row_numer<SHMEM_HASH_SIZE, HASHTYPE, GROUP_SIZE, HAS_GHASH, false, UNROLL_FACTOR>
|
||||||
|
(istart_a, iend_a, istart_c, ja, aa, ib, jb, ab, jc, ac, group_s_HashKeys, group_s_HashVals, ghash_size,
|
||||||
|
jg + istart_g, ag + istart_g);
|
||||||
|
}
|
||||||
|
|
||||||
if (GROUP_SIZE > HYPRE_WARP_SIZE)
|
if (GROUP_SIZE > HYPRE_WARP_SIZE)
|
||||||
{
|
{
|
||||||
@ -333,21 +367,10 @@ hypre_spgemm_numeric( const HYPRE_Int M,
|
|||||||
|
|
||||||
HYPRE_Int jsum = 0;
|
HYPRE_Int jsum = 0;
|
||||||
|
|
||||||
/* copy results into the final C */
|
/* the first warp of the group copies results into the final C
|
||||||
|
* if GROUP_SIZE < WARP_SIZE, the whole group copies */
|
||||||
if (get_warp_in_group_id<GROUP_SIZE>() == 0)
|
if (get_warp_in_group_id<GROUP_SIZE>() == 0)
|
||||||
{
|
{
|
||||||
HYPRE_Int istart_c = 0;
|
|
||||||
#ifdef HYPRE_DEBUG
|
|
||||||
HYPRE_Int iend_c = 0;
|
|
||||||
group_read<GROUP_SIZE>(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M,
|
|
||||||
istart_c, iend_c,
|
|
||||||
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
|
|
||||||
#else
|
|
||||||
group_read<GROUP_SIZE>(ic + ii, GROUP_SIZE >= HYPRE_WARP_SIZE || i < M,
|
|
||||||
istart_c,
|
|
||||||
GROUP_SIZE >= HYPRE_WARP_SIZE ? warp_lane_id : lane_id);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
jsum = hypre_spgemm_copy_from_hash_into_C_row<GROUP_SIZE, SHMEM_HASH_SIZE, HAS_GHASH, UNROLL_FACTOR>
|
jsum = hypre_spgemm_copy_from_hash_into_C_row<GROUP_SIZE, SHMEM_HASH_SIZE, HAS_GHASH, UNROLL_FACTOR>
|
||||||
(lane_id,
|
(lane_id,
|
||||||
GROUP_SIZE >= HYPRE_WARP_SIZE || i < M ? group_s_HashKeys : NULL,
|
GROUP_SIZE >= HYPRE_WARP_SIZE || i < M ? group_s_HashKeys : NULL,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user