diff --git a/src/parcsr_mv/par_csr_triplemat_device.c b/src/parcsr_mv/par_csr_triplemat_device.c index 54452ae4c..c2c0e65c5 100644 --- a/src/parcsr_mv/par_csr_triplemat_device.c +++ b/src/parcsr_mv/par_csr_triplemat_device.c @@ -1055,8 +1055,15 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, num_rows + num_elemt); hypre_CSRMatrixMemoryLocation(IE) = HYPRE_MEMORY_DEVICE; - HYPRE_Int *ie_ii = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE); - HYPRE_Int *ie_j = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE); + HYPRE_Int *ie_ii = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE); + HYPRE_Int *ie_j = hypre_TAlloc(HYPRE_Int, num_rows + num_elemt, HYPRE_MEMORY_DEVICE); + HYPRE_Complex *ie_a = NULL; + + if (hypre_HandleSpgemmUseVendor(hypre_handle())) + { + ie_a = hypre_TAlloc(HYPRE_Complex, num_rows + num_elemt, HYPRE_MEMORY_DEVICE); + HYPRE_THRUST_CALL(fill, ie_a, ie_a + num_rows + num_elemt, 1.0); + } HYPRE_THRUST_CALL( sequence, ie_ii, ie_ii + num_rows); HYPRE_THRUST_CALL( copy, send_map, send_map + num_elemt, ie_ii + num_rows); @@ -1066,8 +1073,9 @@ hypre_ParCSRTMatMatPartialAddDevice( hypre_ParCSRCommPkg *comm_pkg, HYPRE_Int *ie_i = hypreDevice_CsrRowIndicesToPtrs(num_rows, num_rows + num_elemt, ie_ii); hypre_TFree(ie_ii, HYPRE_MEMORY_DEVICE); - hypre_CSRMatrixI(IE) = ie_i; - hypre_CSRMatrixJ(IE) = ie_j; + hypre_CSRMatrixI(IE) = ie_i; + hypre_CSRMatrixJ(IE) = ie_j; + hypre_CSRMatrixData(IE) = ie_a; // CC = [Cbar_local; Cext] hypre_CSRMatrix *CC = hypre_CSRMatrixStack2Device(Cbar_local, Cext);