Fixes for sycl. Still debugging incorrect results.

This commit is contained in:
Wayne Mitchell 2022-06-08 18:32:49 +00:00
parent f867d600b2
commit d8c6556e7e
5 changed files with 49 additions and 6 deletions

View File

@ -941,6 +941,11 @@ hypre_ParcsrGetExternalRowsDeviceWait(void *vrequest)
return A_ext;
}
#endif // defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
HYPRE_Int
hypre_ParCSRCommPkgCreateMatrixE( hypre_ParCSRCommPkg *comm_pkg,
HYPRE_Int local_ncols )
@ -977,11 +982,6 @@ hypre_ParCSRCommPkgCreateMatrixE( hypre_ParCSRCommPkg *comm_pkg,
return hypre_error_flag;
}
#endif // defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
hypre_CSRMatrix*
hypre_MergeDiagAndOffdDevice(hypre_ParCSRMatrix *A)
{

View File

@ -780,6 +780,14 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R,
HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_HOST);
/* add Cext to local part of Cbar */
/* WM: debug */
HYPRE_Int my_id;
hypre_MPI_Comm_rank(hypre_MPI_COMM_WORLD, &my_id);
if (my_id == 0)
{
hypre_CSRMatrixPrint(Cbar, "Cbar");
hypre_CSRMatrixPrint(Cext, "Cext");
}
hypre_ParCSRTMatMatPartialAddDevice(hypre_ParCSRMatrixCommPkg(R),
hypre_ParCSRMatrixNumCols(R),
hypre_ParCSRMatrixNumCols(P),
@ -794,6 +802,12 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R,
&C_offd,
&num_cols_offd_C,
&col_map_offd_C);
/* WM: debug */
if (my_id == 0)
{
hypre_CSRMatrixPrint(C_diag, "C_diag");
hypre_CSRMatrixPrint(C_offd, "C_offd");
}
hypre_TFree(col_map_offd, HYPRE_MEMORY_DEVICE);
}
@ -973,7 +987,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
oneapi::dpl::make_zip_iterator(work, big_work),
pred1 );
HYPRE_Int Cext_diag_nnz = thrust::get<0>(dia_end.base()) - work;
HYPRE_Int Cext_diag_nnz = std::get<0>(dia_end.base()) - work;
#else
auto dia_end = HYPRE_THRUST_CALL( copy_if,
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), Cext_bigj)),
@ -1049,10 +1063,19 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
HYPRE_Int *ie_ii = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE);
HYPRE_Int *ie_j = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE);
#if defined(HYPRE_USING_SYCL)
hypreSycl_sequence(ie_ii, ie_ii + num_rows, 0);
HYPRE_ONEDPL_CALL( std::copy, send_map, send_map + num_elemt, ie_ii + num_rows);
hypreSycl_sequence(ie_j, ie_j + num_rows + num_elemt, 0);
auto zipped_begin = oneapi::dpl::make_zip_iterator(ie_ii, ie_j);
HYPRE_ONEDPL_CALL( std::stable_sort, zipped_begin, zipped_begin + num_rows + num_elemt,
[](auto lhs, auto rhs) { return std::get<0>(lhs) < std::get<0>(rhs); } );
#else
HYPRE_THRUST_CALL( sequence, ie_ii, ie_ii + num_rows);
HYPRE_THRUST_CALL( copy, send_map, send_map + num_elemt, ie_ii + num_rows);
HYPRE_THRUST_CALL( sequence, ie_j, ie_j + num_rows + num_elemt);
HYPRE_THRUST_CALL( stable_sort_by_key, ie_ii, ie_ii + num_rows + num_elemt, ie_j );
#endif
HYPRE_Int *ie_i = hypreDevice_CsrRowIndicesToPtrs(num_rows, num_rows + num_elemt, ie_ii);
hypre_TFree(ie_ii, HYPRE_MEMORY_DEVICE);

View File

@ -6,6 +6,8 @@
******************************************************************************/
#include "seq_mv.h"
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
#define HYPRE_SPGEMM_DEVICE_USE_DSHMEM
#include <csr_spgemm_device_numer.h>
@ -24,3 +26,4 @@ hypre_spgemm_numerical_max_num_blocks
< HYPRE_SPGEMM_NUMER_HASH_SIZE * 32, HYPRE_SPGEMM_BASE_GROUP_SIZE * 32 >
( HYPRE_Int multiProcessorCount, HYPRE_Int *num_blocks_ptr, HYPRE_Int *block_size_ptr );
#endif /* HYPRE_USING_CUDA || defined(HYPRE_USING_HIP) */

View File

@ -31,6 +31,7 @@ hypreDevice_CSRSpGemmOnemklsparse(HYPRE_Int m,
HYPRE_Int **d_jc_out,
HYPRE_Complex **d_c_out)
{
hypre_printf("WM: debug - using oneMKL spgemm\n");
std::int64_t *tmp_size1 = NULL, *tmp_size2, *nnzC = NULL;
void *tmp_buffer1 = NULL;
void *tmp_buffer2 = NULL;

View File

@ -137,6 +137,22 @@ OutputIter hypreSycl_gather(InputIter1 map_first, InputIter1 map_last,
return HYPRE_ONEDPL_CALL( oneapi::dpl::copy, perm_begin, perm_begin + n, result);
}
// Equivalent of thrust::sequence (with step=1)
template <class Iter, class T>
void hypreSycl_sequence(Iter first, Iter last, T init = 0)
{
static_assert(
std::is_same<typename std::iterator_traits<Iter>::iterator_category,
std::random_access_iterator_tag>::value,
"Iterators passed to algorithms must be random-access iterators.");
using DiffType = typename std::iterator_traits<Iter>::difference_type;
HYPRE_ONEDPL_CALL( std::transform,
oneapi::dpl::counting_iterator<DiffType>(init),
oneapi::dpl::counting_iterator<DiffType>(std::distance(first, last)),
first,
[](auto i) { return i; });
}
#endif
#endif