This commit is contained in:
Ruipeng Li 2022-06-06 13:23:32 -07:00
parent 4efa15836d
commit 0fee4f3c80
3 changed files with 78 additions and 49 deletions

View File

@ -133,8 +133,9 @@ HYPRE_Int hypreDevice_CSRSpGemmBinnedGetBlockNumDim()
#if defined(HYPRE_SPGEMM_PRINTF)
HYPRE_SPGEMM_PRINT("===========================================================================\n");
HYPRE_SPGEMM_PRINT("SM count %d\n", multiProcessorCount);
HYPRE_SPGEMM_PRINT("Highest Bin Symbl %d, Numer %d\n", hypre_HandleSpgemmHighestBin(hypre_handle())[0],
hypre_HandleSpgemmHighestBin(hypre_handle())[1]);
HYPRE_SPGEMM_PRINT("Highest Bin Symbl %d, Numer %d\n",
hypre_HandleSpgemmHighestBin(hypre_handle())[0],
hypre_HandleSpgemmHighestBin(hypre_handle())[1]);
HYPRE_SPGEMM_PRINT("---------------------------------------------------------------------------\n");
HYPRE_SPGEMM_PRINT("Bin: ");
for (HYPRE_Int i = 0; i < num_bins + 1; i++) { HYPRE_SPGEMM_PRINT("%5d ", i); } HYPRE_SPGEMM_PRINT("\n");

View File

