superlu_dist support on GPUs (#869)

This PR adds superlu_dist support on GPUs.
This commit is contained in:
Rui Peng Li 2023-10-25 14:35:09 -07:00 committed by GitHub
parent ca784500bc
commit b6c60065d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 27 deletions

View File

@ -148,13 +148,13 @@ FILES =\
schwarz.c\
block_tridiag.c\
par_restr.c\
par_lr_restr.c\
dsuperlu.c
par_lr_restr.c
CUFILES =\
ads.c\
ams.c\
ame.c\
dsuperlu.c\
par_amg_setup.c\
par_ge_device.c\
par_ilu_setup_device.c\

View File

@ -33,20 +33,22 @@ hypre_DSLUData;
#endif
*/
HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE_Int print_level)
HYPRE_Int
hypre_SLUDistSetup(HYPRE_Solver *solver,
hypre_ParCSRMatrix *A,
HYPRE_Int print_level)
{
/* Par Data Structure variables */
HYPRE_BigInt global_num_rows = hypre_ParCSRMatrixGlobalNumRows(A);
HYPRE_BigInt global_num_rows = hypre_ParCSRMatrixGlobalNumRows(A);
MPI_Comm comm = hypre_ParCSRMatrixComm(A);
hypre_CSRMatrix *A_local;
HYPRE_Int num_rows;
HYPRE_Int num_procs, my_id;
HYPRE_Int pcols = 1, prows = 1;
HYPRE_BigInt *big_rowptr = NULL;
hypre_DSLUData *dslu_data = NULL;
HYPRE_Int info = 0;
HYPRE_Int nrhs = 0;
hypre_CSRMatrix *A_local;
HYPRE_Int num_rows;
HYPRE_Int num_procs, my_id;
HYPRE_Int pcols = 1, prows = 1;
HYPRE_BigInt *big_rowptr = NULL;
hypre_DSLUData *dslu_data = NULL;
HYPRE_Int info = 0;
HYPRE_Int nrhs = 0;
hypre_MPI_Comm_size(comm, &num_procs);
hypre_MPI_Comm_rank(comm, &my_id);
@ -59,6 +61,13 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
/* Merge diag and offd into one matrix (global ids) */
A_local = hypre_MergeDiagAndOffd(A);
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
if (hypre_GetActualMemLocation(hypre_CSRMatrixMemoryLocation(A_local)) != hypre_MEMORY_HOST)
{
hypre_CSRMatrixMigrate(A_local, HYPRE_MEMORY_HOST);
}
#endif
num_rows = hypre_CSRMatrixNumRows(A_local);
/* Now convert hypre matrix to a SuperMatrix */
#ifdef HYPRE_MIXEDINT
@ -75,6 +84,7 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
#else
big_rowptr = hypre_CSRMatrixI(A_local);
#endif
dCreate_CompRowLoc_Matrix_dist(
&(dslu_data->A_dslu), global_num_rows, global_num_rows,
hypre_CSRMatrixNumNonzeros(A_local),
@ -134,28 +144,54 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
dslu_data->dslu_options.Fact = FACTORED;
*solver = (HYPRE_Solver) dslu_data;
return hypre_error_flag;
}
HYPRE_Int hypre_SLUDistSolve( void* solver, hypre_ParVector *b, hypre_ParVector *x)
HYPRE_Int
hypre_SLUDistSolve(void *solver,
hypre_ParVector *b,
hypre_ParVector *x)
{
hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver;
HYPRE_Int info = 0;
HYPRE_Real *B = hypre_VectorData(hypre_ParVectorLocalVector(x));
HYPRE_Int size = hypre_VectorSize(hypre_ParVectorLocalVector(x));
HYPRE_Int nrhs = 1;
hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver;
HYPRE_Int info = 0;
HYPRE_Real *x_data;
hypre_ParVector *x_host = NULL;
HYPRE_Int size = hypre_VectorSize(hypre_ParVectorLocalVector(x));
HYPRE_Int nrhs = 1;
hypre_ParVectorCopy(b, x);
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
if (hypre_GetActualMemLocation(hypre_ParVectorMemoryLocation(x)) != hypre_MEMORY_HOST)
{
x_host = hypre_ParVectorCloneDeep_v2(x, HYPRE_MEMORY_HOST);
x_data = hypre_VectorData(hypre_ParVectorLocalVector(x_host));
}
else
#endif
{
x_data = hypre_VectorData(hypre_ParVectorLocalVector(x));
}
pdgssvx(&(dslu_data->dslu_options), &(dslu_data->A_dslu),
&(dslu_data->dslu_ScalePermstruct), B, size, nrhs,
&(dslu_data->dslu_ScalePermstruct), x_data, size, nrhs,
&(dslu_data->dslu_data_grid), &(dslu_data->dslu_data_LU),
&(dslu_data->dslu_solve), dslu_data->berr, &(dslu_data->dslu_data_stat), &info);
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
if (x_host)
{
hypre_ParVectorCopy(x_host, x);
hypre_ParVectorDestroy(x_host);
}
#endif
return hypre_error_flag;
}
HYPRE_Int hypre_SLUDistDestroy( void* solver)
HYPRE_Int
hypre_SLUDistDestroy(void* solver)
{
hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver;
@ -171,6 +207,8 @@ HYPRE_Int hypre_SLUDistDestroy( void* solver)
superlu_gridexit(&(dslu_data->dslu_data_grid));
hypre_TFree(dslu_data->berr, HYPRE_MEMORY_HOST);
hypre_TFree(dslu_data, HYPRE_MEMORY_HOST);
return hypre_error_flag;
}
#endif

View File

@ -12,11 +12,7 @@
*****************************************************************************/
#include "_hypre_parcsr_ls.h"
#include "par_amg.h"
#ifdef HYPRE_USING_DSUPERLU
#include <math.h>
#include "superlu_ddefs.h"
#endif
/*--------------------------------------------------------------------------
* hypre_BoomerAMGCreate
*--------------------------------------------------------------------------*/

View File

@ -1991,7 +1991,7 @@ GenerateDiagAndOffd(hypre_CSRMatrix *A,
}
hypre_CSRMatrix *
hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix)
hypre_MergeDiagAndOffdHost(hypre_ParCSRMatrix *par_matrix)
{
hypre_CSRMatrix *diag = hypre_ParCSRMatrixDiag(par_matrix);
hypre_CSRMatrix *offd = hypre_ParCSRMatrixOffd(par_matrix);
@ -2070,6 +2070,23 @@ hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix)
return matrix;
}
hypre_CSRMatrix *
hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix)
{
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
HYPRE_ExecutionPolicy exec = hypre_GetExecPolicy1( hypre_ParCSRMatrixMemoryLocation(par_matrix) );
if (exec == HYPRE_EXEC_DEVICE)
{
return hypre_MergeDiagAndOffdDevice(par_matrix);
}
else
#endif
{
return hypre_MergeDiagAndOffdHost(par_matrix);
}
}
/*--------------------------------------------------------------------------
* hypre_ParCSRMatrixToCSRMatrixAll
*