Cusparse spmv (#512)

This PR removes frequent GPU malloc/free in CSRMatvec with cuSPARSE 11. See #507.
This commit is contained in:
Ruipeng Li 2021-11-01 10:33:52 -07:00 committed by GitHub
parent 5262bff461
commit 7f2762cffb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 48 additions and 27 deletions

View File

@ -102,7 +102,9 @@ hypre_GpuMatDataDestroy(hypre_GpuMatData *data)
HYPRE_ROCSPARSE_CALL( rocsparse_destroy_mat_info(hypre_GpuMatDataMatInfo(data)) );
#endif
hypre_TFree(data, HYPRE_MEMORY_HOST);
hypre_TFree(hypre_GpuMatDataSpMVBuffer(data), HYPRE_MEMORY_DEVICE);
hypre_TFree(data, HYPRE_MEMORY_HOST);
}
#endif /* #if defined(HYPRE_USING_CUSPARSE) || defined(HYPRE_USING_ROCSPARSE) */

View File

@ -49,7 +49,7 @@ typedef struct
HYPRE_Int *sorted_j; /* some cusparse routines require sorted CSR */
HYPRE_Complex *sorted_data;
hypre_CsrsvData *csrsv_data;
hypre_GpuMatData *mat_data;
hypre_GpuMatData *mat_data;
#endif
} hypre_CSRMatrix;

View File

