sycl impl.

This commit is contained in:
Ruipeng Li 2022-06-06 16:42:26 -07:00
parent 0fee4f3c80
commit c12449c44c

View File

@ -851,8 +851,6 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R,
return C;
}
#if PARCSRGEMM_NEWPARADD
HYPRE_Int
hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
HYPRE_Int num_rows,
@ -913,6 +911,15 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
// Cext offd
#if defined(HYPRE_USING_SYCL)
auto off_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0),
Cext_bigj),
oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0),
Cext_bigj) + Cext_nnz,
Cext_bigj,
oneapi::dpl::make_zip_iterator(work, big_work),
std::not_fn(pred1) );
HYPRE_Int Cext_offd_nnz = std::get<0>(off_end.base()) - work;
#else
auto off_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), Cext_bigj)),
@ -929,6 +936,18 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
&num_cols_offd_C, &col_map_offd_C, &map_offd_to_C);
#if defined(HYPRE_USING_SYCL)
HYPRE_ONEDPL_CALL( oneapi::dpl::lower_bound,
col_map_offd_C,
col_map_offd_C + num_cols_offd_C,
big_work,
big_work + Cext_offd_nnz,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work) );
HYPRE_ONEDPL_CALL( std::transform,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work),
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work) + Cext_offd_nnz,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work),
[const_val = num_cols] (const auto & x) {return x + const_val;} );
#else
HYPRE_THRUST_CALL( lower_bound,
col_map_offd_C,
@ -947,6 +966,15 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
// Cext diag
#if defined(HYPRE_USING_SYCL)
auto dia_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0),
Cext_bigj),
oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0),
Cext_bigj) + Cext_nnz,
Cext_bigj,
oneapi::dpl::make_zip_iterator(work, big_work),
pred1 );
HYPRE_Int Cext_diag_nnz = thrust::get<0>(dia_end.base()) - work;
#else
auto dia_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), Cext_bigj)),
@ -962,6 +990,11 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
hypre_assert(Cext_diag_nnz + Cext_offd_nnz == Cext_nnz);
#if defined(HYPRE_USING_SYCL)
HYPRE_ONEDPL_CALL( std::transform,
big_work,
big_work + Cext_diag_nnz,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work),
[const_val = first_col_diag](const auto & x) {return x - const_val;} );
#else
HYPRE_THRUST_CALL( transform,
big_work,
@ -974,6 +1007,11 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
// transform Cbar_local J index
RAP_functor<2, HYPRE_Int> func2(num_cols, 0, map_offd_to_C);
#if defined(HYPRE_USING_SYCL)
HYPRE_ONEDPL_CALL( std::transform,
hypre_CSRMatrixJ(Cbar_local),
hypre_CSRMatrixJ(Cbar_local) + local_nnz_Cbar,
hypre_CSRMatrixJ(Cbar_local),
func2 );
#else
HYPRE_THRUST_CALL( transform,
hypre_CSRMatrixJ(Cbar_local),
@ -985,6 +1023,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
hypre_TFree(big_work, HYPRE_MEMORY_DEVICE);
hypre_TFree(work, HYPRE_MEMORY_DEVICE);
hypre_TFree(map_offd_to_C, HYPRE_MEMORY_DEVICE);
hypre_TFree(Cext_bigj, HYPRE_MEMORY_DEVICE);
hypre_CSRMatrixBigJ(Cext) = NULL;
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
@ -1064,257 +1104,6 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
in_range<HYPRE_Int> pred(0, num_cols - 1);
#if defined(HYPRE_USING_SYCL)
#else
HYPRE_Int nnz_C_diag = HYPRE_THRUST_CALL( count_if,
zmp_j,
zmp_j + local_nnz_C,
pred );
#endif
HYPRE_Int nnz_C_offd = local_nnz_C - nnz_C_diag;
// diag
hypre_CSRMatrix *C_diag = hypre_CSRMatrixCreate(num_rows, num_cols, nnz_C_diag);
hypre_CSRMatrixInitialize_v2(C_diag, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_ii = hypre_TAlloc(HYPRE_Int, nnz_C_diag, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_j = hypre_CSRMatrixJ(C_diag);
HYPRE_Complex *C_diag_a = hypre_CSRMatrixData(C_diag);
#if defined(HYPRE_USING_SYCL)
#else
auto new_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)) + local_nnz_C,
zmp_j,
thrust::make_zip_iterator(thrust::make_tuple(C_diag_ii, C_diag_j, C_diag_a)),
pred );
hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_diag_ii + nnz_C_diag );
#endif
hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_diag), nnz_C_diag, C_diag_ii,
hypre_CSRMatrixI(C_diag));
hypre_TFree(C_diag_ii, HYPRE_MEMORY_DEVICE);
// offd
hypre_CSRMatrix *C_offd = hypre_CSRMatrixCreate(num_rows, num_cols_offd_C, nnz_C_offd);
hypre_CSRMatrixInitialize_v2(C_offd, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_ii = hypre_TAlloc(HYPRE_Int, nnz_C_offd, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_j = hypre_CSRMatrixJ(C_offd);
HYPRE_Complex *C_offd_a = hypre_CSRMatrixData(C_offd);
#if defined(HYPRE_USING_SYCL)
#else
new_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)) + local_nnz_C,
zmp_j,
thrust::make_zip_iterator(thrust::make_tuple(C_offd_ii, C_offd_j, C_offd_a)),
thrust::not1(pred) );
hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_offd_ii + nnz_C_offd );
#endif
hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_offd), nnz_C_offd, C_offd_ii,
hypre_CSRMatrixI(C_offd));
hypre_TFree(C_offd_ii, HYPRE_MEMORY_DEVICE);
#if defined(HYPRE_USING_SYCL)
#else
HYPRE_THRUST_CALL( transform,
C_offd_j,
C_offd_j + nnz_C_offd,
thrust::make_constant_iterator(num_cols),
C_offd_j,
thrust::minus<HYPRE_Int>() );
#endif
// free
hypre_TFree(Cbar_local, HYPRE_MEMORY_HOST);
hypre_TFree(zmp_i, HYPRE_MEMORY_DEVICE);
if (!Cext_nnz)
{
hypre_CSRMatrixDestroy(Cbar);
hypre_CSRMatrixDestroy(Cext);
}
else
{
hypre_CSRMatrixDestroy(Cz);
}
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time Split %f\n", t2);
#endif
// output
*C_diag_ptr = C_diag;
*C_offd_ptr = C_offd;
*num_cols_offd_C_ptr = num_cols_offd_C;
*col_map_offd_C_ptr = col_map_offd_C;
return hypre_error_flag;
}
#else
HYPRE_Int
hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
HYPRE_Int num_rows,
HYPRE_Int num_cols,
HYPRE_BigInt first_col_diag,
HYPRE_BigInt last_col_diag,
HYPRE_Int num_cols_offd,
HYPRE_BigInt *col_map_offd,
HYPRE_Int local_nnz_Cbar,
hypre_CSRMatrix *Cbar,
hypre_CSRMatrix *Cext,
hypre_CSRMatrix **C_diag_ptr,
hypre_CSRMatrix **C_offd_ptr,
HYPRE_Int *num_cols_offd_C_ptr,
HYPRE_BigInt **col_map_offd_C_ptr )
{
#if PARCSRGEMM_TIMING > 1
MPI_Comm comm = hypre_ParCSRCommPkgComm(comm_pkg);
HYPRE_Real t1, t2;
t1 = hypre_MPI_Wtime();
#endif
// to hold Cbar local and Cext
HYPRE_Int tmp_s = local_nnz_Cbar + hypre_CSRMatrixNumNonzeros(Cext);
HYPRE_Int *tmp_i = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int *tmp_j = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Complex *tmp_a = hypre_TAlloc(HYPRE_Complex, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int Cext_diag_nnz, Cext_offd_nnz, num_cols_offd_C, *offd_map_to_C;
HYPRE_BigInt *col_map_offd_C;
hypre_CSRMatrix *C_diag, *C_offd;
hypre_CSRMatrixSplitDevice_core(0,
hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
NULL,
hypre_CSRMatrixBigJ(Cext), NULL, NULL,
first_col_diag,
last_col_diag,
-1,
NULL, NULL, NULL, NULL,
&Cext_diag_nnz,
NULL, NULL, NULL, NULL,
&Cext_offd_nnz,
NULL, NULL, NULL, NULL);
HYPRE_Int *Cext_ii = hypreDevice_CsrRowPtrsToIndices(hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
hypre_CSRMatrixI(Cext));
hypre_CSRMatrixSplitDevice_core(1,
hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
Cext_ii,
hypre_CSRMatrixBigJ(Cext),
hypre_CSRMatrixData(Cext),
NULL,
first_col_diag,
last_col_diag,
num_cols_offd,
col_map_offd,
&offd_map_to_C,
&num_cols_offd_C,
&col_map_offd_C,
&Cext_diag_nnz,
tmp_i + local_nnz_Cbar,
tmp_j + local_nnz_Cbar,
tmp_a + local_nnz_Cbar,
NULL,
&Cext_offd_nnz,
tmp_i + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_a + local_nnz_Cbar + Cext_diag_nnz,
NULL);
hypre_CSRMatrixDestroy(Cext);
hypre_TFree(Cext_ii, HYPRE_MEMORY_DEVICE);
hypre_ParCSRCommPkgCopySendMapElmtsToDevice(comm_pkg);
#if defined(HYPRE_USING_SYCL)
hypreSycl_gather( tmp_i + local_nnz_Cbar,
tmp_i + tmp_s,
hypre_ParCSRCommPkgDeviceSendMapElmts(comm_pkg),
tmp_i + local_nnz_Cbar );
HYPRE_ONEDPL_CALL( std::transform,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + tmp_s,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
[const_val = num_cols] (const auto & x) {return x + const_val;} );
#else
HYPRE_THRUST_CALL( gather,
tmp_i + local_nnz_Cbar,
tmp_i + tmp_s,
hypre_ParCSRCommPkgDeviceSendMapElmts(comm_pkg),
tmp_i + local_nnz_Cbar );
HYPRE_THRUST_CALL( transform,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + tmp_s,
thrust::make_constant_iterator(num_cols),
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
thrust::plus<HYPRE_Int>() );
#endif
hypreDevice_CsrRowPtrsToIndices_v2(num_rows, local_nnz_Cbar, hypre_CSRMatrixI(Cbar), tmp_i);
hypre_TMemcpy(tmp_a, hypre_CSRMatrixData(Cbar), HYPRE_Complex, local_nnz_Cbar, HYPRE_MEMORY_DEVICE,
HYPRE_MEMORY_DEVICE);
RAP_functor<2, HYPRE_Int> func2(num_cols, 0, offd_map_to_C);
#if defined(HYPRE_USING_SYCL)
HYPRE_ONEDPL_CALL( std::transform,
hypre_CSRMatrixJ(Cbar),
hypre_CSRMatrixJ(Cbar) + local_nnz_Cbar,
tmp_j,
func2 );
#else
HYPRE_THRUST_CALL( transform,
hypre_CSRMatrixJ(Cbar),
hypre_CSRMatrixJ(Cbar) + local_nnz_Cbar,
tmp_j,
func2 );
#endif
hypre_CSRMatrixDestroy(Cbar);
hypre_TFree(offd_map_to_C, HYPRE_MEMORY_DEVICE);
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time PartialAdd1 %f\n", t2);
#endif
#if PARCSRGEMM_TIMING > 1
t1 = hypre_MPI_Wtime();
#endif
// add Cext to Cbar local. Note: type 2, diagonal entries are put at the first in the rows
hypreDevice_StableSortByTupleKey(tmp_s, tmp_i, tmp_j, tmp_a, 2);
HYPRE_Int *zmp_i = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int *zmp_j = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Complex *zmp_a = hypre_TAlloc(HYPRE_Complex, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int local_nnz_C = hypreDevice_ReduceByTupleKey(tmp_s, tmp_i, tmp_j, tmp_a, zmp_i, zmp_j,
zmp_a);
hypre_TFree(tmp_i, HYPRE_MEMORY_DEVICE);
hypre_TFree(tmp_j, HYPRE_MEMORY_DEVICE);
hypre_TFree(tmp_a, HYPRE_MEMORY_DEVICE);
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time PartialAdd2 %f\n", t2);
#endif
#if PARCSRGEMM_TIMING > 1
t1 = hypre_MPI_Wtime();
#endif
// split into diag and offd
in_range<HYPRE_Int> pred(0, num_cols - 1);
#if defined(HYPRE_USING_SYCL)
HYPRE_Int nnz_C_diag = HYPRE_ONEDPL_CALL( std::count_if,
zmp_j,
@ -1328,7 +1117,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
#endif
HYPRE_Int nnz_C_offd = local_nnz_C - nnz_C_diag;
C_diag = hypre_CSRMatrixCreate(num_rows, num_cols, nnz_C_diag);
// diag
hypre_CSRMatrix *C_diag = hypre_CSRMatrixCreate(num_rows, num_cols, nnz_C_diag);
hypre_CSRMatrixInitialize_v2(C_diag, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_ii = hypre_TAlloc(HYPRE_Int, nnz_C_diag, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_j = hypre_CSRMatrixJ(C_diag);
@ -1354,7 +1144,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
hypre_CSRMatrixI(C_diag));
hypre_TFree(C_diag_ii, HYPRE_MEMORY_DEVICE);
C_offd = hypre_CSRMatrixCreate(num_rows, num_cols_offd_C, nnz_C_offd);
// offd
hypre_CSRMatrix *C_offd = hypre_CSRMatrixCreate(num_rows, num_cols_offd_C, nnz_C_offd);
hypre_CSRMatrixInitialize_v2(C_offd, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_ii = hypre_TAlloc(HYPRE_Int, nnz_C_offd, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_j = hypre_CSRMatrixJ(C_offd);
@ -1394,15 +1185,27 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
thrust::minus<HYPRE_Int>() );
#endif
// free
hypre_TFree(Cbar_local, HYPRE_MEMORY_HOST);
hypre_TFree(zmp_i, HYPRE_MEMORY_DEVICE);
hypre_TFree(zmp_j, HYPRE_MEMORY_DEVICE);
hypre_TFree(zmp_a, HYPRE_MEMORY_DEVICE);
if (!Cext_nnz)
{
hypre_CSRMatrixDestroy(Cbar);
hypre_CSRMatrixDestroy(Cext);
}
else
{
hypre_CSRMatrixDestroy(Cz);
}
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time Split %f\n", t2);
#endif
// output
*C_diag_ptr = C_diag;
*C_offd_ptr = C_offd;
*num_cols_offd_C_ptr = num_cols_offd_C;
@ -1411,6 +1214,4 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
return hypre_error_flag;
}
#endif
#endif // #if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)