triangular solve on GPUs; runcheck (#256)
This PR fixes triangular solve on GPUs, and runcheck.sh Co-authored-by: Daniel Osei-Kuffuor <oseikuffuor1@llnl.gov>
This commit is contained in:
parent
bd76daf124
commit
2186a8fb34
@ -38,7 +38,7 @@ shift
|
||||
# Basic build and run tests
|
||||
mo="-j test"
|
||||
eo=""
|
||||
roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 8e-3"
|
||||
roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 2e-2"
|
||||
ross="-struct -sstruct -rt -mpibind -rtol 1e-6 -atol 1e-6"
|
||||
rost="-struct -rt -mpibind -rtol 1e-8 -atol 1e-8"
|
||||
rocuda="-cuda_lassen -rt -mpibind"
|
||||
|
||||
@ -38,7 +38,7 @@ shift
|
||||
# Basic build and run tests
|
||||
mo="-j test"
|
||||
eo=""
|
||||
roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 8e-3"
|
||||
roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 2e-2"
|
||||
ross="-struct -sstruct -rt -mpibind -rtol 1e-6 -atol 1e-6"
|
||||
rost="-struct -rt -mpibind -rtol 1e-8 -atol 1e-8"
|
||||
rocuda="-cuda_ray -rt -mpibind"
|
||||
|
||||
@ -320,8 +320,9 @@ hypre_ParCSRMatrixMigrate(hypre_ParCSRMatrix *A, HYPRE_MemoryLocation memory_loc
|
||||
return hypre_error_flag;
|
||||
}
|
||||
|
||||
if ( hypre_GetActualMemLocation(memory_location) !=
|
||||
hypre_GetActualMemLocation(hypre_ParCSRMatrixMemoryLocation(A)) )
|
||||
HYPRE_MemoryLocation old_memory_location = hypre_ParCSRMatrixMemoryLocation(A);
|
||||
|
||||
if ( hypre_GetActualMemLocation(memory_location) != hypre_GetActualMemLocation(old_memory_location) )
|
||||
{
|
||||
hypre_CSRMatrix *A_diag = hypre_CSRMatrixClone_v2(hypre_ParCSRMatrixDiag(A), 1, memory_location);
|
||||
hypre_CSRMatrixDestroy(hypre_ParCSRMatrixDiag(A));
|
||||
@ -330,6 +331,9 @@ hypre_ParCSRMatrixMigrate(hypre_ParCSRMatrix *A, HYPRE_MemoryLocation memory_loc
|
||||
hypre_CSRMatrix *A_offd = hypre_CSRMatrixClone_v2(hypre_ParCSRMatrixOffd(A), 1, memory_location);
|
||||
hypre_CSRMatrixDestroy(hypre_ParCSRMatrixOffd(A));
|
||||
hypre_ParCSRMatrixOffd(A) = A_offd;
|
||||
|
||||
hypre_TFree(hypre_ParCSRMatrixRowindices(A), old_memory_location);
|
||||
hypre_TFree(hypre_ParCSRMatrixRowvalues(A), old_memory_location);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -618,6 +618,9 @@ hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A )
|
||||
return hypre_error_flag;
|
||||
}
|
||||
|
||||
/* check if diagonal entry is the first one at each row
|
||||
* Return: the number of rows that do not have the first entry as diagonal
|
||||
*/
|
||||
__global__ void
|
||||
hypreCUDAKernel_CSRCheckDiagFirst( HYPRE_Int nrows,
|
||||
HYPRE_Int *ia,
|
||||
@ -643,7 +646,8 @@ hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A )
|
||||
dim3 gDim = hypre_GetDefaultCUDAGridDimension(hypre_CSRMatrixNumRows(A), "thread", bDim);
|
||||
|
||||
HYPRE_Int *result = hypre_TAlloc(HYPRE_Int, hypre_CSRMatrixNumRows(A), HYPRE_MEMORY_DEVICE);
|
||||
HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirst, gDim, bDim, hypre_CSRMatrixNumRows(A),
|
||||
HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirst, gDim, bDim,
|
||||
hypre_CSRMatrixNumRows(A),
|
||||
hypre_CSRMatrixI(A), hypre_CSRMatrixJ(A), result );
|
||||
|
||||
HYPRE_Int ierr = HYPRE_THRUST_CALL( reduce,
|
||||
@ -657,6 +661,62 @@ hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A )
|
||||
return ierr;
|
||||
}
|
||||
|
||||
/* check if diagonal entry is the first one at each row, and
|
||||
* assign numerical zero diag value `v'
|
||||
* Return: the number of rows that do not have the first entry as diagonal
|
||||
*/
|
||||
__global__ void
|
||||
hypreCUDAKernel_CSRCheckDiagFirstSetValueZero( HYPRE_Complex v,
|
||||
HYPRE_Int nrows,
|
||||
HYPRE_Int *ia,
|
||||
HYPRE_Int *ja,
|
||||
HYPRE_Complex *data,
|
||||
HYPRE_Int *result )
|
||||
{
|
||||
const HYPRE_Int row = hypre_cuda_get_grid_thread_id<1,1>();
|
||||
if (row < nrows)
|
||||
{
|
||||
const HYPRE_Int j = ia[row];
|
||||
const HYPRE_Int col = ja[j];
|
||||
|
||||
result[row] = col != row;
|
||||
|
||||
if (col == row && data[j] == 0.0)
|
||||
{
|
||||
data[j] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HYPRE_Int
|
||||
hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A,
|
||||
HYPRE_Complex v )
|
||||
{
|
||||
if (hypre_CSRMatrixNumRows(A) != hypre_CSRMatrixNumCols(A))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
dim3 bDim = hypre_GetDefaultCUDABlockDimension();
|
||||
dim3 gDim = hypre_GetDefaultCUDAGridDimension(hypre_CSRMatrixNumRows(A), "thread", bDim);
|
||||
|
||||
HYPRE_Int *result = hypre_TAlloc(HYPRE_Int, hypre_CSRMatrixNumRows(A), HYPRE_MEMORY_DEVICE);
|
||||
HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirstSetValueZero, gDim, bDim,
|
||||
v, hypre_CSRMatrixNumRows(A),
|
||||
hypre_CSRMatrixI(A), hypre_CSRMatrixJ(A), hypre_CSRMatrixData(A),
|
||||
result );
|
||||
|
||||
HYPRE_Int ierr = HYPRE_THRUST_CALL( reduce,
|
||||
result,
|
||||
result + hypre_CSRMatrixNumRows(A) );
|
||||
|
||||
hypre_TFree(result, HYPRE_MEMORY_DEVICE);
|
||||
|
||||
hypre_SyncCudaComputeStream(hypre_handle());
|
||||
|
||||
return ierr;
|
||||
}
|
||||
|
||||
typedef thrust::tuple<HYPRE_Int, HYPRE_Int> Int2;
|
||||
struct Int2Unequal : public thrust::unary_function<Int2, bool>
|
||||
{
|
||||
@ -1234,6 +1294,13 @@ hypre_CSRMatrixTriLowerUpperSolveCusparse(char uplo,
|
||||
hypre_CSRMatrixSortedData(A) = A_sa = hypre_TAlloc(HYPRE_Complex, nnzA, HYPRE_MEMORY_DEVICE);
|
||||
hypre_TMemcpy(A_sj, A_j, HYPRE_Int, nnzA, HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_DEVICE);
|
||||
hypre_TMemcpy(A_sa, A_a, HYPRE_Complex, nnzA, HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_DEVICE);
|
||||
|
||||
#if defined(HYPRE_USING_CUDA)
|
||||
hypre_CSRMatrixData(A) = A_sa;
|
||||
HYPRE_Int err = hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice(A, INFINITY); hypre_assert(err == 0);
|
||||
hypre_CSRMatrixData(A) = A_a;
|
||||
#endif
|
||||
|
||||
hypre_SortCSRCusparse(nrow, ncol, nnzA, A_i, A_sj, A_sa);
|
||||
}
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@ hypre_CSRMatrix* hypre_CSRMatrixAddPartialDevice( hypre_CSRMatrix *A, hypre_CSRM
|
||||
HYPRE_Int hypre_CSRMatrixColNNzRealDevice( hypre_CSRMatrix *A, HYPRE_Real *colnnz);
|
||||
HYPRE_Int hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A );
|
||||
HYPRE_Int hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A );
|
||||
HYPRE_Int hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A, HYPRE_Complex v );
|
||||
void hypre_CSRMatrixComputeRowSumDevice( hypre_CSRMatrix *A, HYPRE_Int *CF_i, HYPRE_Int *CF_j, HYPRE_Complex *row_sum, HYPRE_Int type, HYPRE_Complex scal, const char *set_or_add);
|
||||
void hypre_CSRMatrixExtractDiagonalDevice( hypre_CSRMatrix *A, HYPRE_Complex *d, HYPRE_Int type);
|
||||
hypre_CSRMatrix* hypre_CSRMatrixStack2Device(hypre_CSRMatrix *A, hypre_CSRMatrix *B);
|
||||
|
||||
@ -300,6 +300,7 @@ hypre_CSRMatrix* hypre_CSRMatrixAddPartialDevice( hypre_CSRMatrix *A, hypre_CSRM
|
||||
HYPRE_Int hypre_CSRMatrixColNNzRealDevice( hypre_CSRMatrix *A, HYPRE_Real *colnnz);
|
||||
HYPRE_Int hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A );
|
||||
HYPRE_Int hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A );
|
||||
HYPRE_Int hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A, HYPRE_Complex v );
|
||||
void hypre_CSRMatrixComputeRowSumDevice( hypre_CSRMatrix *A, HYPRE_Int *CF_i, HYPRE_Int *CF_j, HYPRE_Complex *row_sum, HYPRE_Int type, HYPRE_Complex scal, const char *set_or_add);
|
||||
void hypre_CSRMatrixExtractDiagonalDevice( hypre_CSRMatrix *A, HYPRE_Complex *d, HYPRE_Int type);
|
||||
hypre_CSRMatrix* hypre_CSRMatrixStack2Device(hypre_CSRMatrix *A, hypre_CSRMatrix *B);
|
||||
|
||||
@ -179,6 +179,7 @@ FILES="\
|
||||
${TNAME}.out.322\
|
||||
${TNAME}.out.323\
|
||||
${TNAME}.out.324\
|
||||
${TNAME}.out.325\
|
||||
"
|
||||
|
||||
for i in $FILES
|
||||
|
||||
@ -24,63 +24,94 @@ fi
|
||||
|
||||
#echo "runcheck rtol = $RTOL, atol = $ATOL"
|
||||
|
||||
awk -v filename="$SNAME" 'BEGIN{
|
||||
awk -v ofilename="$FNAME" -v sfilename="$SNAME" 'BEGIN{
|
||||
FS=" ";
|
||||
key = 0;
|
||||
saved_key = 0;
|
||||
ln = 0;
|
||||
# Read saved file data into array
|
||||
while (getline < filename)
|
||||
{
|
||||
if(NF > 0 && substr($1,1,1) !~ /#/)
|
||||
{
|
||||
saved_line[key]=$0;
|
||||
saved_array[++key]=$NF;
|
||||
}
|
||||
}
|
||||
close(filename);
|
||||
|
||||
# read out file and compare
|
||||
filename="'"$FNAME"'"
|
||||
rtol = "'"$RTOL"'" + 0. # adding zero is necessary to convert from string to number
|
||||
atol = "'"$ATOL"'" + 0. # adding zero is necessary to convert from string to number
|
||||
key=0;
|
||||
ln=0;
|
||||
pass=1;
|
||||
|
||||
while (getline < filename)
|
||||
while (getline < sfilename)
|
||||
{
|
||||
ln++;
|
||||
if(NF > 0 && substr($1,1,1) !~ /#/)
|
||||
{
|
||||
# get corresponding value in saved array
|
||||
val = saved_array[++key];
|
||||
# loop over fields in current line
|
||||
for(id=1; id<=NF; id++)
|
||||
{
|
||||
# check if field is numeric
|
||||
if($id ~ /^[0-9]+/)
|
||||
{
|
||||
ln_id[saved_key]=ln;
|
||||
saved_line[saved_key]=$0;
|
||||
saved_array[++saved_key]=$id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
close(sfilename);
|
||||
|
||||
# Read out file data into array
|
||||
out_key=0;
|
||||
ln=0;
|
||||
while (getline < ofilename)
|
||||
{
|
||||
if(NF > 0 && substr($1,1,1) !~ /#/)
|
||||
{
|
||||
# loop over fields in current line
|
||||
for(id=1; id<=NF; id++)
|
||||
{
|
||||
# check if field is numeric
|
||||
if($id ~ /^[0-9]+/)
|
||||
{
|
||||
out_line[out_key]=$0;
|
||||
out_array[++out_key]=$id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
close(ofilename);
|
||||
|
||||
# compare data arrays
|
||||
if(saved_key != out_key)
|
||||
{
|
||||
printf "Number of numeric entries do not match!!\n"
|
||||
printf "Saved file (%d entries) Output file (%d entries)\n\n", saved_key, out_key
|
||||
}
|
||||
|
||||
# compare numeric entries
|
||||
rtol = "'"$RTOL"'" + 0. # adding zero is necessary to convert from string to number
|
||||
atol = "'"$ATOL"'" + 0. # adding zero is necessary to convert from string to number
|
||||
for(id=1; id<=saved_key; id++)
|
||||
{
|
||||
# get value from arrays
|
||||
saved_val = saved_array[id];
|
||||
out_val = out_array[id];
|
||||
|
||||
# floating point field comparison
|
||||
if($NF != int($NF))
|
||||
if(length(saved_val) != length(int(saved_val)) && length(out_val) != length(int(out_val)))
|
||||
{
|
||||
err = val - $NF;
|
||||
# get absolute value of err and val
|
||||
err = saved_val - out_val;
|
||||
# get absolute value of err and saved_val
|
||||
err = err < 0 ? -err : err;
|
||||
val = val < 0 ? -val : val;
|
||||
saved_val = saved_val < 0 ? -saved_val : saved_val;
|
||||
# abs err <= atol or rel err <= rtol
|
||||
if(err <= atol || err <= rtol*val)
|
||||
if(err <= atol || err <= rtol*saved_val)
|
||||
{
|
||||
#print "PASSED"
|
||||
}
|
||||
else
|
||||
{
|
||||
pass=0;
|
||||
printf "(%d) - %s\n", ln, saved_line[key-1]
|
||||
printf "(%d) + %s (err %.2e)\n\n", ln, $0, err
|
||||
#printf "(%d) + %s <-- %s, err %.2e\n", ln, $0, val, err
|
||||
printf "(%d) - %s\n", ln_id[id-1], saved_line[id-1]
|
||||
printf "(%d) + %s (err %.2e)\n\n", ln_id[id-1], out_line[id-1], err
|
||||
}
|
||||
}
|
||||
else # integer comparison
|
||||
else if(length(saved_val) == length(int(saved_val)) && length(out_val) == length(int(out_val)))# integer comparison
|
||||
{
|
||||
tau = val - $NF;
|
||||
tau = saved_val - out_val;
|
||||
# get absolute value of tau
|
||||
tau = tau < 0 ? -tau : tau;
|
||||
# get ceiling of rtol*val (= max allowed change)
|
||||
gamma = int(1.0 + rtol*val);
|
||||
# get ceiling of rtol*saved_val (= max allowed change)
|
||||
gamma = int(1.0 + rtol*saved_val);
|
||||
if(tau <= gamma)
|
||||
{
|
||||
#print "PASSED"
|
||||
@ -88,15 +119,15 @@ awk -v filename="$SNAME" 'BEGIN{
|
||||
else
|
||||
{
|
||||
pass=0;
|
||||
printf "(%d) %s <-- %s, err %d\n", ln, $0, val, tau
|
||||
printf "(%d) %s <-- %s, err %d\n", ln, $0, saved_val, tau
|
||||
}
|
||||
}
|
||||
else # type mismatch
|
||||
{
|
||||
printf "Numeric type mismatch in floating point or integer comparison!!\n"
|
||||
printf "(%d) - %s \n", ln_id[id-1], saved_line[id-1]
|
||||
printf "(%d) + %s \n\n", ln_id[id-1], out_line[id-1]
|
||||
}
|
||||
}
|
||||
}'
|
||||
|
||||
#if [ "x$PASSFAIL" != "x" ];
|
||||
#then
|
||||
# echo $PASSFAIL
|
||||
# diff -U3 -bI"time" $SNAME $FNAME >&2
|
||||
#fi
|
||||
|
||||
Loading…
Reference in New Issue
Block a user