From d8c6556e7e57f6741d808439ead07b7b93c47e71 Mon Sep 17 00:00:00 2001 From: Wayne Mitchell Date: Wed, 8 Jun 2022 18:32:49 +0000 Subject: [PATCH 1/4] Fixes for sycl. Still debugging incorrect results. --- src/parcsr_mv/par_csr_matop_device.c | 10 ++++----- src/parcsr_mv/par_csr_triplemat_device.c | 25 ++++++++++++++++++++- src/seq_mv/csr_spgemm_device_numer10.c | 3 +++ src/seq_mv/csr_spgemm_device_onemklsparse.c | 1 + src/utilities/_hypre_onedpl.hpp | 16 +++++++++++++ 5 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/parcsr_mv/par_csr_matop_device.c b/src/parcsr_mv/par_csr_matop_device.c index 30398a4a3..fb45bbf21 100644 --- a/src/parcsr_mv/par_csr_matop_device.c +++ b/src/parcsr_mv/par_csr_matop_device.c @@ -941,6 +941,11 @@ hypre_ParcsrGetExternalRowsDeviceWait(void *vrequest) return A_ext; } + +#endif // defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL) + +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) + HYPRE_Int hypre_ParCSRCommPkgCreateMatrixE( hypre_ParCSRCommPkg *comm_pkg, HYPRE_Int local_ncols ) @@ -977,11 +982,6 @@ hypre_ParCSRCommPkgCreateMatrixE( hypre_ParCSRCommPkg *comm_pkg, return hypre_error_flag; } -#endif // defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL) - -#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) - - hypre_CSRMatrix* hypre_MergeDiagAndOffdDevice(hypre_ParCSRMatrix *A) { diff --git a/src/parcsr_mv/par_csr_triplemat_device.c b/src/parcsr_mv/par_csr_triplemat_device.c index cb9a39306..bf96fbe5b 100644 --- a/src/parcsr_mv/par_csr_triplemat_device.c +++ b/src/parcsr_mv/par_csr_triplemat_device.c @@ -780,6 +780,14 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_HOST); /* add Cext to local part of Cbar */ + /* WM: debug */ + HYPRE_Int my_id; + hypre_MPI_Comm_rank(hypre_MPI_COMM_WORLD, &my_id); + if (my_id == 0) + { + hypre_CSRMatrixPrint(Cbar, "Cbar"); + hypre_CSRMatrixPrint(Cext, "Cext"); + } hypre_ParCSRTMatMatPartialAddDevice(hypre_ParCSRMatrixCommPkg(R), hypre_ParCSRMatrixNumCols(R), hypre_ParCSRMatrixNumCols(P), @@ -794,6 +802,12 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, &C_offd, &num_cols_offd_C, &col_map_offd_C); + /* WM: debug */ + if (my_id == 0) + { + hypre_CSRMatrixPrint(C_diag, "C_diag"); + hypre_CSRMatrixPrint(C_offd, "C_offd"); + } hypre_TFree(col_map_offd, HYPRE_MEMORY_DEVICE); } @@ -973,7 +987,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, oneapi::dpl::make_zip_iterator(work, big_work), pred1 ); - HYPRE_Int Cext_diag_nnz = thrust::get<0>(dia_end.base()) - work; + HYPRE_Int Cext_diag_nnz = std::get<0>(dia_end.base()) - work; #else auto dia_end = HYPRE_THRUST_CALL( copy_if, thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), Cext_bigj)), @@ -1049,10 +1063,19 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, HYPRE_Int *ie_ii = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE); HYPRE_Int *ie_j = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE); +#if defined(HYPRE_USING_SYCL) + hypreSycl_sequence(ie_ii, ie_ii + num_rows, 0); + HYPRE_ONEDPL_CALL( std::copy, send_map, send_map + num_elemt, ie_ii + num_rows); + hypreSycl_sequence(ie_j, ie_j + num_rows + num_elemt, 0); + auto zipped_begin = oneapi::dpl::make_zip_iterator(ie_ii, ie_j); + HYPRE_ONEDPL_CALL( std::stable_sort, zipped_begin, zipped_begin + num_rows + num_elemt, + [](auto lhs, auto rhs) { return std::get<0>(lhs) < std::get<0>(rhs); } ); +#else HYPRE_THRUST_CALL( sequence, ie_ii, ie_ii + num_rows); HYPRE_THRUST_CALL( copy, send_map, send_map + num_elemt, ie_ii + num_rows); HYPRE_THRUST_CALL( sequence, ie_j, ie_j + num_rows + num_elemt); HYPRE_THRUST_CALL( stable_sort_by_key, ie_ii, ie_ii + num_rows + num_elemt, ie_j ); +#endif HYPRE_Int *ie_i = hypreDevice_CsrRowIndicesToPtrs(num_rows, num_rows + num_elemt, ie_ii); hypre_TFree(ie_ii, HYPRE_MEMORY_DEVICE); diff --git a/src/seq_mv/csr_spgemm_device_numer10.c b/src/seq_mv/csr_spgemm_device_numer10.c index 229fff518..7eb59c2a7 100644 --- a/src/seq_mv/csr_spgemm_device_numer10.c +++ b/src/seq_mv/csr_spgemm_device_numer10.c @@ -6,6 +6,8 @@ ******************************************************************************/ #include "seq_mv.h" +#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) + #define HYPRE_SPGEMM_DEVICE_USE_DSHMEM #include @@ -24,3 +26,4 @@ hypre_spgemm_numerical_max_num_blocks < HYPRE_SPGEMM_NUMER_HASH_SIZE * 32, HYPRE_SPGEMM_BASE_GROUP_SIZE * 32 > ( HYPRE_Int multiProcessorCount, HYPRE_Int *num_blocks_ptr, HYPRE_Int *block_size_ptr ); +#endif /* HYPRE_USING_CUDA || defined(HYPRE_USING_HIP) */ diff --git a/src/seq_mv/csr_spgemm_device_onemklsparse.c b/src/seq_mv/csr_spgemm_device_onemklsparse.c index a15ef161d..7b3c530bb 100644 --- a/src/seq_mv/csr_spgemm_device_onemklsparse.c +++ b/src/seq_mv/csr_spgemm_device_onemklsparse.c @@ -31,6 +31,7 @@ hypreDevice_CSRSpGemmOnemklsparse(HYPRE_Int m, HYPRE_Int **d_jc_out, HYPRE_Complex **d_c_out) { + hypre_printf("WM: debug - using oneMKL spgemm\n"); std::int64_t *tmp_size1 = NULL, *tmp_size2, *nnzC = NULL; void *tmp_buffer1 = NULL; void *tmp_buffer2 = NULL; diff --git a/src/utilities/_hypre_onedpl.hpp b/src/utilities/_hypre_onedpl.hpp index b87acbbb9..fd820ef84 100644 --- a/src/utilities/_hypre_onedpl.hpp +++ b/src/utilities/_hypre_onedpl.hpp @@ -137,6 +137,22 @@ OutputIter hypreSycl_gather(InputIter1 map_first, InputIter1 map_last, return HYPRE_ONEDPL_CALL( oneapi::dpl::copy, perm_begin, perm_begin + n, result); } +// Equivalent of thrust::sequence (with step=1) +template +void hypreSycl_sequence(Iter first, Iter last, T init = 0) +{ + static_assert( + std::is_same::iterator_category, + std::random_access_iterator_tag>::value, + "Iterators passed to algorithms must be random-access iterators."); + using DiffType = typename std::iterator_traits::difference_type; + HYPRE_ONEDPL_CALL( std::transform, + oneapi::dpl::counting_iterator(init), + oneapi::dpl::counting_iterator(std::distance(first, last)), + first, + [](auto i) { return i; }); +} + #endif #endif From 8730346c03468e1d63e8dd16ef82b64917b1d28a Mon Sep 17 00:00:00 2001 From: Wayne Mitchell Date: Thu, 9 Jun 2022 19:39:03 +0000 Subject: [PATCH 2/4] More debugging code and astyle. About to sync up with Ruipeng. --- src/parcsr_mv/par_csr_triplemat_device.c | 20 +++++++++++++++++--- src/seq_mv/csr_matrix.c | 10 ++++++---- src/utilities/_hypre_utilities.h | 3 ++- src/utilities/merge_sort.c | 3 ++- src/utilities/protos.h | 3 ++- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/parcsr_mv/par_csr_triplemat_device.c b/src/parcsr_mv/par_csr_triplemat_device.c index 6631ba566..9134fb06a 100644 --- a/src/parcsr_mv/par_csr_triplemat_device.c +++ b/src/parcsr_mv/par_csr_triplemat_device.c @@ -785,7 +785,7 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, /* WM: debug */ HYPRE_Int my_id; hypre_MPI_Comm_rank(hypre_MPI_COMM_WORLD, &my_id); - if (my_id == 0) + if (my_id == 0 && hypre_ParCSRMatrixNumRows(A) > 400) { hypre_CSRMatrixPrint(Cbar, "Cbar"); hypre_CSRMatrixPrint(Cext, "Cext"); @@ -805,7 +805,7 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R, &num_cols_offd_C, &col_map_offd_C); /* WM: debug */ - if (my_id == 0) + if (my_id == 0 && hypre_ParCSRMatrixNumRows(A) > 400) { hypre_CSRMatrixPrint(C_diag, "C_diag"); hypre_CSRMatrixPrint(C_offd, "C_offd"); @@ -927,6 +927,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, // Convert Cext from BigJ to J // Cext offd #if defined(HYPRE_USING_SYCL) + /* WM: debug - the below is suspicious... */ auto off_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0), Cext_bigj), oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0), @@ -982,6 +983,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, // Cext diag #if defined(HYPRE_USING_SYCL) + /* WM: debug - the below is suspicious... */ auto dia_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0), Cext_bigj), oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0), @@ -1106,6 +1108,10 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, t1 = hypre_MPI_Wtime(); #endif hypreDevice_CSRSpGemm(IE, CC, &Cz); + /* WM: debug */ + hypre_CSRMatrixPrint(IE, "IE"); + hypre_CSRMatrixPrint(CC, "CC"); + hypre_CSRMatrixPrint(Cz, "Cz"); hypre_CSRMatrixDestroy(IE); hypre_CSRMatrixDestroy(CC); @@ -1151,6 +1157,10 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, HYPRE_Complex *C_diag_a = hypre_CSRMatrixData(C_diag); #if defined(HYPRE_USING_SYCL) + /* WM: debug */ + hypre_printf("WM: debug - zmp_a = "); + for (auto i = 0; i < 100; i++) hypre_printf("%f ", zmp_a[i]); + hypre_printf("\n"); auto new_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a), oneapi::dpl::make_zip_iterator(zmp_i, zmp_j, zmp_a) + local_nnz_C, zmp_j, @@ -1164,6 +1174,11 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, zmp_j, thrust::make_zip_iterator(thrust::make_tuple(C_diag_ii, C_diag_j, C_diag_a)), pred ); + /* WM: debug */ + hypre_printf("WM: debug - C_diag_a = "); + for (auto i = 0; i < 100; i++) hypre_printf("%f ", C_diag_a[i]); + hypre_printf("\n"); + hypre_assert( std::get<0>(new_end.base()) == C_offd_ii + nnz_C_offd ); hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_diag_ii + nnz_C_diag ); #endif hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_diag), nnz_C_diag, C_diag_ii, @@ -1182,7 +1197,6 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, zmp_j, oneapi::dpl::make_zip_iterator(C_offd_ii, C_offd_j, C_offd_a), std::not_fn(pred) ); - hypre_assert( std::get<0>(new_end.base()) == C_offd_ii + nnz_C_offd ); #else new_end = HYPRE_THRUST_CALL( copy_if, thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)), diff --git a/src/seq_mv/csr_matrix.c b/src/seq_mv/csr_matrix.c index 5f38c01f9..306bb99d7 100644 --- a/src/seq_mv/csr_matrix.c +++ b/src/seq_mv/csr_matrix.c @@ -588,7 +588,8 @@ hypre_CSRMatrixPrintMM( hypre_CSRMatrix *matrix, HYPRE_Int trans, const char *file_name ) { - hypre_assert(hypre_CSRMatrixI(matrix)[hypre_CSRMatrixNumRows(matrix)] == hypre_CSRMatrixNumNonzeros(matrix)); + hypre_assert(hypre_CSRMatrixI(matrix)[hypre_CSRMatrixNumRows(matrix)] == hypre_CSRMatrixNumNonzeros( + matrix)); FILE *fp = file_name ? fopen(file_name, "w") : stdout; @@ -611,9 +612,10 @@ hypre_CSRMatrixPrintMM( hypre_CSRMatrix *matrix, hypre_fprintf(fp, "%%%%MatrixMarket matrix coordinate pattern general\n"); } - hypre_fprintf(fp, "%d %d %d\n", trans ? hypre_CSRMatrixNumCols(matrix) : hypre_CSRMatrixNumRows(matrix), - trans ? hypre_CSRMatrixNumRows(matrix) : hypre_CSRMatrixNumCols(matrix), - hypre_CSRMatrixNumNonzeros(matrix)); + hypre_fprintf(fp, "%d %d %d\n", + trans ? hypre_CSRMatrixNumCols(matrix) : hypre_CSRMatrixNumRows(matrix), + trans ? hypre_CSRMatrixNumRows(matrix) : hypre_CSRMatrixNumCols(matrix), + hypre_CSRMatrixNumNonzeros(matrix)); HYPRE_Int i, j; diff --git a/src/utilities/_hypre_utilities.h b/src/utilities/_hypre_utilities.h index 6468f6228..e907a9d1d 100644 --- a/src/utilities/_hypre_utilities.h +++ b/src/utilities/_hypre_utilities.h @@ -1761,7 +1761,8 @@ typedef struct * 1) Merge sort can take advantage of eliminating duplicates. * 2) Merge sort is more efficiently parallelizable than qsort */ -HYPRE_Int hypre_MergeOrderedArrays( hypre_IntArray *array1, hypre_IntArray *array2, hypre_IntArray *array3 ); +HYPRE_Int hypre_MergeOrderedArrays( hypre_IntArray *array1, hypre_IntArray *array2, + hypre_IntArray *array3 ); void hypre_union2(HYPRE_Int n1, HYPRE_BigInt *arr1, HYPRE_Int n2, HYPRE_BigInt *arr2, HYPRE_Int *n3, HYPRE_BigInt *arr3, HYPRE_Int *map1, HYPRE_Int *map2); void hypre_merge_sort(HYPRE_Int *in, HYPRE_Int *temp, HYPRE_Int len, HYPRE_Int **sorted); diff --git a/src/utilities/merge_sort.c b/src/utilities/merge_sort.c index 80c76ed97..65e8217e4 100644 --- a/src/utilities/merge_sort.c +++ b/src/utilities/merge_sort.c @@ -61,7 +61,8 @@ hypre_MergeOrderedArrays( hypre_IntArray *array1, array3_data[k++] = array2_data[j++]; } - array3_data = hypre_TReAlloc_v2(array3_data, HYPRE_Int, size1 + size2, HYPRE_Int, k, memory_location); + array3_data = hypre_TReAlloc_v2(array3_data, HYPRE_Int, size1 + size2, HYPRE_Int, k, + memory_location); hypre_IntArraySize(array3) = k; hypre_IntArrayData(array3) = array3_data; diff --git a/src/utilities/protos.h b/src/utilities/protos.h index 78797212b..e847339a5 100644 --- a/src/utilities/protos.h +++ b/src/utilities/protos.h @@ -258,7 +258,8 @@ typedef struct * 1) Merge sort can take advantage of eliminating duplicates. * 2) Merge sort is more efficiently parallelizable than qsort */ -HYPRE_Int hypre_MergeOrderedArrays( hypre_IntArray *array1, hypre_IntArray *array2, hypre_IntArray *array3 ); +HYPRE_Int hypre_MergeOrderedArrays( hypre_IntArray *array1, hypre_IntArray *array2, + hypre_IntArray *array3 ); void hypre_union2(HYPRE_Int n1, HYPRE_BigInt *arr1, HYPRE_Int n2, HYPRE_BigInt *arr2, HYPRE_Int *n3, HYPRE_BigInt *arr3, HYPRE_Int *map1, HYPRE_Int *map2); void hypre_merge_sort(HYPRE_Int *in, HYPRE_Int *temp, HYPRE_Int len, HYPRE_Int **sorted); From 7dc8321f4ab4365309c41f879c9fdd7f67105253 Mon Sep 17 00:00:00 2001 From: Wayne Mitchell Date: Fri, 10 Jun 2022 19:02:35 +0000 Subject: [PATCH 3/4] Additional fixes for sycl build. Note that CSR matrix memory location must be set correctly before calling hypre_CSRMatrixSetRownnz. --- src/parcsr_mv/par_csr_matop.c | 10 +++++----- src/parcsr_mv/par_csr_triplemat_device.c | 16 ++++++++++------ src/seq_mv/csr_spgemm_device_onemklsparse.c | 1 - 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/parcsr_mv/par_csr_matop.c b/src/parcsr_mv/par_csr_matop.c index 8473c3142..908c1e471 100644 --- a/src/parcsr_mv/par_csr_matop.c +++ b/src/parcsr_mv/par_csr_matop.c @@ -1175,6 +1175,7 @@ hypre_ParMatmul( hypre_ParCSRMatrix *A, hypre_CSRMatrixData(C_diag) = C_diag_data; hypre_CSRMatrixI(C_diag) = C_diag_i; hypre_CSRMatrixJ(C_diag) = C_diag_j; + hypre_CSRMatrixMemoryLocation(C_diag) = memory_location_C; hypre_CSRMatrixSetRownnz(C_diag); C_offd = hypre_ParCSRMatrixOffd(C); @@ -1186,10 +1187,9 @@ hypre_ParMatmul( hypre_ParCSRMatrix *A, hypre_CSRMatrixJ(C_offd) = C_offd_j; hypre_ParCSRMatrixColMapOffd(C) = col_map_offd_C; } + hypre_CSRMatrixMemoryLocation(C_offd) = memory_location_C; hypre_CSRMatrixSetRownnz(C_offd); - hypre_CSRMatrixMemoryLocation(C_diag) = memory_location_C; - hypre_CSRMatrixMemoryLocation(C_offd) = memory_location_C; /*----------------------------------------------------------------------- * Free various arrays @@ -4009,6 +4009,9 @@ hypre_ParTMatmul( hypre_ParCSRMatrix *A, hypre_ParCSRMatrixRowvalues(C) = NULL; hypre_ParCSRMatrixGetrowactive(C) = 0; + hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixDiag(C)) = memory_location_C; + hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixOffd(C)) = memory_location_C; + if (C_diag) { hypre_CSRMatrixSetRownnz(C_diag); @@ -4029,9 +4032,6 @@ hypre_ParTMatmul( hypre_ParCSRMatrix *A, hypre_ParCSRMatrixOffd(C) = C_tmp_offd; } - hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixDiag(C)) = memory_location_C; - hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixOffd(C)) = memory_location_C; - if (num_cols_offd_C) { HYPRE_Int jj_count_offd, nnz_offd; diff --git a/src/parcsr_mv/par_csr_triplemat_device.c b/src/parcsr_mv/par_csr_triplemat_device.c index a15953412..d67e2c312 100644 --- a/src/parcsr_mv/par_csr_triplemat_device.c +++ b/src/parcsr_mv/par_csr_triplemat_device.c @@ -940,12 +940,16 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, &num_cols_offd_C, &col_map_offd_C, &map_offd_to_C); #if defined(HYPRE_USING_SYCL) - HYPRE_ONEDPL_CALL( oneapi::dpl::lower_bound, - col_map_offd_C, - col_map_offd_C + num_cols_offd_C, - big_work, - big_work + Cext_offd_nnz, - oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work) ); + /* WM: onedpl lower_bound currently does not accept zero length input */ + if (num_cols_offd_C > 0) + { + HYPRE_ONEDPL_CALL( oneapi::dpl::lower_bound, + col_map_offd_C, + col_map_offd_C + num_cols_offd_C, + big_work, + big_work + Cext_offd_nnz, + oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work) ); + } HYPRE_ONEDPL_CALL( std::transform, oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work), diff --git a/src/seq_mv/csr_spgemm_device_onemklsparse.c b/src/seq_mv/csr_spgemm_device_onemklsparse.c index 7b3c530bb..a15ef161d 100644 --- a/src/seq_mv/csr_spgemm_device_onemklsparse.c +++ b/src/seq_mv/csr_spgemm_device_onemklsparse.c @@ -31,7 +31,6 @@ hypreDevice_CSRSpGemmOnemklsparse(HYPRE_Int m, HYPRE_Int **d_jc_out, HYPRE_Complex **d_c_out) { - hypre_printf("WM: debug - using oneMKL spgemm\n"); std::int64_t *tmp_size1 = NULL, *tmp_size2, *nnzC = NULL; void *tmp_buffer1 = NULL; void *tmp_buffer2 = NULL; From 951fa56baa954984e966d66f986a1b53191393ef Mon Sep 17 00:00:00 2001 From: Wayne Mitchell Date: Fri, 10 Jun 2022 19:06:34 +0000 Subject: [PATCH 4/4] Fix copy paste error --- src/parcsr_mv/par_csr_triplemat_device.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parcsr_mv/par_csr_triplemat_device.c b/src/parcsr_mv/par_csr_triplemat_device.c index d67e2c312..b0bbf5162 100644 --- a/src/parcsr_mv/par_csr_triplemat_device.c +++ b/src/parcsr_mv/par_csr_triplemat_device.c @@ -1168,7 +1168,6 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, zmp_j, thrust::make_zip_iterator(thrust::make_tuple(C_diag_ii, C_diag_j, C_diag_a)), pred ); - hypre_assert( std::get<0>(new_end.base()) == C_offd_ii + nnz_C_offd ); hypre_assert( thrust::get<0>(new_end.get_iterator_tuple()) == C_diag_ii + nnz_C_diag ); #endif hypreDevice_CsrRowIndicesToPtrs_v2(hypre_CSRMatrixNumRows(C_diag), nnz_C_diag, C_diag_ii, @@ -1187,6 +1186,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, zmp_j, oneapi::dpl::make_zip_iterator(C_offd_ii, C_offd_j, C_offd_a), std::not_fn(pred) ); + hypre_assert( std::get<0>(new_end.base()) == C_offd_ii + nnz_C_offd ); #else new_end = HYPRE_THRUST_CALL( copy_if, thrust::make_zip_iterator(thrust::make_tuple(zmp_i, zmp_j, zmp_a)),