@ -153,24 +153,33 @@ hypre_CSRMatrixMatvecCusparseNewAPI( HYPRE_Int trans,
/* SpMV */
size_t bufferSize = 0;
char *dBuffer = NULL;
char *dBuffer = hypre_CSRMatrixGPUMatSpMVBuffer(A);
HYPRE_Int x_size_override = trans ? hypre_CSRMatrixNumRows(A) : hypre_CSRMatrixNumCols(A);
HYPRE_Int y_size_override = trans ? hypre_CSRMatrixNumCols(A) : hypre_CSRMatrixNumRows(A);
cusparseDnVecDescr_t vecX = hypre_VectorToCusparseDnVec(x, 0, x_size_override);
cusparseDnVecDescr_t vecY = hypre_VectorToCusparseDnVec(y, offset, y_size_override - offset);
HYPRE_CUSPARSE_CALL( cusparseSpMV_bufferSize(handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha,
matA,
vecX,
&beta,
vecY,
data_type,
CUSPARSE_CSRMV_ALG2,
&bufferSize) );
if (!dBuffer)
{
HYPRE_CUSPARSE_CALL( cusparseSpMV_bufferSize(handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha,
matA,
vecX,
&beta,
vecY,
data_type,
#if CUSPARSE_SPMV_CSR_ALG2 >= 11200
CUSPARSE_SPMV_CSR_ALG2,
#else
CUSPARSE_CSRMV_ALG2,
#endif
&bufferSize) );
dBuffer = hypre_TAlloc(char, bufferSize, HYPRE_MEMORY_DEVICE);
dBuffer = hypre_TAlloc(char, bufferSize, HYPRE_MEMORY_DEVICE);
hypre_CSRMatrixGPUMatSpMVBuffer(A) = dBuffer;
}
HYPRE_CUSPARSE_CALL( cusparseSpMV(handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
@ -180,7 +189,11 @@ hypre_CSRMatrixMatvecCusparseNewAPI( HYPRE_Int trans,
&beta,
vecY,
data_type,
#if CUSPARSE_SPMV_CSR_ALG2 >= 11200
CUSPARSE_SPMV_CSR_ALG2,
#else
CUSPARSE_CSRMV_ALG2,
#endif
dBuffer) );
hypre_SyncCudaComputeStream(hypre_handle());
@ -189,11 +202,11 @@ hypre_CSRMatrixMatvecCusparseNewAPI( HYPRE_Int trans,
{
hypre_CSRMatrixDestroy(AT);
}
hypre_TFree(dBuffer, HYPRE_MEMORY_DEVICE);
/* This function releases the host memory allocated for the sparse matrix descriptor */
HYPRE_CUSPARSE_CALL(cusparseDestroySpMat(matA));
HYPRE_CUSPARSE_CALL(cusparseDestroyDnVec(vecX));
HYPRE_CUSPARSE_CALL(cusparseDestroyDnVec(vecY));
HYPRE_CUSPARSE_CALL( cusparseDestroySpMat(matA) );
HYPRE_CUSPARSE_CALL( cusparseDestroyDnVec(vecX) );
HYPRE_CUSPARSE_CALL( cusparseDestroyDnVec(vecY) );
return hypre_error_flag;
}

View File

@ -221,7 +221,8 @@ void hypre_CsrsvDataDestroy(hypre_CsrsvData *data);
hypre_GpuMatData* hypre_GpuMatDataCreate();
void hypre_GpuMatDataDestroy(hypre_GpuMatData *data);
hypre_GpuMatData* hypre_CSRMatrixGetGPUMatData(hypre_CSRMatrix *matrix);
#define hypre_CSRMatrixGPUMatDescr(matrix) ( hypre_GpuMatDataMatDecsr(hypre_CSRMatrixGetGPUMatData(matrix)) )
#define hypre_CSRMatrixGPUMatInfo(matrix) ( hypre_GpuMatDataMatInfo (hypre_CSRMatrixGetGPUMatData(matrix)) )
#define hypre_CSRMatrixGPUMatDescr(matrix) ( hypre_GpuMatDataMatDecsr(hypre_CSRMatrixGetGPUMatData(matrix)) )
#define hypre_CSRMatrixGPUMatInfo(matrix) ( hypre_GpuMatDataMatInfo (hypre_CSRMatrixGetGPUMatData(matrix)) )
#define hypre_CSRMatrixGPUMatSpMVBuffer(matrix) ( hypre_GpuMatDataSpMVBuffer (hypre_CSRMatrixGetGPUMatData(matrix)) )
#endif
void hypre_CSRMatrixGpuSpMVAnalysis(hypre_CSRMatrix *matrix);

View File

@ -69,7 +69,7 @@ typedef struct
HYPRE_Int *sorted_j; /* some cusparse routines require sorted CSR */
HYPRE_Complex *sorted_data;
hypre_CsrsvData *csrsv_data;
hypre_GpuMatData *mat_data;
hypre_GpuMatData *mat_data;
#endif
} hypre_CSRMatrix;
@ -493,8 +493,9 @@ void hypre_CsrsvDataDestroy(hypre_CsrsvData *data);
hypre_GpuMatData* hypre_GpuMatDataCreate();
void hypre_GpuMatDataDestroy(hypre_GpuMatData *data);
hypre_GpuMatData* hypre_CSRMatrixGetGPUMatData(hypre_CSRMatrix *matrix);
#define hypre_CSRMatrixGPUMatDescr(matrix) ( hypre_GpuMatDataMatDecsr(hypre_CSRMatrixGetGPUMatData(matrix)) )
#define hypre_CSRMatrixGPUMatInfo(matrix) ( hypre_GpuMatDataMatInfo (hypre_CSRMatrixGetGPUMatData(matrix)) )
#define hypre_CSRMatrixGPUMatDescr(matrix) ( hypre_GpuMatDataMatDecsr(hypre_CSRMatrixGetGPUMatData(matrix)) )
#define hypre_CSRMatrixGPUMatInfo(matrix) ( hypre_GpuMatDataMatInfo (hypre_CSRMatrixGetGPUMatData(matrix)) )
#define hypre_CSRMatrixGPUMatSpMVBuffer(matrix) ( hypre_GpuMatDataSpMVBuffer (hypre_CSRMatrixGetGPUMatData(matrix)) )
#endif
void hypre_CSRMatrixGpuSpMVAnalysis(hypre_CSRMatrix *matrix);

View File

@ -316,6 +316,7 @@ struct hypre_GpuMatData
{
#if defined(HYPRE_USING_CUSPARSE)
cusparseMatDescr_t mat_descr;
char *spmv_buffer;
#endif
#if defined(HYPRE_USING_ROCSPARSE)
@ -324,8 +325,9 @@ struct hypre_GpuMatData
#endif
};
#define hypre_GpuMatDataMatDecsr(data) ((data) -> mat_descr)
#define hypre_GpuMatDataMatInfo(data) ((data) -> mat_info)
#define hypre_GpuMatDataMatDecsr(data) ((data) -> mat_descr)
#define hypre_GpuMatDataMatInfo(data) ((data) -> mat_info)
#define hypre_GpuMatDataSpMVBuffer(data) ((data) -> spmv_buffer)
#endif //#if defined(HYPRE_USING_GPU)

View File

@ -264,6 +264,7 @@ struct hypre_GpuMatData
{
#if defined(HYPRE_USING_CUSPARSE)
cusparseMatDescr_t mat_descr;
char *spmv_buffer;
#endif
#if defined(HYPRE_USING_ROCSPARSE)
@ -272,8 +273,9 @@ struct hypre_GpuMatData
#endif
};
#define hypre_GpuMatDataMatDecsr(data) ((data) -> mat_descr)
#define hypre_GpuMatDataMatInfo(data) ((data) -> mat_info)
#define hypre_GpuMatDataMatDecsr(data) ((data) -> mat_descr)
#define hypre_GpuMatDataMatInfo(data) ((data) -> mat_info)
#define hypre_GpuMatDataSpMVBuffer(data) ((data) -> spmv_buffer)
#endif //#if defined(HYPRE_USING_GPU)