reorg parcsrmm

This commit is contained in:
Ruipeng Li 2022-03-31 17:38:19 -07:00
parent 23c7777045
commit 6e8607fd47
2 changed files with 301 additions and 467 deletions

View File

@ -9,7 +9,7 @@
#include "_hypre_parcsr_mv.h"
#include "_hypre_utilities.hpp"
#define PARCSRGEMM_TIMING 2
#define PARCSRGEMM_TIMING 0
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
@ -414,7 +414,6 @@ hypre_ParCSRTMatMatKTDevice( hypre_ParCSRMatrix *A,
hypre_ParPrintf(comm, "Time SpGemm %f\n", t2);
#endif
hypre_CSRMatrixDestroy(AbarT);
hypre_CSRMatrixDestroy(Bbar);
@ -506,251 +505,21 @@ hypre_ParCSRTMatMatKTDevice( hypre_ParCSRMatrix *A,
hypre_ParPrintf(comm, "Size Cext %d %d %d\n", hypre_CSRMatrixNumRows(Cext), hypre_CSRMatrixNumCols(Cext), hypre_CSRMatrixNumNonzeros(Cext));
#endif
#if PARCSRGEMM_TIMING > 1
t1 = hypre_MPI_Wtime();
#endif
// to hold Cbar local and Cext
HYPRE_Int tmp_s = local_nnz_Cbar + hypre_CSRMatrixNumNonzeros(Cext);
HYPRE_Int *tmp_i = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int *tmp_j = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Complex *tmp_a = hypre_TAlloc(HYPRE_Complex, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int Cext_diag_nnz, Cext_offd_nnz, *offd_map_to_C;
hypre_CSRMatrixSplitDevice_core(0,
hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
NULL,
hypre_CSRMatrixBigJ(Cext), NULL, NULL,
hypre_ParCSRMatrixFirstColDiag(B),
hypre_ParCSRMatrixLastColDiag(B),
hypre_CSRMatrixNumCols(B_offd),
NULL, NULL, NULL, NULL,
&Cext_diag_nnz,
NULL, NULL, NULL, NULL,
&Cext_offd_nnz,
NULL, NULL, NULL, NULL);
HYPRE_Int *Cext_ii = hypreDevice_CsrRowPtrsToIndices(hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
hypre_CSRMatrixI(Cext));
hypre_CSRMatrixSplitDevice_core(1,
hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
Cext_ii,
hypre_CSRMatrixBigJ(Cext),
hypre_CSRMatrixData(Cext),
NULL,
hypre_ParCSRMatrixFirstColDiag(B),
hypre_ParCSRMatrixLastColDiag(B),
hypre_CSRMatrixNumCols(B_offd),
hypre_ParCSRMatrixDeviceColMapOffd(B),
&offd_map_to_C,
&num_cols_offd_C,
&col_map_offd_C,
&Cext_diag_nnz,
tmp_i + local_nnz_Cbar,
tmp_j + local_nnz_Cbar,
tmp_a + local_nnz_Cbar,
NULL,
&Cext_offd_nnz,
tmp_i + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_a + local_nnz_Cbar + Cext_diag_nnz,
NULL);
hypre_CSRMatrixDestroy(Cext);
hypre_TFree(Cext_ii, HYPRE_MEMORY_DEVICE);
hypre_ParCSRCommPkgCopySendMapElmtsToDevice(hypre_ParCSRMatrixCommPkg(A));
#if defined(HYPRE_USING_SYCL)
hypreSycl_gather( tmp_i + local_nnz_Cbar,
tmp_i + tmp_s,
hypre_ParCSRCommPkgDeviceSendMapElmts(hypre_ParCSRMatrixCommPkg(A)),
tmp_i + local_nnz_Cbar );
/* WM: necessary? */
if (tmp_s > local_nnz_Cbar + Cext_diag_nnz)
{
HYPRE_ONEDPL_CALL( std::transform,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + tmp_s,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
[const_val = hypre_ParCSRMatrixNumCols(B)] (const auto & x) {return x + const_val;} );
}
#else
HYPRE_THRUST_CALL( gather,
tmp_i + local_nnz_Cbar,
tmp_i + tmp_s,
hypre_ParCSRCommPkgDeviceSendMapElmts(hypre_ParCSRMatrixCommPkg(A)),
tmp_i + local_nnz_Cbar );
HYPRE_THRUST_CALL( transform,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + tmp_s,
thrust::make_constant_iterator(hypre_ParCSRMatrixNumCols(B)),
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
thrust::plus<HYPRE_Int>() );
#endif
hypreDevice_CsrRowPtrsToIndices_v2(hypre_ParCSRMatrixNumCols(A), local_nnz_Cbar,
hypre_CSRMatrixI(Cbar), tmp_i);
hypre_TMemcpy(tmp_a, hypre_CSRMatrixData(Cbar), HYPRE_Complex, local_nnz_Cbar, HYPRE_MEMORY_DEVICE,
HYPRE_MEMORY_DEVICE);
RAP_functor<2, HYPRE_Int> func2(hypre_ParCSRMatrixNumCols(B), 0, offd_map_to_C);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
if (local_nnz_Cbar > 0)
{
HYPRE_ONEDPL_CALL( std::transform,
hypre_CSRMatrixJ(Cbar),
hypre_CSRMatrixJ(Cbar) + local_nnz_Cbar,
tmp_j,
func2 );
}
#else
HYPRE_THRUST_CALL( transform,
hypre_CSRMatrixJ(Cbar),
hypre_CSRMatrixJ(Cbar) + local_nnz_Cbar,
tmp_j,
func2 );
#endif
hypre_CSRMatrixDestroy(Cbar);
hypre_TFree(offd_map_to_C, HYPRE_MEMORY_DEVICE);
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time PartialAdd1 %f\n", t2);
#endif
#if PARCSRGEMM_TIMING > 1
t1 = hypre_MPI_Wtime();
#endif
// add Cext to Cbar local. Note: type 2, diagonal entries are put at the first in the rows
hypreDevice_StableSortByTupleKey(tmp_s, tmp_i, tmp_j, tmp_a, 2);
HYPRE_Int *zmp_i = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int *zmp_j = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Complex *zmp_a = hypre_TAlloc(HYPRE_Complex, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int local_nnz_C = hypreDevice_ReduceByTupleKey(tmp_s, tmp_i, tmp_j, tmp_a, zmp_i, zmp_j,
zmp_a);
hypre_TFree(tmp_i, HYPRE_MEMORY_DEVICE);
hypre_TFree(tmp_j, HYPRE_MEMORY_DEVICE);
hypre_TFree(tmp_a, HYPRE_MEMORY_DEVICE);
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time PartialAdd2 %f\n", t2);
#endif
#if PARCSRGEMM_TIMING > 1
t1 = hypre_MPI_Wtime();
#endif
// split into diag and offd
in_range<HYPRE_Int> pred(0, hypre_ParCSRMatrixNumCols(B) - 1);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
HYPRE_Int nnz_C_diag = 0;
if (local_nnz_C > 0)
{
nnz_C_diag = HYPRE_ONEDPL_CALL( std::count_if,
zmp_j,
zmp_j + local_nnz_C,
pred );
}
#else
HYPRE_Int nnz_C_diag = HYPRE_THRUST_CALL( count_if,
zmp_j,
zmp_j + local_nnz_C,
pred );
#endif
HYPRE_Int nnz_C_offd = local_nnz_C - nnz_C_diag;
C_diag = hypre_CSRMatrixCreate(hypre_ParCSRMatrixNumCols(A), hypre_ParCSRMatrixNumCols(B),
nnz_C_diag);
hypre_CSRMatrixInitialize_v2(C_diag, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_ii = hypre_TAlloc(HYPRE_Int, nnz_C_diag, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_j = hypre_CSRMatrixJ(C_diag);
HYPRE_Complex *C_diag_a = hypre_CSRMatrixData(C_diag);
#if defined(HYPRE_USING_SYCL)
auto new_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a),
oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a) + local_nnz_C,
zmp_j,
oneapi::dpl::make_zip_iterator(C_diag_ii, C_diag_j, C_diag_a),
pred );
hypre_assert( std::get<0>(new_end.base()) == C_diag_ii + nnz_C_diag );
#else
auto new_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)) + local_nnz_C,
zmp_j,
thrust::make_zip_iterator(thrust::make_tuple(C_diag_ii, C_diag_j, C_diag_a)),
pred );
hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_diag_ii + nnz_C_diag );
#endif
hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_diag), nnz_C_diag, C_diag_ii,
hypre_CSRMatrixI(C_diag));
hypre_TFree(C_diag_ii, HYPRE_MEMORY_DEVICE);
C_offd = hypre_CSRMatrixCreate(hypre_ParCSRMatrixNumCols(A), num_cols_offd_C, nnz_C_offd);
hypre_CSRMatrixInitialize_v2(C_offd, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_ii = hypre_TAlloc(HYPRE_Int, nnz_C_offd, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_j = hypre_CSRMatrixJ(C_offd);
HYPRE_Complex *C_offd_a = hypre_CSRMatrixData(C_offd);
#if defined(HYPRE_USING_SYCL)
new_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a),
oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a) + local_nnz_C,
zmp_j,
oneapi::dpl::make_zip_iterator(C_offd_ii, C_offd_j, C_offd_a),
std::not_fn(pred) );
hypre_assert( std::get<0>(new_end.base()) == C_offd_ii + nnz_C_offd );
#else
new_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)) + local_nnz_C,
zmp_j,
thrust::make_zip_iterator(thrust::make_tuple(C_offd_ii, C_offd_j, C_offd_a)),
thrust::not1(pred) );
hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_offd_ii + nnz_C_offd );
#endif
hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_offd), nnz_C_offd, C_offd_ii,
hypre_CSRMatrixI(C_offd));
hypre_TFree(C_offd_ii, HYPRE_MEMORY_DEVICE);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
if (nnz_C_offd > 0)
{
HYPRE_ONEDPL_CALL( std::transform,
C_offd_j,
C_offd_j + nnz_C_offd,
C_offd_j,
[const_val = hypre_ParCSRMatrixNumCols(B)] (const auto & x) {return x - const_val;} );
}
#else
HYPRE_THRUST_CALL( transform,
C_offd_j,
C_offd_j + nnz_C_offd,
thrust::make_constant_iterator(hypre_ParCSRMatrixNumCols(B)),
C_offd_j,
thrust::minus<HYPRE_Int>() );
#endif
hypre_TFree(zmp_i, HYPRE_MEMORY_DEVICE);
hypre_TFree(zmp_j, HYPRE_MEMORY_DEVICE);
hypre_TFree(zmp_a, HYPRE_MEMORY_DEVICE);
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time Split %f\n", t2);
#endif
/* add Cext to local part of Cbar */
hypre_ParCSRTMatMatPartialAddDevice(hypre_ParCSRMatrixCommPkg(A),
hypre_ParCSRMatrixNumCols(A),
hypre_ParCSRMatrixNumCols(B),
hypre_ParCSRMatrixFirstColDiag(B),
hypre_ParCSRMatrixLastColDiag(B),
hypre_CSRMatrixNumCols(B_offd),
hypre_ParCSRMatrixDeviceColMapOffd(B),
local_nnz_Cbar,
Cbar,
Cext,
&C_diag,
&C_offd,
&num_cols_offd_C,
&col_map_offd_C);
}
else
{
@ -965,227 +734,23 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R,
hypre_TMemcpy(hypre_CSRMatrixI(Cbar) + hypre_ParCSRMatrixNumCols(R), &local_nnz_Cbar, HYPRE_Int, 1,
HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_HOST);
// to hold Cbar local and Cext
HYPRE_Int tmp_s = local_nnz_Cbar + hypre_CSRMatrixNumNonzeros(Cext);
HYPRE_Int *tmp_i = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int *tmp_j = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Complex *tmp_a = hypre_TAlloc(HYPRE_Complex, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int Cext_diag_nnz, Cext_offd_nnz, *offd_map_to_C;
/* add Cext to local part of Cbar */
hypre_ParCSRTMatMatPartialAddDevice(hypre_ParCSRMatrixCommPkg(R),
hypre_ParCSRMatrixNumCols(R),
hypre_ParCSRMatrixNumCols(P),
hypre_ParCSRMatrixFirstColDiag(P),
hypre_ParCSRMatrixLastColDiag(P),
num_cols_offd,
col_map_offd,
local_nnz_Cbar,
Cbar,
Cext,
&C_diag,
&C_offd,
&num_cols_offd_C,
&col_map_offd_C);
hypre_CSRMatrixSplitDevice_core(0,
hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
NULL,
hypre_CSRMatrixBigJ(Cext), NULL, NULL,
hypre_ParCSRMatrixFirstColDiag(P),
hypre_ParCSRMatrixLastColDiag(P),
num_cols_offd,
NULL, NULL, NULL, NULL,
&Cext_diag_nnz,
NULL, NULL, NULL, NULL,
&Cext_offd_nnz,
NULL, NULL, NULL, NULL);
HYPRE_Int *Cext_ii = hypreDevice_CsrRowPtrsToIndices(hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
hypre_CSRMatrixI(Cext));
hypre_CSRMatrixSplitDevice_core(1,
hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
Cext_ii,
hypre_CSRMatrixBigJ(Cext),
hypre_CSRMatrixData(Cext),
NULL,
hypre_ParCSRMatrixFirstColDiag(P),
hypre_ParCSRMatrixLastColDiag(P),
num_cols_offd,
col_map_offd,
&offd_map_to_C,
&num_cols_offd_C,
&col_map_offd_C,
&Cext_diag_nnz,
tmp_i + local_nnz_Cbar,
tmp_j + local_nnz_Cbar,
tmp_a + local_nnz_Cbar,
NULL,
&Cext_offd_nnz,
tmp_i + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_a + local_nnz_Cbar + Cext_diag_nnz,
NULL);
hypre_CSRMatrixDestroy(Cext);
hypre_TFree(Cext_ii, HYPRE_MEMORY_DEVICE);
hypre_TFree(col_map_offd, HYPRE_MEMORY_DEVICE);
hypre_ParCSRCommPkgCopySendMapElmtsToDevice(hypre_ParCSRMatrixCommPkg(R));
#if defined(HYPRE_USING_SYCL)
hypreSycl_gather( tmp_i + local_nnz_Cbar,
tmp_i + tmp_s,
hypre_ParCSRCommPkgDeviceSendMapElmts(hypre_ParCSRMatrixCommPkg(R)),
tmp_i + local_nnz_Cbar );
/* WM: necessary? */
if (tmp_s > local_nnz_Cbar + Cext_diag_nnz)
{
HYPRE_ONEDPL_CALL( std::transform,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + tmp_s,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
[const_val = hypre_ParCSRMatrixNumCols(P)] (const auto & x) {return x + const_val;} );
}
#else
HYPRE_THRUST_CALL( gather,
tmp_i + local_nnz_Cbar,
tmp_i + tmp_s,
hypre_ParCSRCommPkgDeviceSendMapElmts(hypre_ParCSRMatrixCommPkg(R)),
tmp_i + local_nnz_Cbar );
HYPRE_THRUST_CALL( transform,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + tmp_s,
thrust::make_constant_iterator(hypre_ParCSRMatrixNumCols(P)),
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
thrust::plus<HYPRE_Int>() );
#endif
hypreDevice_CsrRowPtrsToIndices_v2(hypre_ParCSRMatrixNumCols(R), local_nnz_Cbar,
hypre_CSRMatrixI(Cbar), tmp_i);
hypre_TMemcpy(tmp_a, hypre_CSRMatrixData(Cbar), HYPRE_Complex, local_nnz_Cbar, HYPRE_MEMORY_DEVICE,
HYPRE_MEMORY_DEVICE);
RAP_functor<2, HYPRE_Int> func2(hypre_ParCSRMatrixNumCols(P), 0, offd_map_to_C);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
if (local_nnz_Cbar > 0)
{
HYPRE_ONEDPL_CALL( std::transform,
hypre_CSRMatrixJ(Cbar),
hypre_CSRMatrixJ(Cbar) + local_nnz_Cbar,
tmp_j,
func2 );
}
#else
HYPRE_THRUST_CALL( transform,
hypre_CSRMatrixJ(Cbar),
hypre_CSRMatrixJ(Cbar) + local_nnz_Cbar,
tmp_j,
func2 );
#endif
hypre_CSRMatrixDestroy(Cbar);
hypre_TFree(offd_map_to_C, HYPRE_MEMORY_DEVICE);
// add Cext to Cbar local. Note: type 2, diagonal entries are put at the first in the rows
hypreDevice_StableSortByTupleKey(tmp_s, tmp_i, tmp_j, tmp_a, 2);
HYPRE_Int *zmp_i = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int *zmp_j = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Complex *zmp_a = hypre_TAlloc(HYPRE_Complex, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int local_nnz_C = hypreDevice_ReduceByTupleKey(tmp_s, tmp_i, tmp_j, tmp_a, zmp_i, zmp_j,
zmp_a);
hypre_TFree(tmp_i, HYPRE_MEMORY_DEVICE);
hypre_TFree(tmp_j, HYPRE_MEMORY_DEVICE);
hypre_TFree(tmp_a, HYPRE_MEMORY_DEVICE);
// split into diag and offd
in_range<HYPRE_Int> pred(0, hypre_ParCSRMatrixNumCols(P) - 1);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
HYPRE_Int nnz_C_diag = 0;
if (local_nnz_C > 0)
{
nnz_C_diag = HYPRE_ONEDPL_CALL( std::count_if,
zmp_j,
zmp_j + local_nnz_C,
pred );
}
#else
HYPRE_Int nnz_C_diag = HYPRE_THRUST_CALL( count_if,
zmp_j,
zmp_j + local_nnz_C,
pred );
#endif
HYPRE_Int nnz_C_offd = local_nnz_C - nnz_C_diag;
C_diag = hypre_CSRMatrixCreate(hypre_ParCSRMatrixNumCols(R), hypre_ParCSRMatrixNumCols(P),
nnz_C_diag);
hypre_CSRMatrixInitialize_v2(C_diag, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_ii = hypre_TAlloc(HYPRE_Int, nnz_C_diag, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_j = hypre_CSRMatrixJ(C_diag);
HYPRE_Complex *C_diag_a = hypre_CSRMatrixData(C_diag);
#if defined(HYPRE_USING_SYCL)
auto new_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a),
oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a) + local_nnz_C,
zmp_j,
oneapi::dpl::make_zip_iterator(C_diag_ii, C_diag_j, C_diag_a),
pred );
hypre_assert( std::get<0>(new_end.base()) == C_diag_ii + nnz_C_diag );
#else
auto new_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)) + local_nnz_C,
zmp_j,
thrust::make_zip_iterator(thrust::make_tuple(C_diag_ii, C_diag_j, C_diag_a)),
pred );
hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_diag_ii + nnz_C_diag );
#endif
hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_diag), nnz_C_diag, C_diag_ii,
hypre_CSRMatrixI(C_diag));
hypre_TFree(C_diag_ii, HYPRE_MEMORY_DEVICE);
C_offd = hypre_CSRMatrixCreate(hypre_ParCSRMatrixNumCols(R), num_cols_offd_C, nnz_C_offd);
hypre_CSRMatrixInitialize_v2(C_offd, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_ii = hypre_TAlloc(HYPRE_Int, nnz_C_offd, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_j = hypre_CSRMatrixJ(C_offd);
HYPRE_Complex *C_offd_a = hypre_CSRMatrixData(C_offd);
#if defined(HYPRE_USING_SYCL)
new_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a),
oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a) + local_nnz_C,
zmp_j,
oneapi::dpl::make_zip_iterator(C_offd_ii, C_offd_j, C_offd_a),
std::not_fn(pred) );
hypre_assert( std::get<0>(new_end.base()) == C_offd_ii + nnz_C_offd );
#else
new_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)) + local_nnz_C,
zmp_j,
thrust::make_zip_iterator(thrust::make_tuple(C_offd_ii, C_offd_j, C_offd_a)),
thrust::not1(pred) );
hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_offd_ii + nnz_C_offd );
#endif
hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_offd), nnz_C_offd, C_offd_ii,
hypre_CSRMatrixI(C_offd));
hypre_TFree(C_offd_ii, HYPRE_MEMORY_DEVICE);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
if (nnz_C_offd > 0)
{
HYPRE_ONEDPL_CALL( std::transform,
C_offd_j,
C_offd_j + nnz_C_offd,
C_offd_j,
[const_val = hypre_ParCSRMatrixNumCols(P)] (const auto & x) {return x - const_val;} );
}
#else
HYPRE_THRUST_CALL( transform,
C_offd_j,
C_offd_j + nnz_C_offd,
thrust::make_constant_iterator(hypre_ParCSRMatrixNumCols(P)),
C_offd_j,
thrust::minus<HYPRE_Int>() );
#endif
hypre_TFree(zmp_i, HYPRE_MEMORY_DEVICE);
hypre_TFree(zmp_j, HYPRE_MEMORY_DEVICE);
hypre_TFree(zmp_a, HYPRE_MEMORY_DEVICE);
}
else
{
@ -1239,4 +804,273 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R,
return C;
}
HYPRE_Int
hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
HYPRE_Int num_rows,
HYPRE_Int num_cols,
HYPRE_BigInt first_col_diag,
HYPRE_BigInt last_col_diag,
HYPRE_Int num_cols_offd,
HYPRE_BigInt *col_map_offd,
HYPRE_Int local_nnz_Cbar,
hypre_CSRMatrix *Cbar,
hypre_CSRMatrix *Cext,
hypre_CSRMatrix **C_diag_ptr,
hypre_CSRMatrix **C_offd_ptr,
HYPRE_Int *num_cols_offd_C_ptr,
HYPRE_BigInt **col_map_offd_C_ptr )
{
#if PARCSRGEMM_TIMING > 1
t1 = hypre_MPI_Wtime();
#endif
// to hold Cbar local and Cext
HYPRE_Int tmp_s = local_nnz_Cbar + hypre_CSRMatrixNumNonzeros(Cext);
HYPRE_Int *tmp_i = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int *tmp_j = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Complex *tmp_a = hypre_TAlloc(HYPRE_Complex, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int Cext_diag_nnz, Cext_offd_nnz, num_cols_offd_C, *offd_map_to_C;
HYPRE_BigInt *col_map_offd_C;
hypre_CSRMatrix *C_diag, *C_offd;
hypre_CSRMatrixSplitDevice_core(0,
hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
NULL,
hypre_CSRMatrixBigJ(Cext), NULL, NULL,
first_col_diag,
last_col_diag,
-1,
NULL, NULL, NULL, NULL,
&Cext_diag_nnz,
NULL, NULL, NULL, NULL,
&Cext_offd_nnz,
NULL, NULL, NULL, NULL);
HYPRE_Int *Cext_ii = hypreDevice_CsrRowPtrsToIndices(hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
hypre_CSRMatrixI(Cext));
hypre_CSRMatrixSplitDevice_core(1,
hypre_CSRMatrixNumRows(Cext),
hypre_CSRMatrixNumNonzeros(Cext),
Cext_ii,
hypre_CSRMatrixBigJ(Cext),
hypre_CSRMatrixData(Cext),
NULL,
first_col_diag,
last_col_diag,
num_cols_offd,
col_map_offd,
&offd_map_to_C,
&num_cols_offd_C,
&col_map_offd_C,
&Cext_diag_nnz,
tmp_i + local_nnz_Cbar,
tmp_j + local_nnz_Cbar,
tmp_a + local_nnz_Cbar,
NULL,
&Cext_offd_nnz,
tmp_i + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_a + local_nnz_Cbar + Cext_diag_nnz,
NULL);
hypre_CSRMatrixDestroy(Cext);
hypre_TFree(Cext_ii, HYPRE_MEMORY_DEVICE);
hypre_ParCSRCommPkgCopySendMapElmtsToDevice(comm_pkg);
#if defined(HYPRE_USING_SYCL)
hypreSycl_gather( tmp_i + local_nnz_Cbar,
tmp_i + tmp_s,
hypre_ParCSRCommPkgDeviceSendMapElmts(comm_pkg),
tmp_i + local_nnz_Cbar );
/* WM: necessary? */
if (tmp_s > local_nnz_Cbar + Cext_diag_nnz)
{
HYPRE_ONEDPL_CALL( std::transform,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + tmp_s,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
[const_val = num_cols] (const auto & x) {return x + const_val;} );
}
#else
HYPRE_THRUST_CALL( gather,
tmp_i + local_nnz_Cbar,
tmp_i + tmp_s,
hypre_ParCSRCommPkgDeviceSendMapElmts(comm_pkg),
tmp_i + local_nnz_Cbar );
HYPRE_THRUST_CALL( transform,
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
tmp_j + tmp_s,
thrust::make_constant_iterator(num_cols),
tmp_j + local_nnz_Cbar + Cext_diag_nnz,
thrust::plus<HYPRE_Int>() );
#endif
hypreDevice_CsrRowPtrsToIndices_v2(num_rows, local_nnz_Cbar, hypre_CSRMatrixI(Cbar), tmp_i);
hypre_TMemcpy(tmp_a, hypre_CSRMatrixData(Cbar), HYPRE_Complex, local_nnz_Cbar, HYPRE_MEMORY_DEVICE,
HYPRE_MEMORY_DEVICE);
RAP_functor<2, HYPRE_Int> func2(num_cols, 0, offd_map_to_C);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
if (local_nnz_Cbar > 0)
{
HYPRE_ONEDPL_CALL( std::transform,
hypre_CSRMatrixJ(Cbar),
hypre_CSRMatrixJ(Cbar) + local_nnz_Cbar,
tmp_j,
func2 );
}
#else
HYPRE_THRUST_CALL( transform,
hypre_CSRMatrixJ(Cbar),
hypre_CSRMatrixJ(Cbar) + local_nnz_Cbar,
tmp_j,
func2 );
#endif
hypre_CSRMatrixDestroy(Cbar);
hypre_TFree(offd_map_to_C, HYPRE_MEMORY_DEVICE);
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time PartialAdd1 %f\n", t2);
#endif
#if PARCSRGEMM_TIMING > 1
t1 = hypre_MPI_Wtime();
#endif
// add Cext to Cbar local. Note: type 2, diagonal entries are put at the first in the rows
hypreDevice_StableSortByTupleKey(tmp_s, tmp_i, tmp_j, tmp_a, 2);
HYPRE_Int *zmp_i = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int *zmp_j = hypre_TAlloc(HYPRE_Int, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Complex *zmp_a = hypre_TAlloc(HYPRE_Complex, tmp_s, HYPRE_MEMORY_DEVICE);
HYPRE_Int local_nnz_C = hypreDevice_ReduceByTupleKey(tmp_s, tmp_i, tmp_j, tmp_a, zmp_i, zmp_j, zmp_a);
hypre_TFree(tmp_i, HYPRE_MEMORY_DEVICE);
hypre_TFree(tmp_j, HYPRE_MEMORY_DEVICE);
hypre_TFree(tmp_a, HYPRE_MEMORY_DEVICE);
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time PartialAdd2 %f\n", t2);
#endif
#if PARCSRGEMM_TIMING > 1
t1 = hypre_MPI_Wtime();
#endif
// split into diag and offd
in_range<HYPRE_Int> pred(0, num_cols - 1);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
HYPRE_Int nnz_C_diag = 0;
if (local_nnz_C > 0)
{
nnz_C_diag = HYPRE_ONEDPL_CALL( std::count_if,
zmp_j,
zmp_j + local_nnz_C,
pred );
}
#else
HYPRE_Int nnz_C_diag = HYPRE_THRUST_CALL( count_if,
zmp_j,
zmp_j + local_nnz_C,
pred );
#endif
HYPRE_Int nnz_C_offd = local_nnz_C - nnz_C_diag;
C_diag = hypre_CSRMatrixCreate(num_rows, num_cols, nnz_C_diag);
hypre_CSRMatrixInitialize_v2(C_diag, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_ii = hypre_TAlloc(HYPRE_Int, nnz_C_diag, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_diag_j = hypre_CSRMatrixJ(C_diag);
HYPRE_Complex *C_diag_a = hypre_CSRMatrixData(C_diag);
#if defined(HYPRE_USING_SYCL)
auto new_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a),
oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a) + local_nnz_C,
zmp_j,
oneapi::dpl::make_zip_iterator(C_diag_ii, C_diag_j, C_diag_a),
pred );
hypre_assert( std::get<0>(new_end.base()) == C_diag_ii + nnz_C_diag );
#else
auto new_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)) + local_nnz_C,
zmp_j,
thrust::make_zip_iterator(thrust::make_tuple(C_diag_ii, C_diag_j, C_diag_a)),
pred );
hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_diag_ii + nnz_C_diag );
#endif
hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_diag), nnz_C_diag, C_diag_ii,
hypre_CSRMatrixI(C_diag));
hypre_TFree(C_diag_ii, HYPRE_MEMORY_DEVICE);
C_offd = hypre_CSRMatrixCreate(num_rows, num_cols_offd_C, nnz_C_offd);
hypre_CSRMatrixInitialize_v2(C_offd, 0, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_ii = hypre_TAlloc(HYPRE_Int, nnz_C_offd, HYPRE_MEMORY_DEVICE);
HYPRE_Int *C_offd_j = hypre_CSRMatrixJ(C_offd);
HYPRE_Complex *C_offd_a = hypre_CSRMatrixData(C_offd);
#if defined(HYPRE_USING_SYCL)
new_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a),
oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a) + local_nnz_C,
zmp_j,
oneapi::dpl::make_zip_iterator(C_offd_ii, C_offd_j, C_offd_a),
std::not_fn(pred) );
hypre_assert( std::get<0>(new_end.base()) == C_offd_ii + nnz_C_offd );
#else
new_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),
thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)) + local_nnz_C,
zmp_j,
thrust::make_zip_iterator(thrust::make_tuple(C_offd_ii, C_offd_j, C_offd_a)),
thrust::not1(pred) );
hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_offd_ii + nnz_C_offd );
#endif
hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_offd), nnz_C_offd, C_offd_ii,
hypre_CSRMatrixI(C_offd));
hypre_TFree(C_offd_ii, HYPRE_MEMORY_DEVICE);
#if defined(HYPRE_USING_SYCL)
/* WM: necessary? */
if (nnz_C_offd > 0)
{
HYPRE_ONEDPL_CALL( std::transform,
C_offd_j,
C_offd_j + nnz_C_offd,
C_offd_j,
[const_val = num_cols] (const auto & x) {return x - const_val;} );
}
#else
HYPRE_THRUST_CALL( transform,
C_offd_j,
C_offd_j + nnz_C_offd,
thrust::make_constant_iterator(num_cols),
C_offd_j,
thrust::minus<HYPRE_Int>() );
#endif
hypre_TFree(zmp_i, HYPRE_MEMORY_DEVICE);
hypre_TFree(zmp_j, HYPRE_MEMORY_DEVICE);
hypre_TFree(zmp_a, HYPRE_MEMORY_DEVICE);
#if PARCSRGEMM_TIMING > 1
hypre_ForceSyncComputeStream(hypre_handle());
t2 = hypre_MPI_Wtime() - t1;
hypre_ParPrintf(comm, "Time Split %f\n", t2);
#endif
*C_diag_ptr = C_diag;
*C_offd_ptr = C_offd;
*num_cols_offd_C_ptr = num_cols_offd_C;
*col_map_offd_C_ptr = col_map_offd_C;
return hypre_error_flag;
}
#endif // #if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)

