Merge branch 'spgemm' of github.com:hypre-space/hypre into spgemm

This commit is contained in:
Ruipeng Li 2022-06-10 12:14:35 -07:00
commit 637af55397
9 changed files with 62 additions and 25 deletions

View File

@ -1175,6 +1175,7 @@ hypre_ParMatmul( hypre_ParCSRMatrix *A,
hypre_CSRMatrixData(C_diag) = C_diag_data;
hypre_CSRMatrixI(C_diag) = C_diag_i;
hypre_CSRMatrixJ(C_diag) = C_diag_j;
hypre_CSRMatrixMemoryLocation(C_diag) = memory_location_C;
hypre_CSRMatrixSetRownnz(C_diag);
C_offd = hypre_ParCSRMatrixOffd(C);
@ -1186,10 +1187,9 @@ hypre_ParMatmul( hypre_ParCSRMatrix *A,
hypre_CSRMatrixJ(C_offd) = C_offd_j;
hypre_ParCSRMatrixColMapOffd(C) = col_map_offd_C;
}
hypre_CSRMatrixMemoryLocation(C_offd) = memory_location_C;
hypre_CSRMatrixSetRownnz(C_offd);
hypre_CSRMatrixMemoryLocation(C_diag) = memory_location_C;
hypre_CSRMatrixMemoryLocation(C_offd) = memory_location_C;
/*-----------------------------------------------------------------------
* Free various arrays
@ -4009,6 +4009,9 @@ hypre_ParTMatmul( hypre_ParCSRMatrix *A,
hypre_ParCSRMatrixRowvalues(C) = NULL;
hypre_ParCSRMatrixGetrowactive(C) = 0;
hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixDiag(C)) = memory_location_C;
hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixOffd(C)) = memory_location_C;
if (C_diag)
{
hypre_CSRMatrixSetRownnz(C_diag);
@ -4029,9 +4032,6 @@ hypre_ParTMatmul( hypre_ParCSRMatrix *A,
hypre_ParCSRMatrixOffd(C) = C_tmp_offd;
}
hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixDiag(C)) = memory_location_C;
hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixOffd(C)) = memory_location_C;
if (num_cols_offd_C)
{
HYPRE_Int jj_count_offd, nnz_offd;

View File

@ -943,6 +943,11 @@ hypre_ParcsrGetExternalRowsDeviceWait(void *vrequest)
return A_ext;
}
#endif // defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
HYPRE_Int
hypre_ParCSRCommPkgCreateMatrixE( hypre_ParCSRCommPkg *comm_pkg,
HYPRE_Int local_ncols )
@ -979,11 +984,6 @@ hypre_ParCSRCommPkgCreateMatrixE( hypre_ParCSRCommPkg *comm_pkg,
return hypre_error_flag;
}
#endif // defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
hypre_CSRMatrix*
hypre_MergeDiagAndOffdDevice(hypre_ParCSRMatrix *A)
{

View File

@ -940,12 +940,16 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
&num_cols_offd_C, &col_map_offd_C, &map_offd_to_C);
#if defined(HYPRE_USING_SYCL)
HYPRE_ONEDPL_CALL( oneapi::dpl::lower_bound,
col_map_offd_C,
col_map_offd_C + num_cols_offd_C,
big_work,
big_work + Cext_offd_nnz,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work) );
/* WM: onedpl lower_bound currently does not accept zero length input */
if (num_cols_offd_C > 0)
{
HYPRE_ONEDPL_CALL( oneapi::dpl::lower_bound,
col_map_offd_C,
col_map_offd_C + num_cols_offd_C,
big_work,
big_work + Cext_offd_nnz,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work) );
}
HYPRE_ONEDPL_CALL( std::transform,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work),
@ -978,7 +982,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
oneapi::dpl::make_zip_iterator(work, big_work),
pred1 );
HYPRE_Int Cext_diag_nnz = thrust::get<0>(dia_end.base()) - work;
HYPRE_Int Cext_diag_nnz = std::get<0>(dia_end.base()) - work;
#else
auto dia_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), Cext_bigj)),
@ -1062,13 +1066,26 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
if (hypre_HandleSpgemmUseVendor(hypre_handle()))
{
ie_a = hypre_TAlloc(HYPRE_Complex, num_rows + num_elemt, HYPRE_MEMORY_DEVICE);
#if defined(HYPRE_USING_SYCL)
HYPRE_ONEDPL_CALL(std::fill, ie_a, ie_a + num_rows + num_elemt, 1.0);
#else
HYPRE_THRUST_CALL(fill, ie_a, ie_a + num_rows + num_elemt, 1.0);
#endif
}
#if defined(HYPRE_USING_SYCL)
hypreSycl_sequence(ie_ii, ie_ii + num_rows, 0);
HYPRE_ONEDPL_CALL( std::copy, send_map, send_map + num_elemt, ie_ii + num_rows);
hypreSycl_sequence(ie_j, ie_j + num_rows + num_elemt, 0);
auto zipped_begin = oneapi::dpl::make_zip_iterator(ie_ii, ie_j);
HYPRE_ONEDPL_CALL( std::stable_sort, zipped_begin, zipped_begin + num_rows + num_elemt,
[](auto lhs, auto rhs) { return std::get<0>(lhs) < std::get<0>(rhs); } );
#else
HYPRE_THRUST_CALL( sequence, ie_ii, ie_ii + num_rows);
HYPRE_THRUST_CALL( copy, send_map, send_map + num_elemt, ie_ii + num_rows);
HYPRE_THRUST_CALL( sequence, ie_j, ie_j + num_rows + num_elemt);
HYPRE_THRUST_CALL( stable_sort_by_key, ie_ii, ie_ii + num_rows + num_elemt, ie_j );
#endif
HYPRE_Int *ie_i = hypreDevice_CsrRowIndicesToPtrs(num_rows, num_rows + num_elemt, ie_ii);
hypre_TFree(ie_ii, HYPRE_MEMORY_DEVICE);

