superlu_dist support on GPUs (#869)
This PR adds superlu_dist support on GPUs.
This commit is contained in:
parent
ca784500bc
commit
b6c60065d2
@ -148,13 +148,13 @@ FILES =\
|
||||
schwarz.c\
|
||||
block_tridiag.c\
|
||||
par_restr.c\
|
||||
par_lr_restr.c\
|
||||
dsuperlu.c
|
||||
par_lr_restr.c
|
||||
|
||||
CUFILES =\
|
||||
ads.c\
|
||||
ams.c\
|
||||
ame.c\
|
||||
dsuperlu.c\
|
||||
par_amg_setup.c\
|
||||
par_ge_device.c\
|
||||
par_ilu_setup_device.c\
|
||||
|
||||
@ -33,7 +33,10 @@ hypre_DSLUData;
|
||||
|
||||
#endif
|
||||
*/
|
||||
HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE_Int print_level)
|
||||
HYPRE_Int
|
||||
hypre_SLUDistSetup(HYPRE_Solver *solver,
|
||||
hypre_ParCSRMatrix *A,
|
||||
HYPRE_Int print_level)
|
||||
{
|
||||
/* Par Data Structure variables */
|
||||
HYPRE_BigInt global_num_rows = hypre_ParCSRMatrixGlobalNumRows(A);
|
||||
@ -44,7 +47,6 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
|
||||
HYPRE_Int pcols = 1, prows = 1;
|
||||
HYPRE_BigInt *big_rowptr = NULL;
|
||||
hypre_DSLUData *dslu_data = NULL;
|
||||
|
||||
HYPRE_Int info = 0;
|
||||
HYPRE_Int nrhs = 0;
|
||||
|
||||
@ -59,6 +61,13 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
|
||||
/* Merge diag and offd into one matrix (global ids) */
|
||||
A_local = hypre_MergeDiagAndOffd(A);
|
||||
|
||||
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
|
||||
if (hypre_GetActualMemLocation(hypre_CSRMatrixMemoryLocation(A_local)) != hypre_MEMORY_HOST)
|
||||
{
|
||||
hypre_CSRMatrixMigrate(A_local, HYPRE_MEMORY_HOST);
|
||||
}
|
||||
#endif
|
||||
|
||||
num_rows = hypre_CSRMatrixNumRows(A_local);
|
||||
/* Now convert hypre matrix to a SuperMatrix */
|
||||
#ifdef HYPRE_MIXEDINT
|
||||
@ -75,6 +84,7 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
|
||||
#else
|
||||
big_rowptr = hypre_CSRMatrixI(A_local);
|
||||
#endif
|
||||
|
||||
dCreate_CompRowLoc_Matrix_dist(
|
||||
&(dslu_data->A_dslu), global_num_rows, global_num_rows,
|
||||
hypre_CSRMatrixNumNonzeros(A_local),
|
||||
@ -134,28 +144,54 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE
|
||||
|
||||
dslu_data->dslu_options.Fact = FACTORED;
|
||||
*solver = (HYPRE_Solver) dslu_data;
|
||||
|
||||
return hypre_error_flag;
|
||||
}
|
||||
|
||||
HYPRE_Int hypre_SLUDistSolve( void* solver, hypre_ParVector *b, hypre_ParVector *x)
|
||||
HYPRE_Int
|
||||
hypre_SLUDistSolve(void *solver,
|
||||
hypre_ParVector *b,
|
||||
hypre_ParVector *x)
|
||||
{
|
||||
hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver;
|
||||
HYPRE_Int info = 0;
|
||||
HYPRE_Real *B = hypre_VectorData(hypre_ParVectorLocalVector(x));
|
||||
HYPRE_Real *x_data;
|
||||
hypre_ParVector *x_host = NULL;
|
||||
HYPRE_Int size = hypre_VectorSize(hypre_ParVectorLocalVector(x));
|
||||
HYPRE_Int nrhs = 1;
|
||||
|
||||
hypre_ParVectorCopy(b, x);
|
||||
|
||||
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
|
||||
if (hypre_GetActualMemLocation(hypre_ParVectorMemoryLocation(x)) != hypre_MEMORY_HOST)
|
||||
{
|
||||
x_host = hypre_ParVectorCloneDeep_v2(x, HYPRE_MEMORY_HOST);
|
||||
x_data = hypre_VectorData(hypre_ParVectorLocalVector(x_host));
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
x_data = hypre_VectorData(hypre_ParVectorLocalVector(x));
|
||||
}
|
||||
|
||||
pdgssvx(&(dslu_data->dslu_options), &(dslu_data->A_dslu),
|
||||
&(dslu_data->dslu_ScalePermstruct), B, size, nrhs,
|
||||
&(dslu_data->dslu_ScalePermstruct), x_data, size, nrhs,
|
||||
&(dslu_data->dslu_data_grid), &(dslu_data->dslu_data_LU),
|
||||
&(dslu_data->dslu_solve), dslu_data->berr, &(dslu_data->dslu_data_stat), &info);
|
||||
|
||||
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
|
||||
if (x_host)
|
||||
{
|
||||
hypre_ParVectorCopy(x_host, x);
|
||||
hypre_ParVectorDestroy(x_host);
|
||||
}
|
||||
#endif
|
||||
|
||||
return hypre_error_flag;
|
||||
}
|
||||
|
||||
HYPRE_Int hypre_SLUDistDestroy( void* solver)
|
||||
HYPRE_Int
|
||||
hypre_SLUDistDestroy(void* solver)
|
||||
{
|
||||
hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver;
|
||||
|
||||
@ -171,6 +207,8 @@ HYPRE_Int hypre_SLUDistDestroy( void* solver)
|
||||
superlu_gridexit(&(dslu_data->dslu_data_grid));
|
||||
hypre_TFree(dslu_data->berr, HYPRE_MEMORY_HOST);
|
||||
hypre_TFree(dslu_data, HYPRE_MEMORY_HOST);
|
||||
|
||||
return hypre_error_flag;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -12,11 +12,7 @@
|
||||
*****************************************************************************/
|
||||
|
||||
#include "_hypre_parcsr_ls.h"
|
||||
#include "par_amg.h"
|
||||
#ifdef HYPRE_USING_DSUPERLU
|
||||
#include <math.h>
|
||||
#include "superlu_ddefs.h"
|
||||
#endif
|
||||
|
||||
/*--------------------------------------------------------------------------
|
||||
* hypre_BoomerAMGCreate
|
||||
*--------------------------------------------------------------------------*/
|
||||
|
||||
@ -1991,7 +1991,7 @@ GenerateDiagAndOffd(hypre_CSRMatrix *A,
|
||||
}
|
||||
|
||||
hypre_CSRMatrix *
|
||||
hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix)
|
||||
hypre_MergeDiagAndOffdHost(hypre_ParCSRMatrix *par_matrix)
|
||||
{
|
||||
hypre_CSRMatrix *diag = hypre_ParCSRMatrixDiag(par_matrix);
|
||||
hypre_CSRMatrix *offd = hypre_ParCSRMatrixOffd(par_matrix);
|
||||
@ -2070,6 +2070,23 @@ hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix)
|
||||
return matrix;
|
||||
}
|
||||
|
||||
hypre_CSRMatrix *
|
||||
hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix)
|
||||
{
|
||||
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
|
||||
HYPRE_ExecutionPolicy exec = hypre_GetExecPolicy1( hypre_ParCSRMatrixMemoryLocation(par_matrix) );
|
||||
|
||||
if (exec == HYPRE_EXEC_DEVICE)
|
||||
{
|
||||
return hypre_MergeDiagAndOffdDevice(par_matrix);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
return hypre_MergeDiagAndOffdHost(par_matrix);
|
||||
}
|
||||
}
|
||||
|
||||
/*--------------------------------------------------------------------------
|
||||
* hypre_ParCSRMatrixToCSRMatrixAll
|
||||
*
|
||||
|
||||
Loading…
Reference in New Issue
Block a user