diff --git a/src/seq_mv/csr_matvec_device.c b/src/seq_mv/csr_matvec_device.c index fc523e368..869f92c72 100644 --- a/src/seq_mv/csr_matvec_device.c +++ b/src/seq_mv/csr_matvec_device.c @@ -35,27 +35,34 @@ hypre_CSRMatrixMatvecDevice2( HYPRE_Int trans, "ERROR::x and y are the same pointer in hypre_CSRMatrixMatvecDevice2"); } + if (hypre_HandleSpMVUseCusparse(hypre_handle()) || trans) + { #ifdef HYPRE_USING_CUSPARSE #if CUSPARSE_VERSION >= CUSPARSE_NEWAPI_VERSION - /* Luke E: The generic API is techinically supported on 10.1,10.2 as a preview, - * with Dscrmv being deprecated. However, there are limitations. - * While in Cuda < 11, there are specific mentions of using csr2csc involving - * transposed matrix products with dcsrm*, - * they are not present in SpMV interface. - */ - hypre_CSRMatrixMatvecCusparseNewAPI(trans, alpha, A, x, beta, y, offset); + /* Luke E: The generic API is techinically supported on 10.1,10.2 as a preview, + * with Dscrmv being deprecated. However, there are limitations. + * While in Cuda < 11, there are specific mentions of using csr2csc involving + * transposed matrix products with dcsrm*, + * they are not present in SpMV interface. + */ + hypre_CSRMatrixMatvecCusparseNewAPI(trans, alpha, A, x, beta, y, offset); #else - hypre_CSRMatrixMatvecCusparseOldAPI(trans, alpha, A, x, beta, y, offset); + hypre_CSRMatrixMatvecCusparseOldAPI(trans, alpha, A, x, beta, y, offset); #endif #elif defined(HYPRE_USING_DEVICE_OPENMP) - hypre_CSRMatrixMatvecOMPOffload(trans, alpha, A, x, beta, y, offset); + hypre_CSRMatrixMatvecOMPOffload(trans, alpha, A, x, beta, y, offset); #elif defined(HYPRE_USING_ROCSPARSE) - hypre_CSRMatrixMatvecRocsparse(trans, alpha, A, x, beta, y, offset); + hypre_CSRMatrixMatvecRocsparse(trans, alpha, A, x, beta, y, offset); #elif defined(HYPRE_USING_ONEMKLSPARSE) - hypre_CSRMatrixMatvecOnemklsparse(trans, alpha, A, x, beta, y, offset); -#else // #ifdef HYPRE_USING_CUSPARSE -#error HYPRE SPMV TODO + hypre_CSRMatrixMatvecOnemklsparse(trans, alpha, A, x, beta, y, offset); +#else + hypre_error_w_msg(HYPRE_ERROR_GENERIC, "Error: No TPL SpMV support configured"); #endif + } + else + { + hypre_CSRMatrixSpMVDevice(alpha, A, x, beta, y, NULL, 0); + } return hypre_error_flag; } diff --git a/src/test/ij.c b/src/test/ij.c index f192ad032..e94269f4a 100644 --- a/src/test/ij.c +++ b/src/test/ij.c @@ -286,6 +286,7 @@ main( hypre_int argc, coarsen_type = 8; mod_rap2 = 1; HYPRE_Int spgemm_use_cusparse = 0; + HYPRE_Int spmv_use_cusparse = 1; HYPRE_Int use_curand = 1; #if defined(HYPRE_USING_HIP) spgemm_use_cusparse = 1; @@ -1190,6 +1191,11 @@ main( hypre_int argc, arg_index++; spgemm_use_cusparse = atoi(argv[arg_index++]); } + else if ( strcmp(argv[arg_index], "-mv_cusparse") == 0 ) + { + arg_index++; + spmv_use_cusparse = atoi(argv[arg_index++]); + } else if ( strcmp(argv[arg_index], "-spgemm_alg") == 0 ) { arg_index++; @@ -2321,6 +2327,7 @@ main( hypre_int argc, HYPRE_SetExecutionPolicy(default_exec_policy); #if defined(HYPRE_USING_GPU) + ierr = HYPRE_SetSpMVUseCusparse(spmv_use_cusparse); hypre_assert(ierr == 0); /* use cuSPARSE for SpGEMM */ ierr = HYPRE_SetSpGemmUseCusparse(spgemm_use_cusparse); hypre_assert(ierr == 0); ierr = hypre_SetSpGemmAlgorithm(spgemm_alg); hypre_assert(ierr == 0); diff --git a/src/utilities/HYPRE_handle.c b/src/utilities/HYPRE_handle.c index d8c69a24c..d1bce34f2 100644 --- a/src/utilities/HYPRE_handle.c +++ b/src/utilities/HYPRE_handle.c @@ -13,6 +13,15 @@ #include "_hypre_utilities.h" +/*-------------------------------------------------------------------------- + * HYPRE_SetSpMVUseCusparse + *--------------------------------------------------------------------------*/ +HYPRE_Int +HYPRE_SetSpMVUseCusparse( HYPRE_Int use_cusparse ) +{ + return hypre_SetSpMVUseCusparse(use_cusparse); +} + /*-------------------------------------------------------------------------- * HYPRE_SetSpGemmUseCusparse *--------------------------------------------------------------------------*/ diff --git a/src/utilities/HYPRE_utilities.h b/src/utilities/HYPRE_utilities.h index 5dc0ff6a1..571629281 100644 --- a/src/utilities/HYPRE_utilities.h +++ b/src/utilities/HYPRE_utilities.h @@ -229,6 +229,7 @@ HYPRE_Int HYPRE_SetGPUMemoryPoolSize(HYPRE_Int bin_growth, HYPRE_Int min_bin, HY * HYPRE handle *--------------------------------------------------------------------------*/ +HYPRE_Int HYPRE_SetSpMVUseCusparse( HYPRE_Int use_cusparse ); HYPRE_Int HYPRE_SetSpGemmUseCusparse( HYPRE_Int use_cusparse ); HYPRE_Int HYPRE_SetUseGpuRand( HYPRE_Int use_curand ); diff --git a/src/utilities/_hypre_utilities.h b/src/utilities/_hypre_utilities.h index 38b4f5fa4..27af159a3 100644 --- a/src/utilities/_hypre_utilities.h +++ b/src/utilities/_hypre_utilities.h @@ -1814,6 +1814,7 @@ HYPRE_Int hypre_SyncComputeStream(hypre_Handle *hypre_handle); HYPRE_Int hypre_ForceSyncComputeStream(hypre_Handle *hypre_handle); /* handle.c */ +HYPRE_Int hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse ); HYPRE_Int hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse ); HYPRE_Int hypre_SetSpGemmAlgorithm( HYPRE_Int value ); HYPRE_Int hypre_SetSpGemmRownnzEstimateMethod( HYPRE_Int value ); diff --git a/src/utilities/handle.c b/src/utilities/handle.c index 4f484080b..6c71cd70d 100644 --- a/src/utilities/handle.c +++ b/src/utilities/handle.c @@ -14,6 +14,16 @@ #include "_hypre_utilities.h" #include "_hypre_utilities.hpp" +/* GPU SpMV */ +HYPRE_Int +hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse ) +{ +#if defined(HYPRE_USING_GPU) + hypre_HandleSpMVUseCusparse(hypre_handle()) = use_cusparse; +#endif + return hypre_error_flag; +} + /* GPU SpGemm */ HYPRE_Int hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse ) diff --git a/src/utilities/protos.h b/src/utilities/protos.h index a45110982..68658ab95 100644 --- a/src/utilities/protos.h +++ b/src/utilities/protos.h @@ -324,6 +324,7 @@ HYPRE_Int hypre_SyncComputeStream(hypre_Handle *hypre_handle); HYPRE_Int hypre_ForceSyncComputeStream(hypre_Handle *hypre_handle); /* handle.c */ +HYPRE_Int hypre_SetSpMVUseCusparse( HYPRE_Int use_cusparse ); HYPRE_Int hypre_SetSpGemmUseCusparse( HYPRE_Int use_cusparse ); HYPRE_Int hypre_SetSpGemmAlgorithm( HYPRE_Int value ); HYPRE_Int hypre_SetSpGemmRownnzEstimateMethod( HYPRE_Int value );