Fixes for sycl. Still debugging incorrect results.
This commit is contained in:
parent
f867d600b2
commit
d8c6556e7e
@ -941,6 +941,11 @@ hypre_ParcsrGetExternalRowsDeviceWait(void *vrequest)
|
||||
return A_ext;
|
||||
}
|
||||
|
||||
|
||||
#endif // defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
|
||||
|
||||
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
|
||||
|
||||
HYPRE_Int
|
||||
hypre_ParCSRCommPkgCreateMatrixE( hypre_ParCSRCommPkg *comm_pkg,
|
||||
HYPRE_Int local_ncols )
|
||||
@ -977,11 +982,6 @@ hypre_ParCSRCommPkgCreateMatrixE( hypre_ParCSRCommPkg *comm_pkg,
|
||||
return hypre_error_flag;
|
||||
}
|
||||
|
||||
#endif // defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
|
||||
|
||||
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
|
||||
|
||||
|
||||
hypre_CSRMatrix*
|
||||
hypre_MergeDiagAndOffdDevice(hypre_ParCSRMatrix *A)
|
||||
{
|
||||
|
||||
@ -780,6 +780,14 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R,
|
||||
HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_HOST);
|
||||
|
||||
/* add Cext to local part of Cbar */
|
||||
/* WM: debug */
|
||||
HYPRE_Int my_id;
|
||||
hypre_MPI_Comm_rank(hypre_MPI_COMM_WORLD, &my_id);
|
||||
if (my_id == 0)
|
||||
{
|
||||
hypre_CSRMatrixPrint(Cbar, "Cbar");
|
||||
hypre_CSRMatrixPrint(Cext, "Cext");
|
||||
}
|
||||
hypre_ParCSRTMatMatPartialAddDevice(hypre_ParCSRMatrixCommPkg(R),
|
||||
hypre_ParCSRMatrixNumCols(R),
|
||||
hypre_ParCSRMatrixNumCols(P),
|
||||
@ -794,6 +802,12 @@ hypre_ParCSRMatrixRAPKTDevice( hypre_ParCSRMatrix *R,
|
||||
&C_offd,
|
||||
&num_cols_offd_C,
|
||||
&col_map_offd_C);
|
||||
/* WM: debug */
|
||||
if (my_id == 0)
|
||||
{
|
||||
hypre_CSRMatrixPrint(C_diag, "C_diag");
|
||||
hypre_CSRMatrixPrint(C_offd, "C_offd");
|
||||
}
|
||||
|
||||
hypre_TFree(col_map_offd, HYPRE_MEMORY_DEVICE);
|
||||
}
|
||||
@ -973,7 +987,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
|
||||
oneapi::dpl::make_zip_iterator(work, big_work),
|
||||
pred1 );
|
||||
|
||||
HYPRE_Int Cext_diag_nnz = thrust::get<0>(dia_end.base()) - work;
|
||||
HYPRE_Int Cext_diag_nnz = std::get<0>(dia_end.base()) - work;
|
||||
#else
|
||||
auto dia_end = HYPRE_THRUST_CALL( copy_if,
|
||||
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), Cext_bigj)),
|
||||
@ -1049,10 +1063,19 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
|
||||
HYPRE_Int *ie_ii = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE);
|
||||
HYPRE_Int *ie_j = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE);
|
||||
|
||||
#if defined(HYPRE_USING_SYCL)
|
||||
hypreSycl_sequence(ie_ii, ie_ii + num_rows, 0);
|
||||
HYPRE_ONEDPL_CALL( std::copy, send_map, send_map + num_elemt, ie_ii + num_rows);
|
||||
hypreSycl_sequence(ie_j, ie_j + num_rows + num_elemt, 0);
|
||||
auto zipped_begin = oneapi::dpl::make_zip_iterator(ie_ii, ie_j);
|
||||
HYPRE_ONEDPL_CALL( std::stable_sort, zipped_begin, zipped_begin + num_rows + num_elemt,
|
||||
[](auto lhs, auto rhs) { return std::get<0>(lhs) < std::get<0>(rhs); } );
|
||||
#else
|
||||
HYPRE_THRUST_CALL( sequence, ie_ii, ie_ii + num_rows);
|
||||
HYPRE_THRUST_CALL( copy, send_map, send_map + num_elemt, ie_ii + num_rows);
|
||||
HYPRE_THRUST_CALL( sequence, ie_j, ie_j + num_rows + num_elemt);
|
||||
HYPRE_THRUST_CALL( stable_sort_by_key, ie_ii, ie_ii + num_rows + num_elemt, ie_j );
|
||||
#endif
|
||||
|
||||
HYPRE_Int *ie_i = hypreDevice_CsrRowIndicesToPtrs(num_rows, num_rows + num_elemt, ie_ii);
|
||||
hypre_TFree(ie_ii, HYPRE_MEMORY_DEVICE);
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
******************************************************************************/
|
||||
#include "seq_mv.h"
|
||||
|
||||
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
|
||||
|
||||
#define HYPRE_SPGEMM_DEVICE_USE_DSHMEM
|
||||
|
||||
#include <csr_spgemm_device_numer.h>
|
||||
@ -24,3 +26,4 @@ hypre_spgemm_numerical_max_num_blocks
|
||||
< HYPRE_SPGEMM_NUMER_HASH_SIZE * 32, HYPRE_SPGEMM_BASE_GROUP_SIZE * 32 >
|
||||
( HYPRE_Int multiProcessorCount, HYPRE_Int *num_blocks_ptr, HYPRE_Int *block_size_ptr );
|
||||
|
||||
#endif /* HYPRE_USING_CUDA || defined(HYPRE_USING_HIP) */
|
||||
|
||||
@ -31,6 +31,7 @@ hypreDevice_CSRSpGemmOnemklsparse(HYPRE_Int m,
|
||||
HYPRE_Int **d_jc_out,
|
||||
HYPRE_Complex **d_c_out)
|
||||
{
|
||||
hypre_printf("WM: debug - using oneMKL spgemm\n");
|
||||
std::int64_t *tmp_size1 = NULL, *tmp_size2, *nnzC = NULL;
|
||||
void *tmp_buffer1 = NULL;
|
||||
void *tmp_buffer2 = NULL;
|
||||
|
||||
@ -137,6 +137,22 @@ OutputIter hypreSycl_gather(InputIter1 map_first, InputIter1 map_last,
|
||||
return HYPRE_ONEDPL_CALL( oneapi::dpl::copy, perm_begin, perm_begin + n, result);
|
||||
}
|
||||
|
||||
// Equivalent of thrust::sequence (with step=1)
|
||||
template <class Iter, class T>
|
||||
void hypreSycl_sequence(Iter first, Iter last, T init = 0)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same<typename std::iterator_traits<Iter>::iterator_category,
|
||||
std::random_access_iterator_tag>::value,
|
||||
"Iterators passed to algorithms must be random-access iterators.");
|
||||
using DiffType = typename std::iterator_traits<Iter>::difference_type;
|
||||
HYPRE_ONEDPL_CALL( std::transform,
|
||||
oneapi::dpl::counting_iterator<DiffType>(init),
|
||||
oneapi::dpl::counting_iterator<DiffType>(std::distance(first, last)),
|
||||
first,
|
||||
[](auto i) { return i; });
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
Loading…
Reference in New Issue
Block a user