One based index (#408)

This PR fixes #396, indexing problem in HYPRE_IJVectorGetValues
This commit is contained in:
Ruipeng Li 2021-06-18 15:56:22 -07:00 committed by GitHub
parent 970f75821f
commit 27d6b2bd72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 25 deletions

View File

@ -655,6 +655,7 @@ hypre_IJVectorGetValuesPar(hypre_IJVector *vector,
HYPRE_BigInt *IJpartitioning = hypre_IJVectorPartitioning(vector);
HYPRE_BigInt vec_start;
HYPRE_BigInt vec_stop;
HYPRE_BigInt jmin = hypre_IJVectorGlobalFirstRow(vector);
hypre_ParVector *par_vector = (hypre_ParVector*) hypre_IJVectorObject(vector);
HYPRE_Int print_level = hypre_IJVectorPrintLevel(vector);
@ -721,10 +722,7 @@ hypre_IJVectorGetValuesPar(hypre_IJVector *vector,
return hypre_error_flag;
}
hypre_assert(vec_start == hypre_ParVectorFirstIndex(par_vector));
hypre_assert(vec_stop == hypre_ParVectorLastIndex(par_vector) + 1);
hypre_ParVectorGetValues(par_vector, num_values, (HYPRE_BigInt *) indices, values);
hypre_ParVectorGetValues2(par_vector, num_values, (HYPRE_BigInt *) indices, jmin, values);
return hypre_error_flag;
}

View File

@ -943,12 +943,15 @@ HYPRE_Int hypre_ParVectorReadIJ ( MPI_Comm comm , const char *filename , HYPRE_I
HYPRE_Int hypre_FillResponseParToVectorAll ( void *p_recv_contact_buf , HYPRE_Int contact_size , HYPRE_Int contact_proc , void *ro , MPI_Comm comm , void **p_send_response_buf , HYPRE_Int *response_message_size );
HYPRE_Complex hypre_ParVectorLocalSumElts ( hypre_ParVector *vector );
HYPRE_Int hypre_ParVectorGetValues ( hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices , HYPRE_Complex *values);
HYPRE_Int hypre_ParVectorGetValues2( hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices, HYPRE_BigInt base, HYPRE_Complex *values );
HYPRE_Int hypre_ParVectorGetValuesHost(hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices, HYPRE_BigInt base, HYPRE_Complex *values);
HYPRE_Int hypre_ParVectorElmdivpy( hypre_ParVector *x, hypre_ParVector *b, hypre_ParVector *y );
/* par_vector_device.c */
HYPRE_Int hypre_ParVectorGetValuesDevice(hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices, HYPRE_Complex *values);
HYPRE_Int hypre_ParVectorGetValuesDevice(hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices, HYPRE_BigInt base, HYPRE_Complex *values);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -1023,6 +1023,7 @@ HYPRE_Int
hypre_ParVectorGetValuesHost(hypre_ParVector *vector,
HYPRE_Int num_values,
HYPRE_BigInt *indices,
HYPRE_BigInt base,
HYPRE_Complex *values)
{
HYPRE_Int i, ierr = 0;
@ -1046,7 +1047,7 @@ hypre_ParVectorGetValuesHost(hypre_ParVector *vector,
#endif
for (i = 0; i < num_values; i++)
{
HYPRE_BigInt index = indices[i];
HYPRE_BigInt index = indices[i] - base;
if (index < first_index || index > last_index)
{
ierr ++;
@ -1085,22 +1086,33 @@ hypre_ParVectorGetValuesHost(hypre_ParVector *vector,
return hypre_error_flag;
}
HYPRE_Int
hypre_ParVectorGetValues2(hypre_ParVector *vector,
HYPRE_Int num_values,
HYPRE_BigInt *indices,
HYPRE_BigInt base,
HYPRE_Complex *values)
{
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
if (HYPRE_EXEC_DEVICE == hypre_GetExecPolicy1( hypre_ParVectorMemoryLocation(vector) ))
{
hypre_ParVectorGetValuesDevice(vector, num_values, indices, base, values);
}
else
#endif
{
hypre_ParVectorGetValuesHost(vector, num_values, indices, base, values);
}
return hypre_error_flag;
}
HYPRE_Int
hypre_ParVectorGetValues(hypre_ParVector *vector,
HYPRE_Int num_values,
HYPRE_BigInt *indices,
HYPRE_Complex *values)
{
#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
if (HYPRE_EXEC_DEVICE == hypre_GetExecPolicy1( hypre_ParVectorMemoryLocation(vector) ))
{
hypre_ParVectorGetValuesDevice(vector, num_values, indices, values);
}
else
#endif
{
hypre_ParVectorGetValuesHost(vector, num_values, indices, values);
}
return hypre_error_flag;
return hypre_ParVectorGetValues2(vector, num_values, indices, 0, values);
}

View File

@ -13,6 +13,7 @@ HYPRE_Int
hypre_ParVectorGetValuesDevice(hypre_ParVector *vector,
HYPRE_Int num_values,
HYPRE_BigInt *indices,
HYPRE_BigInt base,
HYPRE_Complex *values)
{
HYPRE_Int ierr = 0;
@ -28,7 +29,7 @@ hypre_ParVectorGetValuesDevice(hypre_ParVector *vector,
ierr = HYPRE_THRUST_CALL( count_if,
indices,
indices + num_values,
out_of_range<HYPRE_BigInt>(first_index, last_index) );
out_of_range<HYPRE_BigInt>(first_index + base, last_index + base) );
if (ierr)
{
hypre_error_in_arg(3);
@ -36,18 +37,18 @@ hypre_ParVectorGetValuesDevice(hypre_ParVector *vector,
hypre_printf(" error: %d indices out of range! -- hypre_ParVectorGetValues\n", ierr);
HYPRE_THRUST_CALL( gather_if,
thrust::make_transform_iterator(indices, _1 - first_index),
thrust::make_transform_iterator(indices, _1 - first_index) + num_values,
thrust::make_transform_iterator(indices, _1 - base - first_index),
thrust::make_transform_iterator(indices, _1 - base - first_index) + num_values,
indices,
data,
values,
in_range<HYPRE_BigInt>(first_index, last_index) );
in_range<HYPRE_BigInt>(first_index + base, last_index + base) );
}
else
{
HYPRE_THRUST_CALL( gather,
thrust::make_transform_iterator(indices, _1 - first_index),
thrust::make_transform_iterator(indices, _1 - first_index) + num_values,
thrust::make_transform_iterator(indices, _1 - base - first_index),
thrust::make_transform_iterator(indices, _1 - base - first_index) + num_values,
data,
values);
}

View File

@ -301,6 +301,8 @@ HYPRE_Int hypre_ParVectorReadIJ ( MPI_Comm comm , const char *filename , HYPRE_I
HYPRE_Int hypre_FillResponseParToVectorAll ( void *p_recv_contact_buf , HYPRE_Int contact_size , HYPRE_Int contact_proc , void *ro , MPI_Comm comm , void **p_send_response_buf , HYPRE_Int *response_message_size );
HYPRE_Complex hypre_ParVectorLocalSumElts ( hypre_ParVector *vector );
HYPRE_Int hypre_ParVectorGetValues ( hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices , HYPRE_Complex *values);
HYPRE_Int hypre_ParVectorGetValues2( hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices, HYPRE_BigInt base, HYPRE_Complex *values );
HYPRE_Int hypre_ParVectorGetValuesHost(hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices, HYPRE_BigInt base, HYPRE_Complex *values);
HYPRE_Int hypre_ParVectorElmdivpy( hypre_ParVector *x, hypre_ParVector *b, hypre_ParVector *y );
/* par_vector_device.c */
HYPRE_Int hypre_ParVectorGetValuesDevice(hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices, HYPRE_Complex *values);
HYPRE_Int hypre_ParVectorGetValuesDevice(hypre_ParVector *vector, HYPRE_Int num_values, HYPRE_BigInt *indices, HYPRE_BigInt base, HYPRE_Complex *values);