superlu_dist support on GPUs (#869)

This PR adds superlu_dist support on GPUs.
This commit is contained in:
Rui Peng Li 2023-10-25 14:35:09 -07:00 committed by GitHub
parent ca784500bc
commit b6c60065d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 27 deletions

View File

@ -148,13 +148,13 @@ FILES =\
schwarz.c\ schwarz.c\
block_tridiag.c\ block_tridiag.c\
par_restr.c\ par_restr.c\
par_lr_restr.c\ par_lr_restr.c
dsuperlu.c
CUFILES =\ CUFILES =\
ads.c\ ads.c\
ams.c\ ams.c\
ame.c\ ame.c\
dsuperlu.c\
par_amg_setup.c\ par_amg_setup.c\
par_ge_device.c\ par_ge_device.c\
par_ilu_setup_device.c\ par_ilu_setup_device.c\

View File

@ -33,20 +33,22 @@ hypre_DSLUData;
#endif #endif
*/ */
HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE_Int print_level) HYPRE_Int
hypre_SLUDistSetup(HYPRE_Solver *solver,
hypre_ParCSRMatrix *A,
HYPRE_Int print_level)
{ {
/* Par Data Structure variables */ /* Par Data Structure variables */
HYPRE_BigInt global_num_rows = hypre_ParCSRMatrixGlobalNumRows(A); HYPRE_BigInt global_num_rows = hypre_ParCSRMatrixGlobalNumRows(A);
MPI_Comm comm = hypre_ParCSRMatrixComm(A); MPI_Comm comm = hypre_ParCSRMatrixComm(A);
hypre_CSRMatrix *A_local; hypre_CSRMatrix *A_local;
HYPRE_Int num_rows; HYPRE_Int num_rows;
HYPRE_Int num_procs, my_id; HYPRE_Int num_procs, my_id;
HYPRE_Int pcols = 1, prows = 1; HYPRE_Int pcols = 1, prows = 1;
HYPRE_BigInt *big_rowptr = NULL; HYPRE_BigInt *big_rowptr = NULL;
hypre_DSLUData *dslu_data = NULL; hypre_DSLUData *dslu_data = NULL;
HYPRE_Int info = 0;
HYPRE_Int info = 0; HYPRE_Int nrhs = 0;
HYPRE_Int nrhs = 0;
hypre_MPI_Comm_size(comm, &num_procs); hypre_MPI_Comm_size(comm, &num_procs);
hypre_MPI_Comm_rank(comm, &my_id); hypre_MPI_Comm_rank(comm, &my_id);
@ -59,6 +61,13 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
/* Merge diag and offd into one matrix (global ids) */ /* Merge diag and offd into one matrix (global ids) */
A_local = hypre_MergeDiagAndOffd(A); A_local = hypre_MergeDiagAndOffd(A);
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
if (hypre_GetActualMemLocation(hypre_CSRMatrixMemoryLocation(A_local)) != hypre_MEMORY_HOST)
{
hypre_CSRMatrixMigrate(A_local, HYPRE_MEMORY_HOST);
}
#endif
num_rows = hypre_CSRMatrixNumRows(A_local); num_rows = hypre_CSRMatrixNumRows(A_local);
/* Now convert hypre matrix to a SuperMatrix */ /* Now convert hypre matrix to a SuperMatrix */
#ifdef HYPRE_MIXEDINT #ifdef HYPRE_MIXEDINT
@ -75,6 +84,7 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
#else #else
big_rowptr = hypre_CSRMatrixI(A_local); big_rowptr = hypre_CSRMatrixI(A_local);
#endif #endif
dCreate_CompRowLoc_Matrix_dist( dCreate_CompRowLoc_Matrix_dist(
&(dslu_data->A_dslu), global_num_rows, global_num_rows, &(dslu_data->A_dslu), global_num_rows, global_num_rows,
hypre_CSRMatrixNumNonzeros(A_local), hypre_CSRMatrixNumNonzeros(A_local),
@ -134,28 +144,54 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
dslu_data->dslu_options.Fact = FACTORED; dslu_data->dslu_options.Fact = FACTORED;
*solver = (HYPRE_Solver) dslu_data; *solver = (HYPRE_Solver) dslu_data;
return hypre_error_flag; return hypre_error_flag;
} }
HYPRE_Int hypre_SLUDistSolve( void* solver, hypre_ParVector *b, hypre_ParVector *x) HYPRE_Int
hypre_SLUDistSolve(void *solver,
hypre_ParVector *b,
hypre_ParVector *x)
{ {
hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver; hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver;
HYPRE_Int info = 0; HYPRE_Int info = 0;
HYPRE_Real *B = hypre_VectorData(hypre_ParVectorLocalVector(x)); HYPRE_Real *x_data;
HYPRE_Int size = hypre_VectorSize(hypre_ParVectorLocalVector(x)); hypre_ParVector *x_host = NULL;
HYPRE_Int nrhs = 1; HYPRE_Int size = hypre_VectorSize(hypre_ParVectorLocalVector(x));
HYPRE_Int nrhs = 1;
hypre_ParVectorCopy(b, x); hypre_ParVectorCopy(b, x);
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
if (hypre_GetActualMemLocation(hypre_ParVectorMemoryLocation(x)) != hypre_MEMORY_HOST)
{
x_host = hypre_ParVectorCloneDeep_v2(x, HYPRE_MEMORY_HOST);
x_data = hypre_VectorData(hypre_ParVectorLocalVector(x_host));
}
else
#endif
{
x_data = hypre_VectorData(hypre_ParVectorLocalVector(x));
}
pdgssvx(&(dslu_data->dslu_options), &(dslu_data->A_dslu), pdgssvx(&(dslu_data->dslu_options), &(dslu_data->A_dslu),
&(dslu_data->dslu_ScalePermstruct), B, size, nrhs, &(dslu_data->dslu_ScalePermstruct), x_data, size, nrhs,
&(dslu_data->dslu_data_grid), &(dslu_data->dslu_data_LU), &(dslu_data->dslu_data_grid), &(dslu_data->dslu_data_LU),
&(dslu_data->dslu_solve), dslu_data->berr, &(dslu_data->dslu_data_stat), &info); &(dslu_data->dslu_solve), dslu_data->berr, &(dslu_data->dslu_data_stat), &info);
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
if (x_host)
{
hypre_ParVectorCopy(x_host, x);
hypre_ParVectorDestroy(x_host);
}
#endif
return hypre_error_flag; return hypre_error_flag;
} }
HYPRE_Int hypre_SLUDistDestroy( void* solver) HYPRE_Int
hypre_SLUDistDestroy(void* solver)
{ {
hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver; hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver;
@ -171,6 +207,8 @@ HYPRE_Int hypre_SLUDistDestroy( void* solver)
superlu_gridexit(&(dslu_data->dslu_data_grid)); superlu_gridexit(&(dslu_data->dslu_data_grid));
hypre_TFree(dslu_data->berr, HYPRE_MEMORY_HOST); hypre_TFree(dslu_data->berr, HYPRE_MEMORY_HOST);
hypre_TFree(dslu_data, HYPRE_MEMORY_HOST); hypre_TFree(dslu_data, HYPRE_MEMORY_HOST);
return hypre_error_flag; return hypre_error_flag;
} }
#endif #endif

View File

@ -12,11 +12,7 @@
*****************************************************************************/ *****************************************************************************/
#include "_hypre_parcsr_ls.h" #include "_hypre_parcsr_ls.h"
#include "par_amg.h"
#ifdef HYPRE_USING_DSUPERLU
#include <math.h>
#include "superlu_ddefs.h"
#endif
/*-------------------------------------------------------------------------- /*--------------------------------------------------------------------------
* hypre_BoomerAMGCreate * hypre_BoomerAMGCreate
*--------------------------------------------------------------------------*/ *--------------------------------------------------------------------------*/

View File

@ -1991,7 +1991,7 @@ GenerateDiagAndOffd(hypre_CSRMatrix *A,
} }
hypre_CSRMatrix * hypre_CSRMatrix *
hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix) hypre_MergeDiagAndOffdHost(hypre_ParCSRMatrix *par_matrix)
{ {
hypre_CSRMatrix *diag = hypre_ParCSRMatrixDiag(par_matrix); hypre_CSRMatrix *diag = hypre_ParCSRMatrixDiag(par_matrix);
hypre_CSRMatrix *offd = hypre_ParCSRMatrixOffd(par_matrix); hypre_CSRMatrix *offd = hypre_ParCSRMatrixOffd(par_matrix);
@ -2070,6 +2070,23 @@ hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix)
return matrix; return matrix;
} }
hypre_CSRMatrix *
hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix)
{
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
HYPRE_ExecutionPolicy exec = hypre_GetExecPolicy1( hypre_ParCSRMatrixMemoryLocation(par_matrix) );
if (exec == HYPRE_EXEC_DEVICE)
{
return hypre_MergeDiagAndOffdDevice(par_matrix);
}
else
#endif
{
return hypre_MergeDiagAndOffdHost(par_matrix);
}
}
/*-------------------------------------------------------------------------- /*--------------------------------------------------------------------------
* hypre_ParCSRMatrixToCSRMatrixAll * hypre_ParCSRMatrixToCSRMatrixAll
* *