View File

@ -588,7 +588,8 @@ hypre_CSRMatrixPrintMM( hypre_CSRMatrix *matrix,
HYPRE_Int trans,
const char *file_name )
{
hypre_assert(hypre_CSRMatrixI(matrix)[hypre_CSRMatrixNumRows(matrix)] == hypre_CSRMatrixNumNonzeros(matrix));
hypre_assert(hypre_CSRMatrixI(matrix)[hypre_CSRMatrixNumRows(matrix)] == hypre_CSRMatrixNumNonzeros(
matrix));
FILE *fp = file_name ? fopen(file_name, "w") : stdout;
@ -611,9 +612,10 @@ hypre_CSRMatrixPrintMM( hypre_CSRMatrix *matrix,
hypre_fprintf(fp, "%%%%MatrixMarket matrix coordinate pattern general\n");
}
hypre_fprintf(fp, "%d %d %d\n", trans ? hypre_CSRMatrixNumCols(matrix) : hypre_CSRMatrixNumRows(matrix),
trans ? hypre_CSRMatrixNumRows(matrix) : hypre_CSRMatrixNumCols(matrix),
hypre_CSRMatrixNumNonzeros(matrix));
hypre_fprintf(fp, "%d %d %d\n",
trans ? hypre_CSRMatrixNumCols(matrix) : hypre_CSRMatrixNumRows(matrix),
trans ? hypre_CSRMatrixNumRows(matrix) : hypre_CSRMatrixNumCols(matrix),
hypre_CSRMatrixNumNonzeros(matrix));
HYPRE_Int i, j;

View File

@ -27,4 +27,3 @@ hypre_spgemm_numerical_max_num_blocks
( HYPRE_Int multiProcessorCount, HYPRE_Int *num_blocks_ptr, HYPRE_Int *block_size_ptr );
#endif /* HYPRE_USING_CUDA || defined(HYPRE_USING_HIP) */

View File

@ -137,6 +137,22 @@ OutputIter hypreSycl_gather(InputIter1 map_first, InputIter1 map_last,
return HYPRE_ONEDPL_CALL( oneapi::dpl::copy, perm_begin, perm_begin + n, result);
}
// Equivalent of thrust::sequence (with step=1)
template <class Iter, class T>
void hypreSycl_sequence(Iter first, Iter last, T init = 0)
{
static_assert(
std::is_same<typename std::iterator_traits<Iter>::iterator_category,
std::random_access_iterator_tag>::value,
"Iterators passed to algorithms must be random-access iterators.");
using DiffType = typename std::iterator_traits<Iter>::difference_type;
HYPRE_ONEDPL_CALL( std::transform,
oneapi::dpl::counting_iterator<DiffType>(init),
oneapi::dpl::counting_iterator<DiffType>(std::distance(first, last)),
first,
[](auto i) { return i; });
}
#endif
#endif

View File

@ -1761,7 +1761,8 @@ typedef struct
* 1) Merge sort can take advantage of eliminating duplicates.
* 2) Merge sort is more efficiently parallelizable than qsort
*/
HYPRE_Int hypre_MergeOrderedArrays( hypre_IntArray *array1, hypre_IntArray *array2, hypre_IntArray *array3 );
HYPRE_Int hypre_MergeOrderedArrays( hypre_IntArray *array1, hypre_IntArray *array2,
hypre_IntArray *array3 );
void hypre_union2(HYPRE_Int n1, HYPRE_BigInt *arr1, HYPRE_Int n2, HYPRE_BigInt *arr2, HYPRE_Int *n3,
HYPRE_BigInt *arr3, HYPRE_Int *map1, HYPRE_Int *map2);
void hypre_merge_sort(HYPRE_Int *in, HYPRE_Int *temp, HYPRE_Int len, HYPRE_Int **sorted);

View File

@ -61,7 +61,8 @@ hypre_MergeOrderedArrays( hypre_IntArray *array1,
array3_data[k++] = array2_data[j++];
}
array3_data = hypre_TReAlloc_v2(array3_data, HYPRE_Int, size1 + size2, HYPRE_Int, k, memory_location);
array3_data = hypre_TReAlloc_v2(array3_data, HYPRE_Int, size1 + size2, HYPRE_Int, k,
memory_location);
hypre_IntArraySize(array3) = k;
hypre_IntArrayData(array3) = array3_data;

View File

@ -258,7 +258,8 @@ typedef struct
* 1) Merge sort can take advantage of eliminating duplicates.
* 2) Merge sort is more efficiently parallelizable than qsort
*/
HYPRE_Int hypre_MergeOrderedArrays( hypre_IntArray *array1, hypre_IntArray *array2, hypre_IntArray *array3 );
HYPRE_Int hypre_MergeOrderedArrays( hypre_IntArray *array1, hypre_IntArray *array2,
hypre_IntArray *array3 );
void hypre_union2(HYPRE_Int n1, HYPRE_BigInt *arr1, HYPRE_Int n2, HYPRE_BigInt *arr2, HYPRE_Int *n3,
HYPRE_BigInt *arr3, HYPRE_Int *map1, HYPRE_Int *map2);
void hypre_merge_sort(HYPRE_Int *in, HYPRE_Int *temp, HYPRE_Int len, HYPRE_Int **sorted);