View File

@ -452,10 +452,10 @@ HYPRE_Int hypre_ParCSRMatrixMatvec_FF ( HYPRE_Complex alpha, hypre_ParCSRMatrix
hypre_ParVector *x, HYPRE_Complex beta, hypre_ParVector *y, HYPRE_Int *CF_marker, HYPRE_Int fpt );
/* par_csr_triplemat.c */
HYPRE_Int hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg_A, HYPRE_Int num_cols_A, HYPRE_Int num_cols_B, HYPRE_BigInt first_col_diag_B, HYPRE_BigInt last_col_diag_B, HYPRE_Int num_cols_offd_B, HYPRE_BigInt *col_map_offd_B, HYPRE_Int local_nnz_Cbar, hypre_CSRMatrix *Cbar, hypre_CSRMatrix *Cext, hypre_CSRMatrix **C_diag_ptr, hypre_CSRMatrix **C_offd_ptr, HYPRE_Int *num_cols_offd_C_ptr, HYPRE_BigInt **col_map_offd_C_ptr );
hypre_ParCSRMatrix *hypre_ParCSRMatMat( hypre_ParCSRMatrix *A, hypre_ParCSRMatrix *B );
hypre_ParCSRMatrix *hypre_ParCSRMatMatHost( hypre_ParCSRMatrix *A, hypre_ParCSRMatrix *B );
hypre_ParCSRMatrix *hypre_ParCSRMatMatDevice( hypre_ParCSRMatrix *A, hypre_ParCSRMatrix *B );
hypre_ParCSRMatrix *hypre_ParCSRTMatMatKTHost( hypre_ParCSRMatrix *A, hypre_ParCSRMatrix *B,
HYPRE_Int keep_transpose);
hypre_ParCSRMatrix *hypre_ParCSRTMatMatKTDevice( hypre_ParCSRMatrix *A, hypre_ParCSRMatrix *B,