Additional fixes for sycl build. Note that CSR matrix memory location must be set correctly before calling hypre_CSRMatrixSetRownnz.

This commit is contained in:
Wayne Mitchell 2022-06-10 19:02:35 +00:00
parent 3cc0138a9b
commit 7dc8321f4a
3 changed files with 15 additions and 12 deletions

View File

@ -1175,6 +1175,7 @@ hypre_ParMatmul( hypre_ParCSRMatrix *A,
hypre_CSRMatrixData(C_diag) = C_diag_data; hypre_CSRMatrixData(C_diag) = C_diag_data;
hypre_CSRMatrixI(C_diag) = C_diag_i; hypre_CSRMatrixI(C_diag) = C_diag_i;
hypre_CSRMatrixJ(C_diag) = C_diag_j; hypre_CSRMatrixJ(C_diag) = C_diag_j;
hypre_CSRMatrixMemoryLocation(C_diag) = memory_location_C;
hypre_CSRMatrixSetRownnz(C_diag); hypre_CSRMatrixSetRownnz(C_diag);
C_offd = hypre_ParCSRMatrixOffd(C); C_offd = hypre_ParCSRMatrixOffd(C);
@ -1186,10 +1187,9 @@ hypre_ParMatmul( hypre_ParCSRMatrix *A,
hypre_CSRMatrixJ(C_offd) = C_offd_j; hypre_CSRMatrixJ(C_offd) = C_offd_j;
hypre_ParCSRMatrixColMapOffd(C) = col_map_offd_C; hypre_ParCSRMatrixColMapOffd(C) = col_map_offd_C;
} }
hypre_CSRMatrixMemoryLocation(C_offd) = memory_location_C;
hypre_CSRMatrixSetRownnz(C_offd); hypre_CSRMatrixSetRownnz(C_offd);
hypre_CSRMatrixMemoryLocation(C_diag) = memory_location_C;
hypre_CSRMatrixMemoryLocation(C_offd) = memory_location_C;
/*----------------------------------------------------------------------- /*-----------------------------------------------------------------------
* Free various arrays * Free various arrays
@ -4009,6 +4009,9 @@ hypre_ParTMatmul( hypre_ParCSRMatrix *A,
hypre_ParCSRMatrixRowvalues(C) = NULL; hypre_ParCSRMatrixRowvalues(C) = NULL;
hypre_ParCSRMatrixGetrowactive(C) = 0; hypre_ParCSRMatrixGetrowactive(C) = 0;
hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixDiag(C)) = memory_location_C;
hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixOffd(C)) = memory_location_C;
if (C_diag) if (C_diag)
{ {
hypre_CSRMatrixSetRownnz(C_diag); hypre_CSRMatrixSetRownnz(C_diag);
@ -4029,9 +4032,6 @@ hypre_ParTMatmul( hypre_ParCSRMatrix *A,
hypre_ParCSRMatrixOffd(C) = C_tmp_offd; hypre_ParCSRMatrixOffd(C) = C_tmp_offd;
} }
hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixDiag(C)) = memory_location_C;
hypre_CSRMatrixMemoryLocation(hypre_ParCSRMatrixOffd(C)) = memory_location_C;
if (num_cols_offd_C) if (num_cols_offd_C)
{ {
HYPRE_Int jj_count_offd, nnz_offd; HYPRE_Int jj_count_offd, nnz_offd;

View File

@ -940,12 +940,16 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
&num_cols_offd_C, &col_map_offd_C, &map_offd_to_C); &num_cols_offd_C, &col_map_offd_C, &map_offd_to_C);
#if defined(HYPRE_USING_SYCL) #if defined(HYPRE_USING_SYCL)
HYPRE_ONEDPL_CALL( oneapi::dpl::lower_bound, /* WM: onedpl lower_bound currently does not accept zero length input */
col_map_offd_C, if (num_cols_offd_C > 0)
col_map_offd_C + num_cols_offd_C, {
big_work, HYPRE_ONEDPL_CALL( oneapi::dpl::lower_bound,
big_work + Cext_offd_nnz, col_map_offd_C,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work) ); col_map_offd_C + num_cols_offd_C,
big_work,
big_work + Cext_offd_nnz,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work) );
}
HYPRE_ONEDPL_CALL( std::transform, HYPRE_ONEDPL_CALL( std::transform,
oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work), oneapi::dpl::make_permutation_iterator(hypre_CSRMatrixJ(Cext), work),

View File

@ -31,7 +31,6 @@ hypreDevice_CSRSpGemmOnemklsparse(HYPRE_Int m,
HYPRE_Int **d_jc_out, HYPRE_Int **d_jc_out,
HYPRE_Complex **d_c_out) HYPRE_Complex **d_c_out)
{ {
hypre_printf("WM: debug - using oneMKL spgemm\n");
std::int64_t *tmp_size1 = NULL, *tmp_size2, *nnzC = NULL; std::int64_t *tmp_size1 = NULL, *tmp_size2, *nnzC = NULL;
void *tmp_buffer1 = NULL; void *tmp_buffer1 = NULL;
void *tmp_buffer2 = NULL; void *tmp_buffer2 = NULL;