fix matrix num cols
This commit is contained in:
parent
0eab14592b
commit
430241d7d0
@ -468,6 +468,8 @@ hypre_ConcatDiagOffdAndExtDevice(hypre_ParCSRMatrix *A,
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/* The input B_ext is a BigJ matrix, so is the output */
|
||||||
|
/* RL: TODO FIX the num of columns of the output (from B_ext 'big' num cols) */
|
||||||
HYPRE_Int
|
HYPRE_Int
|
||||||
hypre_ExchangeExternalRowsDeviceInit( hypre_CSRMatrix *B_ext,
|
hypre_ExchangeExternalRowsDeviceInit( hypre_CSRMatrix *B_ext,
|
||||||
hypre_ParCSRCommPkg *comm_pkg_A,
|
hypre_ParCSRCommPkg *comm_pkg_A,
|
||||||
|
|||||||
@ -472,6 +472,8 @@ hypre_ParCSRTMatMatKTDevice( hypre_ParCSRMatrix *A,
|
|||||||
thrust::minus<HYPRE_Int>() );
|
thrust::minus<HYPRE_Int>() );
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Change Cint into a BigJ matrix
|
||||||
|
// RL: TODO FIX the 'big' num of columns to global size
|
||||||
hypre_CSRMatrixBigJ(Cint) = hypre_TAlloc(HYPRE_BigInt, hypre_CSRMatrixNumNonzeros(Cint),
|
hypre_CSRMatrixBigJ(Cint) = hypre_TAlloc(HYPRE_BigInt, hypre_CSRMatrixNumNonzeros(Cint),
|
||||||
HYPRE_MEMORY_DEVICE);
|
HYPRE_MEMORY_DEVICE);
|
||||||
|
|
||||||
@ -908,6 +910,7 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
|
|||||||
HYPRE_Int *work = hypre_TAlloc(HYPRE_Int, Cext_nnz, HYPRE_MEMORY_DEVICE);
|
HYPRE_Int *work = hypre_TAlloc(HYPRE_Int, Cext_nnz, HYPRE_MEMORY_DEVICE);
|
||||||
HYPRE_Int *map_offd_to_C;
|
HYPRE_Int *map_offd_to_C;
|
||||||
|
|
||||||
|
// Convert Cext from BigJ to J
|
||||||
// Cext offd
|
// Cext offd
|
||||||
#if defined(HYPRE_USING_SYCL)
|
#if defined(HYPRE_USING_SYCL)
|
||||||
auto off_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0),
|
auto off_end = hypreSycl_copy_if( oneapi::dpl::make_zip_iterator(oneapi::dpl::counting_iterator(0),
|
||||||
@ -1003,6 +1006,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
|
|||||||
thrust::minus<HYPRE_BigInt>());
|
thrust::minus<HYPRE_BigInt>());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
hypre_CSRMatrixNumCols(Cext) = num_cols + num_cols_offd_C;
|
||||||
|
|
||||||
// transform Cbar_local J index
|
// transform Cbar_local J index
|
||||||
RAP_functor<2, HYPRE_Int> func2(num_cols, 0, map_offd_to_C);
|
RAP_functor<2, HYPRE_Int> func2(num_cols, 0, map_offd_to_C);
|
||||||
#if defined(HYPRE_USING_SYCL)
|
#if defined(HYPRE_USING_SYCL)
|
||||||
@ -1019,6 +1024,8 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
|
|||||||
func2 );
|
func2 );
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
hypre_CSRMatrixNumCols(Cbar_local) = num_cols + num_cols_offd_C;
|
||||||
|
|
||||||
hypre_TFree(big_work, HYPRE_MEMORY_DEVICE);
|
hypre_TFree(big_work, HYPRE_MEMORY_DEVICE);
|
||||||
hypre_TFree(work, HYPRE_MEMORY_DEVICE);
|
hypre_TFree(work, HYPRE_MEMORY_DEVICE);
|
||||||
hypre_TFree(map_offd_to_C, HYPRE_MEMORY_DEVICE);
|
hypre_TFree(map_offd_to_C, HYPRE_MEMORY_DEVICE);
|
||||||
@ -1059,7 +1066,6 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg,
|
|||||||
|
|
||||||
hypre_CSRMatrixI(IE) = ie_i;
|
hypre_CSRMatrixI(IE) = ie_i;
|
||||||
hypre_CSRMatrixJ(IE) = ie_j;
|
hypre_CSRMatrixJ(IE) = ie_j;
|
||||||
//hypre_CSRMatrixData(IE) = ie_a;
|
|
||||||
|
|
||||||
// CC = [Cbar_local; Cext]
|
// CC = [Cbar_local; Cext]
|
||||||
hypre_CSRMatrix *CC = hypre_CSRMatrixStack2Device(Cbar_local, Cext);
|
hypre_CSRMatrix *CC = hypre_CSRMatrixStack2Device(Cbar_local, Cext);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user