add using hypre's spmv option

This commit is contained in:
Ruipeng Li 2022-03-08 22:11:31 -08:00
parent 7443a2ac6c
commit 63c9fa65a2
7 changed files with 49 additions and 13 deletions

View File

@ -35,27 +35,34 @@ hypre_CSRMatrixMatvecDevice2( HYPRE_Int trans,
"ERROR::x and y are the same pointer in hypre_CSRMatrixMatvecDevice2");
}
if (hypre_HandleSpMVUseCusparse(hypre_handle()) || trans)
{
#ifdef HYPRE_USING_CUSPARSE
#if CUSPARSE_VERSION >= CUSPARSE_NEWAPI_VERSION
/* Luke E: The generic API is techinically supported on 10.1,10.2 as a preview,
* with Dscrmv being deprecated. However, there are limitations.
* While in Cuda < 11, there are specific mentions of using csr2csc involving
* transposed matrix products with dcsrm*,
* they are not present in SpMV interface.
*/
hypre_CSRMatrixMatvecCusparseNewAPI(trans, alpha, A, x, beta, y, offset);
/* Luke E: The generic API is techinically supported on 10.1,10.2 as a preview,
* with Dscrmv being deprecated. However, there are limitations.
* While in Cuda < 11, there are specific mentions of using csr2csc involving
* transposed matrix products with dcsrm*,
* they are not present in SpMV interface.
*/
hypre_CSRMatrixMatvecCusparseNewAPI(trans, alpha, A, x, beta, y, offset);
#else
hypre_CSRMatrixMatvecCusparseOldAPI(trans, alpha, A, x, beta, y, offset);
hypre_CSRMatrixMatvecCusparseOldAPI(trans, alpha, A, x, beta, y, offset);
#endif
#elif defined(HYPRE_USING_DEVICE_OPENMP)
hypre_CSRMatrixMatvecOMPOffload(trans, alpha, A, x, beta, y, offset);
hypre_CSRMatrixMatvecOMPOffload(trans, alpha, A, x, beta, y, offset);
#elif defined(HYPRE_USING_ROCSPARSE)
hypre_CSRMatrixMatvecRocsparse(trans, alpha, A, x, beta, y, offset);
hypre_CSRMatrixMatvecRocsparse(trans, alpha, A, x, beta, y, offset);
#elif defined(HYPRE_USING_ONEMKLSPARSE)
hypre_CSRMatrixMatvecOnemklsparse(trans, alpha, A, x, beta, y, offset);
#else // #ifdef HYPRE_USING_CUSPARSE
#error HYPRE SPMV TODO
hypre_CSRMatrixMatvecOnemklsparse(trans, alpha, A, x, beta, y, offset);
#else
hypre_error_w_msg(HYPRE_ERROR_GENERIC, "Error: No TPL SpMV support configured");
#endif
}
else
{
hypre_CSRMatrixSpMVDevice(alpha, A, x, beta, y, NULL, 0);
}
return hypre_error_flag;
}

View File

@ -286,6 +286,7 @@ main( hypre_int argc,
coarsen_type = 8;
mod_rap2 = 1;
HYPRE_Int spgemm_use_cusparse = 0;
HYPRE_Int spmv_use_cusparse = 1;
HYPRE_Int use_curand = 1;
#if defined(HYPRE_USING_HIP)
spgemm_use_cusparse = 1;
@ -1190,6 +1191,11 @@ main( hypre_int argc,
arg_index++;
spgemm_use_cusparse = atoi(argv[arg_index++]);
}
else if ( strcmp(argv[arg_index], "-mv_cusparse") == 0 )
{
arg_index++;
spmv_use_cusparse = atoi(argv[arg_index++]);
}
else if ( strcmp(argv[arg_index], "-spgemm_alg") == 0 )
{
arg_index++;
@ -2321,6 +2327,7 @@ main( hypre_int argc,
HYPRE_SetExecutionPolicy(default_exec_policy);
#if defined(HYPRE_USING_GPU)
ierr = HYPRE_SetSpMVUseCusparse(spmv_use_cusparse); hypre_assert(ierr == 0);
/* use cuSPARSE for SpGEMM */
ierr = HYPRE_SetSpGemmUseCusparse(spgemm_use_cusparse); hypre_assert(ierr == 0);
ierr = hypre_SetSpGemmAlgorithm(spgemm_alg); hypre_assert(ierr == 0);

View File

@ -13,6 +13,15 @@
#include "_hypre_utilities.h"
/*--------------------------------------------------------------------------
* HYPRE_SetSpMVUseCusparse
*--------------------------------------------------------------------------*/
HYPRE_Int
HYPRE_SetSpMVUseCusparse( HYPRE_Int use_cusparse )
{
return hypre_SetSpMVUseCusparse(use_cusparse);
}
/*--------------------------------------------------------------------------
* HYPRE_SetSpGemmUseCusparse
*--------------------------------------------------------------------------*/

View File

@ -229,6 +229,7 @@ HYPRE_Int HYPRE_SetGPUMemoryPoolSize(HYPRE_Int bin_growth, HYPRE_Int min_bin, HY
* HYPRE handle
*--------------------------------------------------------------------------*/
HYPRE_Int HYPRE_SetSpMVUseCusparse( HYPRE_Int use_cusparse );
HYPRE_Int HYPRE_SetSpGemmUseCusparse( HYPRE_Int use_cusparse );
HYPRE_Int HYPRE_SetUseGpuRand( HYPRE_Int use_curand );

View File

@ -1814,6 +1814,7 @@ HYPRE_Int hypre_SyncComputeStream(hypre_Handle *hypre_handle);
HYPRE_Int hypre_ForceSyncComputeStream(hypre_Handle *hypre_handle);
/* handle.c */
HYPRE_Int hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse );
HYPRE_Int hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse );
HYPRE_Int hypre_SetSpGemmAlgorithm( HYPRE_Int value );
HYPRE_Int hypre_SetSpGemmRownnzEstimateMethod( HYPRE_Int value );

View File

@ -14,6 +14,16 @@
#include "_hypre_utilities.h"
#include "_hypre_utilities.hpp"
/* GPU SpMV */
HYPRE_Int
hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse )
{
#if defined(HYPRE_USING_GPU)
hypre_HandleSpMVUseCusparse(hypre_handle()) = use_cusparse;
#endif
return hypre_error_flag;
}
/* GPU SpGemm */
HYPRE_Int
hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse )

View File

@ -324,6 +324,7 @@ HYPRE_Int hypre_SyncComputeStream(hypre_Handle *hypre_handle);
HYPRE_Int hypre_ForceSyncComputeStream(hypre_Handle *hypre_handle);
/* handle.c */
HYPRE_Int hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse );
HYPRE_Int hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse );
HYPRE_Int hypre_SetSpGemmAlgorithm( HYPRE_Int value );
HYPRE_Int hypre_SetSpGemmRownnzEstimateMethod( HYPRE_Int value );