@ -123,27 +123,37 @@ hypreDevice_CSRSpGemmNumerWithRownnzUpperboundBinned( HYPRE_Int m,
hypre_SpGemmCreateBins(m, s, t, u, d_rc, false, d_rind, h_bin_ptr);
#if 0
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 1, HYPRE_SPGEMM_NUMER_HASH_SIZE / 16, /* 16, 2 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 1,
HYPRE_SPGEMM_NUMER_HASH_SIZE / 16, /* 16, 2 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 16, exact_rownnz, false);
#endif
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 2, HYPRE_SPGEMM_NUMER_HASH_SIZE / 8, /* 32, 4 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 2,
HYPRE_SPGEMM_NUMER_HASH_SIZE / 8, /* 32, 4 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 8, exact_rownnz, false);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 3, HYPRE_SPGEMM_NUMER_HASH_SIZE / 4, /* 64, 8 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 3,
HYPRE_SPGEMM_NUMER_HASH_SIZE / 4, /* 64, 8 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 4, exact_rownnz, false);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 4, HYPRE_SPGEMM_NUMER_HASH_SIZE / 2, /* 128, 16 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 4,
HYPRE_SPGEMM_NUMER_HASH_SIZE / 2, /* 128, 16 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 2, exact_rownnz, false);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 5, HYPRE_SPGEMM_NUMER_HASH_SIZE, /* 256, 32 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 5,
HYPRE_SPGEMM_NUMER_HASH_SIZE, /* 256, 32 */
HYPRE_SPGEMM_BASE_GROUP_SIZE, exact_rownnz, false);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 6, HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, /* 512, 64 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 6,
HYPRE_SPGEMM_NUMER_HASH_SIZE * 2, /* 512, 64 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 2, exact_rownnz, false);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 7, HYPRE_SPGEMM_NUMER_HASH_SIZE * 4, /* 1024, 128 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 7,
HYPRE_SPGEMM_NUMER_HASH_SIZE * 4, /* 1024, 128 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 4, exact_rownnz, false);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 8, HYPRE_SPGEMM_NUMER_HASH_SIZE * 8, /* 2048, 256 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 8,
HYPRE_SPGEMM_NUMER_HASH_SIZE * 8, /* 2048, 256 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 8, exact_rownnz, false);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 9, HYPRE_SPGEMM_NUMER_HASH_SIZE * 16, /* 4096, 512 */
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 9,
HYPRE_SPGEMM_NUMER_HASH_SIZE * 16, /* 4096, 512 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 16, exact_rownnz, 9 == high_bin);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED(10, HYPRE_SPGEMM_NUMER_HASH_SIZE * 32, /* 8192, 1024 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 32, exact_rownnz, true);
HYPRE_SPGEMM_NUMERICAL_WITH_ROWNNZ_BINNED( 10,
HYPRE_SPGEMM_NUMER_HASH_SIZE * 32, /* 8192, 1024 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 32, exact_rownnz, true);
if (!exact_rownnz)
{

View File

@ -72,22 +72,30 @@ hypreDevice_CSRSpGemmRownnzUpperboundBinned( HYPRE_Int m,
hypre_SpGemmCreateBins(m, s, t, u, d_rc, false, d_rind, h_bin_ptr);
HYPRE_SPGEMM_ROWNNZ_BINNED( 3, HYPRE_SPGEMM_SYMBL_HASH_SIZE / 4, /* 128, 8 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 4, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 4, HYPRE_SPGEMM_SYMBL_HASH_SIZE / 2, /* 256, 16 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 2, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 5, HYPRE_SPGEMM_SYMBL_HASH_SIZE, /* 512, 32 */
HYPRE_SPGEMM_BASE_GROUP_SIZE, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 6, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 2, /* 1024, 64 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 2, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 7, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 4, /* 2048, 128 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 4, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 8, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 8, /* 4096, 256 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 8, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 9, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 16, /* 8192, 512 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 16, 9 == high_bin, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 10, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 32, /* 16384, 1024 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 32, true, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 3,
HYPRE_SPGEMM_SYMBL_HASH_SIZE / 4, /* 128, 8 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 4, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 4,
HYPRE_SPGEMM_SYMBL_HASH_SIZE / 2, /* 256, 16 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 2, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 5,
HYPRE_SPGEMM_SYMBL_HASH_SIZE, /* 512, 32 */
HYPRE_SPGEMM_BASE_GROUP_SIZE, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 6,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 2, /* 1024, 64 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 2, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 7,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 4, /* 2048, 128 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 4, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 8,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 8, /* 4096, 256 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 8, false, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 9,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 16, /* 8192, 512 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 16, 9 == high_bin, CAN_FAIL, d_rf);
HYPRE_SPGEMM_ROWNNZ_BINNED( 10,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 32, /* 16384, 1024 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 32, true, CAN_FAIL, d_rf);
hypre_TFree(d_rind, HYPRE_MEMORY_DEVICE);
@ -268,21 +276,26 @@ hypreDevice_CSRSpGemmRownnzBinned( HYPRE_Int m,
hypre_SpGemmCreateBins(m, s, t, u, d_rc, false, d_rind, h_bin_ptr);
HYPRE_SPGEMM_ROWNNZ_BINNED( 1, HYPRE_SPGEMM_SYMBL_HASH_SIZE / 16, /* 32, 2 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 16, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 2, HYPRE_SPGEMM_SYMBL_HASH_SIZE / 8, /* 64, 4 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 8, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 3, HYPRE_SPGEMM_SYMBL_HASH_SIZE / 4, /* 128, 8 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 4, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 4, HYPRE_SPGEMM_SYMBL_HASH_SIZE / 2, /* 256, 16 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 2, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 1,
HYPRE_SPGEMM_SYMBL_HASH_SIZE / 16, /* 32, 2 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 16, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 2,
HYPRE_SPGEMM_SYMBL_HASH_SIZE / 8, /* 64, 4 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 8, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 3,
HYPRE_SPGEMM_SYMBL_HASH_SIZE / 4, /* 128, 8 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 4, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 4,
HYPRE_SPGEMM_SYMBL_HASH_SIZE / 2, /* 256, 16 */
HYPRE_SPGEMM_BASE_GROUP_SIZE / 2, false, false, NULL);
if (h_bin_ptr[5] > h_bin_ptr[4])
{
char *d_rf = hypre_CTAlloc(char, m, HYPRE_MEMORY_DEVICE);
HYPRE_SPGEMM_ROWNNZ_BINNED( 5, HYPRE_SPGEMM_SYMBL_HASH_SIZE,
HYPRE_SPGEMM_BASE_GROUP_SIZE, false, true, d_rf); /* 512, 32 */
HYPRE_SPGEMM_ROWNNZ_BINNED( 5,
HYPRE_SPGEMM_SYMBL_HASH_SIZE,
HYPRE_SPGEMM_BASE_GROUP_SIZE, false, true, d_rf); /* 512, 32 */
HYPRE_Int num_failed_rows =
HYPRE_THRUST_CALL( reduce,
@ -311,16 +324,21 @@ hypreDevice_CSRSpGemmRownnzBinned( HYPRE_Int m,
hypre_SpGemmCreateBins(num_failed_rows, s, t, u, d_rc, true, d_rind, h_bin_ptr);
HYPRE_SPGEMM_ROWNNZ_BINNED( 6, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 2, /* 1024, 64 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 2, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 7, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 4, /* 2048, 128 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 4, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 8, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 8, /* 4096, 256 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 8, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 9, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 16, /* 8192, 512 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 16, 9 == high_bin, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 10, HYPRE_SPGEMM_SYMBL_HASH_SIZE * 32, /* 16384, 1024 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 32, true, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 6,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 2, /* 1024, 64 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 2, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 7,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 4, /* 2048, 128 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 4, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 8,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 8, /* 4096, 256 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 8, false, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 9,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 16, /* 8192, 512 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 16, 9 == high_bin, false, NULL);
HYPRE_SPGEMM_ROWNNZ_BINNED( 10,
HYPRE_SPGEMM_SYMBL_HASH_SIZE * 32, /* 16384, 1024 */
HYPRE_SPGEMM_BASE_GROUP_SIZE * 32, true, false, NULL);
}
hypre_TFree(d_rf, HYPRE_MEMORY_DEVICE);