diff --git a/src/seq_mv/csr_matvec_device.c b/src/seq_mv/csr_matvec_device.c index f44acd161..1c7e9633e 100644 --- a/src/seq_mv/csr_matvec_device.c +++ b/src/seq_mv/csr_matvec_device.c @@ -17,9 +17,14 @@ #if defined(HYPRE_USING_GPU) -#if CUSPARSE_VERSION >= CUSPARSE_NEWAPI_VERSION +#if CUSPARSE_VERSION >= CUSPARSE_NEWSPMM_VERSION #define HYPRE_CUSPARSE_SPMV_ALG CUSPARSE_SPMV_CSR_ALG2 #define HYPRE_CUSPARSE_SPMM_ALG CUSPARSE_SPMM_CSR_ALG3 + +#elif CUSPARSE_VERSION >= CUSPARSE_NEWAPI_VERSION +#define HYPRE_CUSPARSE_SPMV_ALG CUSPARSE_CSRMV_ALG2 +#define HYPRE_CUSPARSE_SPMM_ALG CUSPARSE_SPMM_CSR_ALG1 + #else #define HYPRE_CUSPARSE_SPMV_ALG CUSPARSE_CSRMV_ALG2 #define HYPRE_CUSPARSE_SPMM_ALG CUSPARSE_CSRMM_ALG1 @@ -265,6 +270,7 @@ hypre_CSRMatrixMatvecCusparseNewAPI( HYPRE_Int trans, dBuffer = hypre_TAlloc(char, bufferSize, HYPRE_MEMORY_DEVICE); hypre_CSRMatrixGPUMatSpMVBuffer(A) = dBuffer; +#if CUSPARSE_VERSION >= CUSPARSE_NEWSPMM_VERSION if (num_vectors > 1) { HYPRE_CUSPARSE_CALL( cusparseSpMM_preprocess(handle, @@ -279,6 +285,7 @@ hypre_CSRMatrixMatvecCusparseNewAPI( HYPRE_Int trans, HYPRE_CUSPARSE_SPMM_ALG, dBuffer) ); } +#endif } if (num_vectors == 1) diff --git a/src/utilities/_hypre_utilities.hpp b/src/utilities/_hypre_utilities.hpp index 5a8d9f821..0a6a3a521 100644 --- a/src/utilities/_hypre_utilities.hpp +++ b/src/utilities/_hypre_utilities.hpp @@ -96,6 +96,7 @@ using hypre_DeviceItem = void*; #endif #define CUSPARSE_NEWAPI_VERSION 11000 +#define CUSPARSE_NEWSPMM_VERSION 11201 #define CUDA_MALLOCASYNC_VERSION 11020 #define THRUST_CALL_BLOCKING 1 diff --git a/src/utilities/device_utils.h b/src/utilities/device_utils.h index 10bd38102..d4bfc77a2 100644 --- a/src/utilities/device_utils.h +++ b/src/utilities/device_utils.h @@ -39,6 +39,7 @@ using hypre_DeviceItem = void*; #endif #define CUSPARSE_NEWAPI_VERSION 11000 +#define CUSPARSE_NEWSPMM_VERSION 11201 #define CUDA_MALLOCASYNC_VERSION 11020 #define THRUST_CALL_BLOCKING 1