diff --git a/src/parcsr_ls/HYPRE_parcsr_ls.h b/src/parcsr_ls/HYPRE_parcsr_ls.h index 205d28a65..0e23e311a 100644 --- a/src/parcsr_ls/HYPRE_parcsr_ls.h +++ b/src/parcsr_ls/HYPRE_parcsr_ls.h @@ -3961,7 +3961,7 @@ HYPRE_MGRSetReservedCpointsLevelToKeep( HYPRE_Solver solver, HYPRE_Int level); * (Optional) Set the relaxation type for F-relaxation. * Currently supports the following flavors of relaxation types * as described in the \e BoomerAMGSetRelaxType: - * \e relax_type 0, 3 - 8, 13, 14, 18. Also supports AMG (options 1 and 2) + * \e relax_type 0, 3 - 8, 13, 14, 18. Also supports AMG (options 1 and 2) * and direct solver variants (9, 99, 199). See HYPRE_MGRSetLevelFRelaxType for details. **/ HYPRE_Int diff --git a/src/parcsr_mv/par_csr_triplemat_device.c b/src/parcsr_mv/par_csr_triplemat_device.c index 61f3b9aec..4863ca6bc 100644 --- a/src/parcsr_mv/par_csr_triplemat_device.c +++ b/src/parcsr_mv/par_csr_triplemat_device.c @@ -623,17 +623,18 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, hypre_ParCSRMatrix *P, HYPRE_Int keep_transpose ) { - hypre_CSRMatrix *R_diag = hypre_ParCSRMatrixDiag(R); - hypre_CSRMatrix *R_offd = hypre_ParCSRMatrixOffd(R); + MPI_Comm comm = hypre_ParCSRMatrixComm(A); + hypre_CSRMatrix *R_diag = hypre_ParCSRMatrixDiag(R); + hypre_CSRMatrix *R_offd = hypre_ParCSRMatrixOffd(R); - hypre_ParCSRMatrix *C; - hypre_CSRMatrix *C_diag; - hypre_CSRMatrix *C_offd; - HYPRE_Int num_cols_offd_C = 0; - HYPRE_BigInt *col_map_offd_C = NULL; + hypre_ParCSRMatrix *C; + hypre_CSRMatrix *C_diag; + hypre_CSRMatrix *C_offd; + HYPRE_Int num_cols_offd_C = 0; + HYPRE_BigInt *col_map_offd_C = NULL; + + HYPRE_Int num_procs; - HYPRE_Int num_procs; - MPI_Comm comm = hypre_ParCSRMatrixComm(A); hypre_MPI_Comm_size(comm, &num_procs); if ( hypre_ParCSRMatrixGlobalNumRows(R) != hypre_ParCSRMatrixGlobalNumRows(A) || @@ -717,12 +718,14 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, hypre_CSRMatrixDestroy(Abar); hypre_CSRMatrixDestroy(Pbar); - hypre_assert(hypre_CSRMatrixNumRows(Cbar) == hypre_ParCSRMatrixNumCols(R) + hypre_CSRMatrixNumCols( - R_offd)); - hypre_assert(hypre_CSRMatrixNumCols(Cbar) == hypre_ParCSRMatrixNumCols(P) + num_cols_offd); + hypre_assert(hypre_CSRMatrixNumRows(Cbar) == + hypre_ParCSRMatrixNumCols(R) + hypre_CSRMatrixNumCols(R_offd)); + hypre_assert(hypre_CSRMatrixNumCols(Cbar) == + hypre_ParCSRMatrixNumCols(P) + num_cols_offd); - hypre_TMemcpy(&local_nnz_Cbar, hypre_CSRMatrixI(Cbar) + hypre_ParCSRMatrixNumCols(R), HYPRE_Int, 1, - HYPRE_MEMORY_HOST, HYPRE_MEMORY_DEVICE); + hypre_TMemcpy(&local_nnz_Cbar, + hypre_CSRMatrixI(Cbar) + hypre_ParCSRMatrixNumCols(R), + HYPRE_Int, 1, HYPRE_MEMORY_HOST, HYPRE_MEMORY_DEVICE); // Cint is the bottom part of Cbar Cint = hypre_CSRMatrixCreate(hypre_CSRMatrixNumCols(R_offd), hypre_CSRMatrixNumCols(Cbar), @@ -748,10 +751,12 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, // Change Cint into a BigJ matrix // RL: TODO FIX the 'big' num of columns to global size - hypre_CSRMatrixBigJ(Cint) = hypre_TAlloc(HYPRE_BigInt, hypre_CSRMatrixNumNonzeros(Cint), + hypre_CSRMatrixBigJ(Cint) = hypre_TAlloc(HYPRE_BigInt, + hypre_CSRMatrixNumNonzeros(Cint), HYPRE_MEMORY_DEVICE); - RAP_functor<1, HYPRE_BigInt> func1(hypre_ParCSRMatrixNumCols(P), hypre_ParCSRMatrixFirstColDiag(P), + RAP_functor<1, HYPRE_BigInt> func1(hypre_ParCSRMatrixNumCols(P), + hypre_ParCSRMatrixFirstColDiag(P), col_map_offd); #if defined(HYPRE_USING_SYCL) HYPRE_ONEDPL_CALL( std::transform, @@ -780,7 +785,8 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, hypre_TFree(hypre_CSRMatrixBigJ(Cint), HYPRE_MEMORY_DEVICE); hypre_TFree(Cint, HYPRE_MEMORY_HOST); - hypre_TMemcpy(hypre_CSRMatrixI(Cbar) + hypre_ParCSRMatrixNumCols(R), &local_nnz_Cbar, HYPRE_Int, 1, + hypre_TMemcpy(hypre_CSRMatrixI(Cbar) + hypre_ParCSRMatrixNumCols(R), + &local_nnz_Cbar, HYPRE_Int, 1, HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_HOST); /* add Cext to local part of Cbar */ @@ -806,17 +812,32 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, hypre_CSRMatrix *R_diagT; hypre_CSRMatrix *A_diag = hypre_ParCSRMatrixDiag(A); hypre_CSRMatrix *P_diag = hypre_ParCSRMatrixDiag(P); - hypre_CSRMatrixTransposeDevice(R_diag, &R_diagT, 1); - C_diag = hypre_CSRMatrixTripleMultiplyDevice(R_diagT, A_diag, P_diag); - C_offd = hypre_CSRMatrixCreate(hypre_ParCSRMatrixNumCols(R), 0, 0); - hypre_CSRMatrixInitialize_v2(C_offd, 0, HYPRE_MEMORY_DEVICE); - if (keep_transpose) + + /* Recover or compute transpose of R_diag */ + if (hypre_ParCSRMatrixDiagT(R)) { - hypre_ParCSRMatrixDiagT(R) = R_diagT; + R_diagT = hypre_ParCSRMatrixDiagT(R); } else { - hypre_CSRMatrixDestroy(R_diagT); + hypre_CSRMatrixTransposeDevice(R_diag, &R_diagT, 1); + } + + C_diag = hypre_CSRMatrixTripleMultiplyDevice(R_diagT, A_diag, P_diag); + C_offd = hypre_CSRMatrixCreate(hypre_ParCSRMatrixNumCols(R), 0, 0); + hypre_CSRMatrixInitialize_v2(C_offd, 0, HYPRE_MEMORY_DEVICE); + + /* Keep or destroy transpose of R_diag */ + if (!hypre_ParCSRMatrixDiagT(R)) + { + if (keep_transpose) + { + hypre_ParCSRMatrixDiagT(R) = R_diagT; + } + else + { + hypre_CSRMatrixDestroy(R_diagT); + } } } @@ -842,9 +863,11 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, { hypre_ParCSRMatrixDeviceColMapOffd(C) = col_map_offd_C; - hypre_ParCSRMatrixColMapOffd(C) = hypre_TAlloc(HYPRE_BigInt, num_cols_offd_C, HYPRE_MEMORY_HOST); - hypre_TMemcpy(hypre_ParCSRMatrixColMapOffd(C), col_map_offd_C, HYPRE_BigInt, num_cols_offd_C, - HYPRE_MEMORY_HOST, HYPRE_MEMORY_DEVICE); + hypre_ParCSRMatrixColMapOffd(C) = hypre_TAlloc(HYPRE_BigInt, + num_cols_offd_C, + HYPRE_MEMORY_HOST); + hypre_TMemcpy(hypre_ParCSRMatrixColMapOffd(C), col_map_offd_C, HYPRE_BigInt, + num_cols_offd_C, HYPRE_MEMORY_HOST, HYPRE_MEMORY_DEVICE); } hypre_assert(!hypre_CSRMatrixCheckDiagFirstDevice(hypre_ParCSRMatrixDiag(C)));