diff --git a/src/parcsr_mv/par_csr_matop_device.c b/src/parcsr_mv/par_csr_matop_device.c index 30398a4a3..8c145f052 100644 --- a/src/parcsr_mv/par_csr_matop_device.c +++ b/src/parcsr_mv/par_csr_matop_device.c @@ -468,6 +468,8 @@ hypre_ConcatDiagOffdAndExtDevice(hypre_ParCSRMatrix *A, } #endif +/* The input B_ext is a BigJ matrix, so is the output */ +/* RL: TODO FIX the num of columns of the output (from B_ext 'big' num cols) */ HYPRE_Int hypre_ExchangeExternalRowsDeviceInit( hypre_CSRMatrix *B_ext, hypre_ParCSRCommPkg *comm_pkg_A, diff --git a/src/parcsr_mv/par_csr_triplemat_device.c b/src/parcsr_mv/par_csr_triplemat_device.c index cb9a39306..35318f261 100644 --- a/src/parcsr_mv/par_csr_triplemat_device.c +++ b/src/parcsr_mv/par_csr_triplemat_device.c @@ -472,6 +472,8 @@ hypre_ParCSRTMatMatKTDevice( hypre_ParCSRMatrix *A, thrust::minus() ); #endif + // Change Cint into a BigJ matrix + // RL: TODO FIX the 'big' num of columns to global size hypre_CSRMatrixBigJ(Cint) = hypre_TAlloc(HYPRE_BigInt, hypre_CSRMatrixNumNonzeros(Cint), HYPRE_MEMORY_DEVICE); @@ -908,6 +910,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, HYPRE_Int *work = hypre_TAlloc(HYPRE_Int, Cext_nnz, HYPRE_MEMORY_DEVICE); HYPRE_Int *map_offd_to_C; + // Convert Cext from BigJ to J // Cext offd #if defined(HYPRE_USING_SYCL) auto off_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0), @@ -1003,6 +1006,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, thrust::minus()); #endif + hypre_CSRMatrixNumCols(Cext) = num_cols + num_cols_offd_C; + // transform Cbar_local J index RAP_functor<2, HYPRE_Int> func2(num_cols, 0, map_offd_to_C); #if defined(HYPRE_USING_SYCL) @@ -1019,6 +1024,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, func2 ); #endif + hypre_CSRMatrixNumCols(Cbar_local) = num_cols + num_cols_offd_C; + hypre_TFree(big_work, HYPRE_MEMORY_DEVICE); hypre_TFree(work, HYPRE_MEMORY_DEVICE); hypre_TFree(map_offd_to_C, HYPRE_MEMORY_DEVICE); @@ -1059,7 +1066,6 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, hypre_CSRMatrixI(IE) = ie_i; hypre_CSRMatrixJ(IE) = ie_j; - //hypre_CSRMatrixData(IE) = ie_a; // CC = [Cbar_local; Cext] hypre_CSRMatrix *CC = hypre_CSRMatrixStack2Device(Cbar_local, Cext);