triangular solve on GPUs; runcheck (#256)

This PR fixes triangular solve on GPUs, and runcheck.sh

Co-authored-by: Daniel Osei-Kuffuor <oseikuffuor1@llnl.gov>
This commit is contained in:
Ruipeng Li 2021-01-15 20:46:59 -08:00 committed by GitHub
parent bd76daf124
commit 2186a8fb34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 172 additions and 67 deletions

View File

@ -38,7 +38,7 @@ shift
# Basic build and run tests
mo="-j test"
eo=""
roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 8e-3"
roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 2e-2"
ross="-struct -sstruct -rt -mpibind -rtol 1e-6 -atol 1e-6"
rost="-struct -rt -mpibind -rtol 1e-8 -atol 1e-8"
rocuda="-cuda_lassen -rt -mpibind"

View File

@ -38,7 +38,7 @@ shift
# Basic build and run tests
mo="-j test"
eo=""
roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 8e-3"
roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 2e-2"
ross="-struct -sstruct -rt -mpibind -rtol 1e-6 -atol 1e-6"
rost="-struct -rt -mpibind -rtol 1e-8 -atol 1e-8"
rocuda="-cuda_ray -rt -mpibind"

View File

@ -320,8 +320,9 @@ hypre_ParCSRMatrixMigrate(hypre_ParCSRMatrix *A, HYPRE_MemoryLocation memory_loc
return hypre_error_flag;
}
if ( hypre_GetActualMemLocation(memory_location) !=
hypre_GetActualMemLocation(hypre_ParCSRMatrixMemoryLocation(A)) )
HYPRE_MemoryLocation old_memory_location = hypre_ParCSRMatrixMemoryLocation(A);
if ( hypre_GetActualMemLocation(memory_location) != hypre_GetActualMemLocation(old_memory_location) )
{
hypre_CSRMatrix *A_diag = hypre_CSRMatrixClone_v2(hypre_ParCSRMatrixDiag(A), 1, memory_location);
hypre_CSRMatrixDestroy(hypre_ParCSRMatrixDiag(A));
@ -330,6 +331,9 @@ hypre_ParCSRMatrixMigrate(hypre_ParCSRMatrix *A, HYPRE_MemoryLocation memory_loc
hypre_CSRMatrix *A_offd = hypre_CSRMatrixClone_v2(hypre_ParCSRMatrixOffd(A), 1, memory_location);
hypre_CSRMatrixDestroy(hypre_ParCSRMatrixOffd(A));
hypre_ParCSRMatrixOffd(A) = A_offd;
hypre_TFree(hypre_ParCSRMatrixRowindices(A), old_memory_location);
hypre_TFree(hypre_ParCSRMatrixRowvalues(A), old_memory_location);
}
else
{

View File

@ -618,6 +618,9 @@ hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A )
return hypre_error_flag;
}
/* check if diagonal entry is the first one at each row
* Return: the number of rows that do not have the first entry as diagonal
*/
__global__ void
hypreCUDAKernel_CSRCheckDiagFirst( HYPRE_Int nrows,
HYPRE_Int *ia,
@ -643,7 +646,8 @@ hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A )
dim3 gDim = hypre_GetDefaultCUDAGridDimension(hypre_CSRMatrixNumRows(A), "thread", bDim);
HYPRE_Int *result = hypre_TAlloc(HYPRE_Int, hypre_CSRMatrixNumRows(A), HYPRE_MEMORY_DEVICE);
HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirst, gDim, bDim, hypre_CSRMatrixNumRows(A),
HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirst, gDim, bDim,
hypre_CSRMatrixNumRows(A),
hypre_CSRMatrixI(A), hypre_CSRMatrixJ(A), result );
HYPRE_Int ierr = HYPRE_THRUST_CALL( reduce,
@ -657,6 +661,62 @@ hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A )
return ierr;
}
/* check if diagonal entry is the first one at each row, and
* assign numerical zero diag value `v'
* Return: the number of rows that do not have the first entry as diagonal
*/
__global__ void
hypreCUDAKernel_CSRCheckDiagFirstSetValueZero( HYPRE_Complex v,
HYPRE_Int nrows,
HYPRE_Int *ia,
HYPRE_Int *ja,
HYPRE_Complex *data,
HYPRE_Int *result )
{
const HYPRE_Int row = hypre_cuda_get_grid_thread_id<1,1>();
if (row < nrows)
{
const HYPRE_Int j = ia[row];
const HYPRE_Int col = ja[j];
result[row] = col != row;
if (col == row && data[j] == 0.0)
{
data[j] = v;
}
}
}
HYPRE_Int
hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A,
HYPRE_Complex v )
{
if (hypre_CSRMatrixNumRows(A) != hypre_CSRMatrixNumCols(A))
{
return -1;
}
dim3 bDim = hypre_GetDefaultCUDABlockDimension();
dim3 gDim = hypre_GetDefaultCUDAGridDimension(hypre_CSRMatrixNumRows(A), "thread", bDim);
HYPRE_Int *result = hypre_TAlloc(HYPRE_Int, hypre_CSRMatrixNumRows(A), HYPRE_MEMORY_DEVICE);
HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirstSetValueZero, gDim, bDim,
v, hypre_CSRMatrixNumRows(A),
hypre_CSRMatrixI(A), hypre_CSRMatrixJ(A), hypre_CSRMatrixData(A),
result );
HYPRE_Int ierr = HYPRE_THRUST_CALL( reduce,
result,
result + hypre_CSRMatrixNumRows(A) );
hypre_TFree(result, HYPRE_MEMORY_DEVICE);
hypre_SyncCudaComputeStream(hypre_handle());
return ierr;
}
typedef thrust::tuple<HYPRE_Int, HYPRE_Int> Int2;
struct Int2Unequal : public thrust::unary_function<Int2, bool>
{
@ -1234,6 +1294,13 @@ hypre_CSRMatrixTriLowerUpperSolveCusparse(char uplo,
hypre_CSRMatrixSortedData(A) = A_sa = hypre_TAlloc(HYPRE_Complex, nnzA, HYPRE_MEMORY_DEVICE);
hypre_TMemcpy(A_sj, A_j, HYPRE_Int, nnzA, HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_DEVICE);
hypre_TMemcpy(A_sa, A_a, HYPRE_Complex, nnzA, HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_DEVICE);
#if defined(HYPRE_USING_CUDA)
hypre_CSRMatrixData(A) = A_sa;
HYPRE_Int err = hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice(A, INFINITY); hypre_assert(err == 0);
hypre_CSRMatrixData(A) = A_a;
#endif
hypre_SortCSRCusparse(nrow, ncol, nnzA, A_i, A_sj, A_sa);
}

View File

@ -34,6 +34,7 @@ hypre_CSRMatrix* hypre_CSRMatrixAddPartialDevice( hypre_CSRMatrix *A, hypre_CSRM
HYPRE_Int hypre_CSRMatrixColNNzRealDevice( hypre_CSRMatrix *A, HYPRE_Real *colnnz);
HYPRE_Int hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A );
HYPRE_Int hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A );
HYPRE_Int hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A, HYPRE_Complex v );
void hypre_CSRMatrixComputeRowSumDevice( hypre_CSRMatrix *A, HYPRE_Int *CF_i, HYPRE_Int *CF_j, HYPRE_Complex *row_sum, HYPRE_Int type, HYPRE_Complex scal, const char *set_or_add);
void hypre_CSRMatrixExtractDiagonalDevice( hypre_CSRMatrix *A, HYPRE_Complex *d, HYPRE_Int type);
hypre_CSRMatrix* hypre_CSRMatrixStack2Device(hypre_CSRMatrix *A, hypre_CSRMatrix *B);

