fix matrix num cols

This commit is contained in:
Ruipeng Li 2022-06-08 11:29:01 -07:00
parent 0eab14592b
commit 430241d7d0
2 changed files with 9 additions and 1 deletions

View File

@ -468,6 +468,8 @@ hypre_ConcatDiagOffdAndExtDevice(hypre_ParCSRMatrix *A,
} }
#endif #endif
/* The input B_ext is a BigJ matrix, so is the output */
/* RL: TODO FIX the num of columns of the output (from B_ext 'big' num cols) */
HYPRE_Int HYPRE_Int
hypre_ExchangeExternalRowsDeviceInit( hypre_CSRMatrix *B_ext, hypre_ExchangeExternalRowsDeviceInit( hypre_CSRMatrix *B_ext,
hypre_ParCSRCommPkg *comm_pkg_A, hypre_ParCSRCommPkg *comm_pkg_A,

View File

@ -472,6 +472,8 @@ hypre_ParCSRTMatMatKTDevice( hypre_ParCSRMatrix *A,
thrust::minus<HYPRE_Int>() ); thrust::minus<HYPRE_Int>() );
#endif #endif
// Change Cint into a BigJ matrix
// RL: TODO FIX the 'big' num of columns to global size
hypre_CSRMatrixBigJ(Cint) = hypre_TAlloc(HYPRE_BigInt, hypre_CSRMatrixNumNonzeros(Cint), hypre_CSRMatrixBigJ(Cint) = hypre_TAlloc(HYPRE_BigInt, hypre_CSRMatrixNumNonzeros(Cint),
HYPRE_MEMORY_DEVICE); HYPRE_MEMORY_DEVICE);
@ -908,6 +910,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
HYPRE_Int *work = hypre_TAlloc(HYPRE_Int, Cext_nnz, HYPRE_MEMORY_DEVICE); HYPRE_Int *work = hypre_TAlloc(HYPRE_Int, Cext_nnz, HYPRE_MEMORY_DEVICE);
HYPRE_Int *map_offd_to_C; HYPRE_Int *map_offd_to_C;
// Convert Cext from BigJ to J
// Cext offd // Cext offd
#if defined(HYPRE_USING_SYCL) #if defined(HYPRE_USING_SYCL)
auto off_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0), auto off_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0),
@ -1003,6 +1006,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
thrust::minus<HYPRE_BigInt>()); thrust::minus<HYPRE_BigInt>());
#endif #endif
hypre_CSRMatrixNumCols(Cext) = num_cols + num_cols_offd_C;
// transform Cbar_local J index // transform Cbar_local J index
RAP_functor<2, HYPRE_Int> func2(num_cols, 0, map_offd_to_C); RAP_functor<2, HYPRE_Int> func2(num_cols, 0, map_offd_to_C);
#if defined(HYPRE_USING_SYCL) #if defined(HYPRE_USING_SYCL)
@ -1019,6 +1024,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
func2 ); func2 );
#endif #endif
hypre_CSRMatrixNumCols(Cbar_local) = num_cols + num_cols_offd_C;
hypre_TFree(big_work, HYPRE_MEMORY_DEVICE); hypre_TFree(big_work, HYPRE_MEMORY_DEVICE);
hypre_TFree(work, HYPRE_MEMORY_DEVICE); hypre_TFree(work, HYPRE_MEMORY_DEVICE);
hypre_TFree(map_offd_to_C, HYPRE_MEMORY_DEVICE); hypre_TFree(map_offd_to_C, HYPRE_MEMORY_DEVICE);
@ -1059,7 +1066,6 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
hypre_CSRMatrixI(IE) = ie_i; hypre_CSRMatrixI(IE) = ie_i;
hypre_CSRMatrixJ(IE) = ie_j; hypre_CSRMatrixJ(IE) = ie_j;
//hypre_CSRMatrixData(IE) = ie_a;
// CC = [Cbar_local; Cext] // CC = [Cbar_local; Cext]
hypre_CSRMatrix *CC = hypre_CSRMatrixStack2Device(Cbar_local, Cext); hypre_CSRMatrix *CC = hypre_CSRMatrixStack2Device(Cbar_local, Cext);