add using hypre's spmv option
This commit is contained in:
parent
7443a2ac6c
commit
63c9fa65a2
@ -35,27 +35,34 @@ hypre_CSRMatrixMatvecDevice2( HYPRE_Int trans,
|
||||
"ERROR::x and y are the same pointer in hypre_CSRMatrixMatvecDevice2");
|
||||
}
|
||||
|
||||
if (hypre_HandleSpMVUseCusparse(hypre_handle()) || trans)
|
||||
{
|
||||
#ifdef HYPRE_USING_CUSPARSE
|
||||
#if CUSPARSE_VERSION >= CUSPARSE_NEWAPI_VERSION
|
||||
/* Luke E: The generic API is techinically supported on 10.1,10.2 as a preview,
|
||||
* with Dscrmv being deprecated. However, there are limitations.
|
||||
* While in Cuda < 11, there are specific mentions of using csr2csc involving
|
||||
* transposed matrix products with dcsrm*,
|
||||
* they are not present in SpMV interface.
|
||||
*/
|
||||
hypre_CSRMatrixMatvecCusparseNewAPI(trans, alpha, A, x, beta, y, offset);
|
||||
/* Luke E: The generic API is techinically supported on 10.1,10.2 as a preview,
|
||||
* with Dscrmv being deprecated. However, there are limitations.
|
||||
* While in Cuda < 11, there are specific mentions of using csr2csc involving
|
||||
* transposed matrix products with dcsrm*,
|
||||
* they are not present in SpMV interface.
|
||||
*/
|
||||
hypre_CSRMatrixMatvecCusparseNewAPI(trans, alpha, A, x, beta, y, offset);
|
||||
#else
|
||||
hypre_CSRMatrixMatvecCusparseOldAPI(trans, alpha, A, x, beta, y, offset);
|
||||
hypre_CSRMatrixMatvecCusparseOldAPI(trans, alpha, A, x, beta, y, offset);
|
||||
#endif
|
||||
#elif defined(HYPRE_USING_DEVICE_OPENMP)
|
||||
hypre_CSRMatrixMatvecOMPOffload(trans, alpha, A, x, beta, y, offset);
|
||||
hypre_CSRMatrixMatvecOMPOffload(trans, alpha, A, x, beta, y, offset);
|
||||
#elif defined(HYPRE_USING_ROCSPARSE)
|
||||
hypre_CSRMatrixMatvecRocsparse(trans, alpha, A, x, beta, y, offset);
|
||||
hypre_CSRMatrixMatvecRocsparse(trans, alpha, A, x, beta, y, offset);
|
||||
#elif defined(HYPRE_USING_ONEMKLSPARSE)
|
||||
hypre_CSRMatrixMatvecOnemklsparse(trans, alpha, A, x, beta, y, offset);
|
||||
#else // #ifdef HYPRE_USING_CUSPARSE
|
||||
#error HYPRE SPMV TODO
|
||||
hypre_CSRMatrixMatvecOnemklsparse(trans, alpha, A, x, beta, y, offset);
|
||||
#else
|
||||
hypre_error_w_msg(HYPRE_ERROR_GENERIC, "Error: No TPL SpMV support configured");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
hypre_CSRMatrixSpMVDevice(alpha, A, x, beta, y, NULL, 0);
|
||||
}
|
||||
|
||||
return hypre_error_flag;
|
||||
}
|
||||
|
||||
@ -286,6 +286,7 @@ main( hypre_int argc,
|
||||
coarsen_type = 8;
|
||||
mod_rap2 = 1;
|
||||
HYPRE_Int spgemm_use_cusparse = 0;
|
||||
HYPRE_Int spmv_use_cusparse = 1;
|
||||
HYPRE_Int use_curand = 1;
|
||||
#if defined(HYPRE_USING_HIP)
|
||||
spgemm_use_cusparse = 1;
|
||||
@ -1190,6 +1191,11 @@ main( hypre_int argc,
|
||||
arg_index++;
|
||||
spgemm_use_cusparse = atoi(argv[arg_index++]);
|
||||
}
|
||||
else if ( strcmp(argv[arg_index], "-mv_cusparse") == 0 )
|
||||
{
|
||||
arg_index++;
|
||||
spmv_use_cusparse = atoi(argv[arg_index++]);
|
||||
}
|
||||
else if ( strcmp(argv[arg_index], "-spgemm_alg") == 0 )
|
||||
{
|
||||
arg_index++;
|
||||
@ -2321,6 +2327,7 @@ main( hypre_int argc,
|
||||
HYPRE_SetExecutionPolicy(default_exec_policy);
|
||||
|
||||
#if defined(HYPRE_USING_GPU)
|
||||
ierr = HYPRE_SetSpMVUseCusparse(spmv_use_cusparse); hypre_assert(ierr == 0);
|
||||
/* use cuSPARSE for SpGEMM */
|
||||
ierr = HYPRE_SetSpGemmUseCusparse(spgemm_use_cusparse); hypre_assert(ierr == 0);
|
||||
ierr = hypre_SetSpGemmAlgorithm(spgemm_alg); hypre_assert(ierr == 0);
|
||||
|
||||
@ -13,6 +13,15 @@
|
||||
|
||||
#include "_hypre_utilities.h"
|
||||
|
||||
/*--------------------------------------------------------------------------
|
||||
* HYPRE_SetSpMVUseCusparse
|
||||
*--------------------------------------------------------------------------*/
|
||||
HYPRE_Int
|
||||
HYPRE_SetSpMVUseCusparse( HYPRE_Int use_cusparse )
|
||||
{
|
||||
return hypre_SetSpMVUseCusparse(use_cusparse);
|
||||
}
|
||||
|
||||
/*--------------------------------------------------------------------------
|
||||
* HYPRE_SetSpGemmUseCusparse
|
||||
*--------------------------------------------------------------------------*/
|
||||
|
||||
@ -229,6 +229,7 @@ HYPRE_Int HYPRE_SetGPUMemoryPoolSize(HYPRE_Int bin_growth, HYPRE_Int min_bin, HY
|
||||
* HYPRE handle
|
||||
*--------------------------------------------------------------------------*/
|
||||
|
||||
HYPRE_Int HYPRE_SetSpMVUseCusparse( HYPRE_Int use_cusparse );
|
||||
HYPRE_Int HYPRE_SetSpGemmUseCusparse( HYPRE_Int use_cusparse );
|
||||
HYPRE_Int HYPRE_SetUseGpuRand( HYPRE_Int use_curand );
|
||||
|
||||
|
||||
@ -1814,6 +1814,7 @@ HYPRE_Int hypre_SyncComputeStream(hypre_Handle *hypre_handle);
|
||||
HYPRE_Int hypre_ForceSyncComputeStream(hypre_Handle *hypre_handle);
|
||||
|
||||
/* handle.c */
|
||||
HYPRE_Int hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse );
|
||||
HYPRE_Int hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse );
|
||||
HYPRE_Int hypre_SetSpGemmAlgorithm( HYPRE_Int value );
|
||||
HYPRE_Int hypre_SetSpGemmRownnzEstimateMethod( HYPRE_Int value );
|
||||
|
||||
@ -14,6 +14,16 @@
|
||||
#include "_hypre_utilities.h"
|
||||
#include "_hypre_utilities.hpp"
|
||||
|
||||
/* GPU SpMV */
|
||||
HYPRE_Int
|
||||
hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse )
|
||||
{
|
||||
#if defined(HYPRE_USING_GPU)
|
||||
hypre_HandleSpMVUseCusparse(hypre_handle()) = use_cusparse;
|
||||
#endif
|
||||
return hypre_error_flag;
|
||||
}
|
||||
|
||||
/* GPU SpGemm */
|
||||
HYPRE_Int
|
||||
hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse )
|
||||
|
||||
@ -324,6 +324,7 @@ HYPRE_Int hypre_SyncComputeStream(hypre_Handle *hypre_handle);
|
||||
HYPRE_Int hypre_ForceSyncComputeStream(hypre_Handle *hypre_handle);
|
||||
|
||||
/* handle.c */
|
||||
HYPRE_Int hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse );
|
||||
HYPRE_Int hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse );
|
||||
HYPRE_Int hypre_SetSpGemmAlgorithm( HYPRE_Int value );
|
||||
HYPRE_Int hypre_SetSpGemmRownnzEstimateMethod( HYPRE_Int value );
|
||||
|
||||
Loading…
Reference in New Issue
Block a user