From b6c60065d22782625468cfa8e0e9143780ce8534 Mon Sep 17 00:00:00 2001 From: Rui Peng Li Date: Wed, 25 Oct 2023 14:35:09 -0700 Subject: [PATCH] superlu_dist support on GPUs (#869) This PR adds superlu_dist support on GPUs. --- src/parcsr_ls/Makefile | 4 +- src/parcsr_ls/dsuperlu.c | 76 +++++++++++++++++++++++++--------- src/parcsr_ls/par_amg.c | 6 +-- src/parcsr_mv/par_csr_matrix.c | 19 ++++++++- 4 files changed, 78 insertions(+), 27 deletions(-) diff --git a/src/parcsr_ls/Makefile b/src/parcsr_ls/Makefile index ebdb43e7b..4172f97c7 100644 --- a/src/parcsr_ls/Makefile +++ b/src/parcsr_ls/Makefile @@ -148,13 +148,13 @@ FILES =\ schwarz.c\ block_tridiag.c\ par_restr.c\ - par_lr_restr.c\ - dsuperlu.c + par_lr_restr.c CUFILES =\ ads.c\ ams.c\ ame.c\ + dsuperlu.c\ par_amg_setup.c\ par_ge_device.c\ par_ilu_setup_device.c\ diff --git a/src/parcsr_ls/dsuperlu.c b/src/parcsr_ls/dsuperlu.c index 31bd21825..4f24ea09e 100644 --- a/src/parcsr_ls/dsuperlu.c +++ b/src/parcsr_ls/dsuperlu.c @@ -33,20 +33,22 @@ hypre_DSLUData; #endif */ -HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE_Int print_level) +HYPRE_Int +hypre_SLUDistSetup(HYPRE_Solver *solver, + hypre_ParCSRMatrix *A, + HYPRE_Int print_level) { /* Par Data Structure variables */ - HYPRE_BigInt global_num_rows = hypre_ParCSRMatrixGlobalNumRows(A); + HYPRE_BigInt global_num_rows = hypre_ParCSRMatrixGlobalNumRows(A); MPI_Comm comm = hypre_ParCSRMatrixComm(A); - hypre_CSRMatrix *A_local; - HYPRE_Int num_rows; - HYPRE_Int num_procs, my_id; - HYPRE_Int pcols = 1, prows = 1; - HYPRE_BigInt *big_rowptr = NULL; - hypre_DSLUData *dslu_data = NULL; - - HYPRE_Int info = 0; - HYPRE_Int nrhs = 0; + hypre_CSRMatrix *A_local; + HYPRE_Int num_rows; + HYPRE_Int num_procs, my_id; + HYPRE_Int pcols = 1, prows = 1; + HYPRE_BigInt *big_rowptr = NULL; + hypre_DSLUData *dslu_data = NULL; + HYPRE_Int info = 0; + HYPRE_Int nrhs = 0; hypre_MPI_Comm_size(comm, &num_procs); hypre_MPI_Comm_rank(comm, &my_id); @@ -59,6 +61,13 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE /* Merge diag and offd into one matrix (global ids) */ A_local = hypre_MergeDiagAndOffd(A); +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL) + if (hypre_GetActualMemLocation(hypre_CSRMatrixMemoryLocation(A_local)) != hypre_MEMORY_HOST) + { + hypre_CSRMatrixMigrate(A_local, HYPRE_MEMORY_HOST); + } +#endif + num_rows = hypre_CSRMatrixNumRows(A_local); /* Now convert hypre matrix to a SuperMatrix */ #ifdef HYPRE_MIXEDINT @@ -75,6 +84,7 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE #else big_rowptr = hypre_CSRMatrixI(A_local); #endif + dCreate_CompRowLoc_Matrix_dist( &(dslu_data->A_dslu), global_num_rows, global_num_rows, hypre_CSRMatrixNumNonzeros(A_local), @@ -134,28 +144,54 @@ HYPRE_Int hypre_SLUDistSetup( HYPRE_Solver *solver, hypre_ParCSRMatrix *A, HYPRE dslu_data->dslu_options.Fact = FACTORED; *solver = (HYPRE_Solver) dslu_data; + return hypre_error_flag; } -HYPRE_Int hypre_SLUDistSolve( void* solver, hypre_ParVector *b, hypre_ParVector *x) +HYPRE_Int +hypre_SLUDistSolve(void *solver, + hypre_ParVector *b, + hypre_ParVector *x) { - hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver; - HYPRE_Int info = 0; - HYPRE_Real *B = hypre_VectorData(hypre_ParVectorLocalVector(x)); - HYPRE_Int size = hypre_VectorSize(hypre_ParVectorLocalVector(x)); - HYPRE_Int nrhs = 1; + hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver; + HYPRE_Int info = 0; + HYPRE_Real *x_data; + hypre_ParVector *x_host = NULL; + HYPRE_Int size = hypre_VectorSize(hypre_ParVectorLocalVector(x)); + HYPRE_Int nrhs = 1; hypre_ParVectorCopy(b, x); +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL) + if (hypre_GetActualMemLocation(hypre_ParVectorMemoryLocation(x)) != hypre_MEMORY_HOST) + { + x_host = hypre_ParVectorCloneDeep_v2(x, HYPRE_MEMORY_HOST); + x_data = hypre_VectorData(hypre_ParVectorLocalVector(x_host)); + } + else +#endif + { + x_data = hypre_VectorData(hypre_ParVectorLocalVector(x)); + } + pdgssvx(&(dslu_data->dslu_options), &(dslu_data->A_dslu), - &(dslu_data->dslu_ScalePermstruct), B, size, nrhs, + &(dslu_data->dslu_ScalePermstruct), x_data, size, nrhs, &(dslu_data->dslu_data_grid), &(dslu_data->dslu_data_LU), &(dslu_data->dslu_solve), dslu_data->berr, &(dslu_data->dslu_data_stat), &info); +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL) + if (x_host) + { + hypre_ParVectorCopy(x_host, x); + hypre_ParVectorDestroy(x_host); + } +#endif + return hypre_error_flag; } -HYPRE_Int hypre_SLUDistDestroy( void* solver) +HYPRE_Int +hypre_SLUDistDestroy(void* solver) { hypre_DSLUData *dslu_data = (hypre_DSLUData *) solver; @@ -171,6 +207,8 @@ HYPRE_Int hypre_SLUDistDestroy( void* solver) superlu_gridexit(&(dslu_data->dslu_data_grid)); hypre_TFree(dslu_data->berr, HYPRE_MEMORY_HOST); hypre_TFree(dslu_data, HYPRE_MEMORY_HOST); + return hypre_error_flag; } + #endif diff --git a/src/parcsr_ls/par_amg.c b/src/parcsr_ls/par_amg.c index 6da5aa269..9fd38e2f7 100644 --- a/src/parcsr_ls/par_amg.c +++ b/src/parcsr_ls/par_amg.c @@ -12,11 +12,7 @@ *****************************************************************************/ #include "_hypre_parcsr_ls.h" -#include "par_amg.h" -#ifdef HYPRE_USING_DSUPERLU -#include -#include "superlu_ddefs.h" -#endif + /*-------------------------------------------------------------------------- * hypre_BoomerAMGCreate *--------------------------------------------------------------------------*/ diff --git a/src/parcsr_mv/par_csr_matrix.c b/src/parcsr_mv/par_csr_matrix.c index d5d964b78..0703b49f1 100644 --- a/src/parcsr_mv/par_csr_matrix.c +++ b/src/parcsr_mv/par_csr_matrix.c @@ -1991,7 +1991,7 @@ GenerateDiagAndOffd(hypre_CSRMatrix *A, } hypre_CSRMatrix * -hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix) +hypre_MergeDiagAndOffdHost(hypre_ParCSRMatrix *par_matrix) { hypre_CSRMatrix *diag = hypre_ParCSRMatrixDiag(par_matrix); hypre_CSRMatrix *offd = hypre_ParCSRMatrixOffd(par_matrix); @@ -2070,6 +2070,23 @@ hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix) return matrix; } +hypre_CSRMatrix * +hypre_MergeDiagAndOffd(hypre_ParCSRMatrix *par_matrix) +{ +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) + HYPRE_ExecutionPolicy exec = hypre_GetExecPolicy1( hypre_ParCSRMatrixMemoryLocation(par_matrix) ); + + if (exec == HYPRE_EXEC_DEVICE) + { + return hypre_MergeDiagAndOffdDevice(par_matrix); + } + else +#endif + { + return hypre_MergeDiagAndOffdHost(par_matrix); + } +} + /*-------------------------------------------------------------------------- * hypre_ParCSRMatrixToCSRMatrixAll *