View File

@ -300,6 +300,7 @@ hypre_CSRMatrix* hypre_CSRMatrixAddPartialDevice( hypre_CSRMatrix *A, hypre_CSRM
HYPRE_Int hypre_CSRMatrixColNNzRealDevice( hypre_CSRMatrix *A, HYPRE_Real *colnnz);
HYPRE_Int hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A );
HYPRE_Int hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A );
HYPRE_Int hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A, HYPRE_Complex v );
void hypre_CSRMatrixComputeRowSumDevice( hypre_CSRMatrix *A, HYPRE_Int *CF_i, HYPRE_Int *CF_j, HYPRE_Complex *row_sum, HYPRE_Int type, HYPRE_Complex scal, const char *set_or_add);
void hypre_CSRMatrixExtractDiagonalDevice( hypre_CSRMatrix *A, HYPRE_Complex *d, HYPRE_Int type);
hypre_CSRMatrix* hypre_CSRMatrixStack2Device(hypre_CSRMatrix *A, hypre_CSRMatrix *B);

View File

@ -179,6 +179,7 @@ FILES="\
${TNAME}.out.322\
${TNAME}.out.323\
${TNAME}.out.324\
${TNAME}.out.325\
"
for i in $FILES

View File

@ -24,79 +24,110 @@ fi
#echo "runcheck rtol = $RTOL, atol = $ATOL"
awk -v filename="$SNAME" 'BEGIN{
awk -v ofilename="$FNAME" -v sfilename="$SNAME" 'BEGIN{
FS=" ";
key = 0;
saved_key = 0;
ln = 0;
# Read saved file data into array
while (getline < filename)
{
if(NF > 0 && substr($1,1,1) !~ /#/)
{
saved_line[key]=$0;
saved_array[++key]=$NF;
}
}
close(filename);
# read out file and compare
filename="'"$FNAME"'"
rtol = "'"$RTOL"'" + 0. # adding zero is necessary to convert from string to number
atol = "'"$ATOL"'" + 0. # adding zero is necessary to convert from string to number
key=0;
ln=0;
pass=1;
while (getline < filename)
while (getline < sfilename)
{
ln++;
if(NF > 0 && substr($1,1,1) !~ /#/)
{
# get corresponding value in saved array
val = saved_array[++key];
# floating point field comparison
if($NF != int($NF))
# loop over fields in current line
for(id=1; id<=NF; id++)
{
err = val - $NF;
# get absolute value of err and val
err = err < 0 ? -err : err;
val = val < 0 ? -val : val;
# abs err <= atol or rel err <= rtol
if(err <= atol || err <= rtol*val)
# check if field is numeric
if($id ~ /^[0-9]+/)
{
#print "PASSED"
}
else
{
pass=0;
printf "(%d) - %s\n", ln, saved_line[key-1]
printf "(%d) + %s (err %.2e)\n\n", ln, $0, err
#printf "(%d) + %s <-- %s, err %.2e\n", ln, $0, val, err
}
}
else # integer comparison
{
tau = val - $NF;
# get absolute value of tau
tau = tau < 0 ? -tau : tau;
# get ceiling of rtol*val (= max allowed change)
gamma = int(1.0 + rtol*val);
if(tau <= gamma)
{
#print "PASSED"
}
else
{
pass=0;
printf "(%d) %s <-- %s, err %d\n", ln, $0, val, tau
ln_id[saved_key]=ln;
saved_line[saved_key]=$0;
saved_array[++saved_key]=$id;
}
}
}
}
close(sfilename);
# Read out file data into array
out_key=0;
ln=0;
while (getline < ofilename)
{
if(NF > 0 && substr($1,1,1) !~ /#/)
{
# loop over fields in current line
for(id=1; id<=NF; id++)
{
# check if field is numeric
if($id ~ /^[0-9]+/)
{
out_line[out_key]=$0;
out_array[++out_key]=$id;
}
}
}
}
close(ofilename);
# compare data arrays
if(saved_key != out_key)
{
printf "Number of numeric entries do not match!!\n"
printf "Saved file (%d entries) Output file (%d entries)\n\n", saved_key, out_key
}
# compare numeric entries
rtol = "'"$RTOL"'" + 0. # adding zero is necessary to convert from string to number
atol = "'"$ATOL"'" + 0. # adding zero is necessary to convert from string to number
for(id=1; id<=saved_key; id++)
{
# get value from arrays
saved_val = saved_array[id];
out_val = out_array[id];
# floating point field comparison
if(length(saved_val) != length(int(saved_val)) && length(out_val) != length(int(out_val)))
{
err = saved_val - out_val;
# get absolute value of err and saved_val
err = err < 0 ? -err : err;
saved_val = saved_val < 0 ? -saved_val : saved_val;
# abs err <= atol or rel err <= rtol
if(err <= atol || err <= rtol*saved_val)
{
#print "PASSED"
}
else
{
pass=0;
printf "(%d) - %s\n", ln_id[id-1], saved_line[id-1]
printf "(%d) + %s (err %.2e)\n\n", ln_id[id-1], out_line[id-1], err
}
}
else if(length(saved_val) == length(int(saved_val)) && length(out_val) == length(int(out_val)))# integer comparison
{
tau = saved_val - out_val;
# get absolute value of tau
tau = tau < 0 ? -tau : tau;
# get ceiling of rtol*saved_val (= max allowed change)
gamma = int(1.0 + rtol*saved_val);
if(tau <= gamma)
{
#print "PASSED"
}
else
{
pass=0;
printf "(%d) %s <-- %s, err %d\n", ln, $0, saved_val, tau
}
}
else # type mismatch
{
printf "Numeric type mismatch in floating point or integer comparison!!\n"
printf "(%d) - %s \n", ln_id[id-1], saved_line[id-1]
printf "(%d) + %s \n\n", ln_id[id-1], out_line[id-1]
}
}
}'
#if [ "x$PASSFAIL" != "x" ];
#then
# echo $PASSFAIL
# diff -U3 -bI"time" $SNAME $FNAME >&2
#fi