From 2186a8fb343f84f109405eed3e298cb1a2615e95 Mon Sep 17 00:00:00 2001 From: Ruipeng Li Date: Fri, 15 Jan 2021 20:46:59 -0800 Subject: [PATCH] triangular solve on GPUs; runcheck (#256) This PR fixes triangular solve on GPUs, and runcheck.sh Co-authored-by: Daniel Osei-Kuffuor --- AUTOTEST/machine-lassen.sh | 2 +- AUTOTEST/machine-ray.sh | 2 +- src/parcsr_mv/par_csr_matrix.c | 8 +- src/seq_mv/csr_matop_device.c | 69 ++++++++++++++- src/seq_mv/protos.h | 1 + src/seq_mv/seq_mv.h | 1 + src/test/TEST_ij/solvers.sh | 1 + src/test/runcheck.sh | 155 ++++++++++++++++++++------------- 8 files changed, 172 insertions(+), 67 deletions(-) diff --git a/AUTOTEST/machine-lassen.sh b/AUTOTEST/machine-lassen.sh index 4c02beb49..7ee578296 100755 --- a/AUTOTEST/machine-lassen.sh +++ b/AUTOTEST/machine-lassen.sh @@ -38,7 +38,7 @@ shift # Basic build and run tests mo="-j test" eo="" -roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 8e-3" +roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 2e-2" ross="-struct -sstruct -rt -mpibind -rtol 1e-6 -atol 1e-6" rost="-struct -rt -mpibind -rtol 1e-8 -atol 1e-8" rocuda="-cuda_lassen -rt -mpibind" diff --git a/AUTOTEST/machine-ray.sh b/AUTOTEST/machine-ray.sh index 2db70f67c..dbbef44d1 100755 --- a/AUTOTEST/machine-ray.sh +++ b/AUTOTEST/machine-ray.sh @@ -38,7 +38,7 @@ shift # Basic build and run tests mo="-j test" eo="" -roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 8e-3" +roij="-ij -ams -rt -mpibind -rtol 1e-3 -atol 2e-2" ross="-struct -sstruct -rt -mpibind -rtol 1e-6 -atol 1e-6" rost="-struct -rt -mpibind -rtol 1e-8 -atol 1e-8" rocuda="-cuda_ray -rt -mpibind" diff --git a/src/parcsr_mv/par_csr_matrix.c b/src/parcsr_mv/par_csr_matrix.c index 830040263..c31b34115 100644 --- a/src/parcsr_mv/par_csr_matrix.c +++ b/src/parcsr_mv/par_csr_matrix.c @@ -320,8 +320,9 @@ hypre_ParCSRMatrixMigrate(hypre_ParCSRMatrix *A, HYPRE_MemoryLocation memory_loc return hypre_error_flag; } - if ( hypre_GetActualMemLocation(memory_location) != - hypre_GetActualMemLocation(hypre_ParCSRMatrixMemoryLocation(A)) ) + HYPRE_MemoryLocation old_memory_location = hypre_ParCSRMatrixMemoryLocation(A); + + if ( hypre_GetActualMemLocation(memory_location) != hypre_GetActualMemLocation(old_memory_location) ) { hypre_CSRMatrix *A_diag = hypre_CSRMatrixClone_v2(hypre_ParCSRMatrixDiag(A), 1, memory_location); hypre_CSRMatrixDestroy(hypre_ParCSRMatrixDiag(A)); @@ -330,6 +331,9 @@ hypre_ParCSRMatrixMigrate(hypre_ParCSRMatrix *A, HYPRE_MemoryLocation memory_loc hypre_CSRMatrix *A_offd = hypre_CSRMatrixClone_v2(hypre_ParCSRMatrixOffd(A), 1, memory_location); hypre_CSRMatrixDestroy(hypre_ParCSRMatrixOffd(A)); hypre_ParCSRMatrixOffd(A) = A_offd; + + hypre_TFree(hypre_ParCSRMatrixRowindices(A), old_memory_location); + hypre_TFree(hypre_ParCSRMatrixRowvalues(A), old_memory_location); } else { diff --git a/src/seq_mv/csr_matop_device.c b/src/seq_mv/csr_matop_device.c index 953817055..4469958c0 100644 --- a/src/seq_mv/csr_matop_device.c +++ b/src/seq_mv/csr_matop_device.c @@ -618,6 +618,9 @@ hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A ) return hypre_error_flag; } +/* check if diagonal entry is the first one at each row + * Return: the number of rows that do not have the first entry as diagonal + */ __global__ void hypreCUDAKernel_CSRCheckDiagFirst( HYPRE_Int nrows, HYPRE_Int *ia, @@ -643,7 +646,8 @@ hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A ) dim3 gDim = hypre_GetDefaultCUDAGridDimension(hypre_CSRMatrixNumRows(A), "thread", bDim); HYPRE_Int *result = hypre_TAlloc(HYPRE_Int, hypre_CSRMatrixNumRows(A), HYPRE_MEMORY_DEVICE); - HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirst, gDim, bDim, hypre_CSRMatrixNumRows(A), + HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirst, gDim, bDim, + hypre_CSRMatrixNumRows(A), hypre_CSRMatrixI(A), hypre_CSRMatrixJ(A), result ); HYPRE_Int ierr = HYPRE_THRUST_CALL( reduce, @@ -657,6 +661,62 @@ hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A ) return ierr; } +/* check if diagonal entry is the first one at each row, and + * assign numerical zero diag value `v' + * Return: the number of rows that do not have the first entry as diagonal + */ +__global__ void +hypreCUDAKernel_CSRCheckDiagFirstSetValueZero( HYPRE_Complex v, + HYPRE_Int nrows, + HYPRE_Int *ia, + HYPRE_Int *ja, + HYPRE_Complex *data, + HYPRE_Int *result ) +{ + const HYPRE_Int row = hypre_cuda_get_grid_thread_id<1,1>(); + if (row < nrows) + { + const HYPRE_Int j = ia[row]; + const HYPRE_Int col = ja[j]; + + result[row] = col != row; + + if (col == row && data[j] == 0.0) + { + data[j] = v; + } + } +} + +HYPRE_Int +hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A, + HYPRE_Complex v ) +{ + if (hypre_CSRMatrixNumRows(A) != hypre_CSRMatrixNumCols(A)) + { + return -1; + } + + dim3 bDim = hypre_GetDefaultCUDABlockDimension(); + dim3 gDim = hypre_GetDefaultCUDAGridDimension(hypre_CSRMatrixNumRows(A), "thread", bDim); + + HYPRE_Int *result = hypre_TAlloc(HYPRE_Int, hypre_CSRMatrixNumRows(A), HYPRE_MEMORY_DEVICE); + HYPRE_CUDA_LAUNCH( hypreCUDAKernel_CSRCheckDiagFirstSetValueZero, gDim, bDim, + v, hypre_CSRMatrixNumRows(A), + hypre_CSRMatrixI(A), hypre_CSRMatrixJ(A), hypre_CSRMatrixData(A), + result ); + + HYPRE_Int ierr = HYPRE_THRUST_CALL( reduce, + result, + result + hypre_CSRMatrixNumRows(A) ); + + hypre_TFree(result, HYPRE_MEMORY_DEVICE); + + hypre_SyncCudaComputeStream(hypre_handle()); + + return ierr; +} + typedef thrust::tuple Int2; struct Int2Unequal : public thrust::unary_function { @@ -1234,6 +1294,13 @@ hypre_CSRMatrixTriLowerUpperSolveCusparse(char uplo, hypre_CSRMatrixSortedData(A) = A_sa = hypre_TAlloc(HYPRE_Complex, nnzA, HYPRE_MEMORY_DEVICE); hypre_TMemcpy(A_sj, A_j, HYPRE_Int, nnzA, HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_DEVICE); hypre_TMemcpy(A_sa, A_a, HYPRE_Complex, nnzA, HYPRE_MEMORY_DEVICE, HYPRE_MEMORY_DEVICE); + +#if defined(HYPRE_USING_CUDA) + hypre_CSRMatrixData(A) = A_sa; + HYPRE_Int err = hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice(A, INFINITY); hypre_assert(err == 0); + hypre_CSRMatrixData(A) = A_a; +#endif + hypre_SortCSRCusparse(nrow, ncol, nnzA, A_i, A_sj, A_sa); } diff --git a/src/seq_mv/protos.h b/src/seq_mv/protos.h index c5dc3824c..675e70a9f 100644 --- a/src/seq_mv/protos.h +++ b/src/seq_mv/protos.h @@ -34,6 +34,7 @@ hypre_CSRMatrix* hypre_CSRMatrixAddPartialDevice( hypre_CSRMatrix *A, hypre_CSRM HYPRE_Int hypre_CSRMatrixColNNzRealDevice( hypre_CSRMatrix *A, HYPRE_Real *colnnz); HYPRE_Int hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A ); HYPRE_Int hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A ); +HYPRE_Int hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A, HYPRE_Complex v ); void hypre_CSRMatrixComputeRowSumDevice( hypre_CSRMatrix *A, HYPRE_Int *CF_i, HYPRE_Int *CF_j, HYPRE_Complex *row_sum, HYPRE_Int type, HYPRE_Complex scal, const char *set_or_add); void hypre_CSRMatrixExtractDiagonalDevice( hypre_CSRMatrix *A, HYPRE_Complex *d, HYPRE_Int type); hypre_CSRMatrix* hypre_CSRMatrixStack2Device(hypre_CSRMatrix *A, hypre_CSRMatrix *B); diff --git a/src/seq_mv/seq_mv.h b/src/seq_mv/seq_mv.h index fdb820bd9..7c85462b6 100644 --- a/src/seq_mv/seq_mv.h +++ b/src/seq_mv/seq_mv.h @@ -300,6 +300,7 @@ hypre_CSRMatrix* hypre_CSRMatrixAddPartialDevice( hypre_CSRMatrix *A, hypre_CSRM HYPRE_Int hypre_CSRMatrixColNNzRealDevice( hypre_CSRMatrix *A, HYPRE_Real *colnnz); HYPRE_Int hypre_CSRMatrixMoveDiagFirstDevice( hypre_CSRMatrix *A ); HYPRE_Int hypre_CSRMatrixCheckDiagFirstDevice( hypre_CSRMatrix *A ); +HYPRE_Int hypre_CSRMatrixCheckDiagFirstSetValueZeroDevice( hypre_CSRMatrix *A, HYPRE_Complex v ); void hypre_CSRMatrixComputeRowSumDevice( hypre_CSRMatrix *A, HYPRE_Int *CF_i, HYPRE_Int *CF_j, HYPRE_Complex *row_sum, HYPRE_Int type, HYPRE_Complex scal, const char *set_or_add); void hypre_CSRMatrixExtractDiagonalDevice( hypre_CSRMatrix *A, HYPRE_Complex *d, HYPRE_Int type); hypre_CSRMatrix* hypre_CSRMatrixStack2Device(hypre_CSRMatrix *A, hypre_CSRMatrix *B); diff --git a/src/test/TEST_ij/solvers.sh b/src/test/TEST_ij/solvers.sh index 781183f32..6f804df88 100755 --- a/src/test/TEST_ij/solvers.sh +++ b/src/test/TEST_ij/solvers.sh @@ -179,6 +179,7 @@ FILES="\ ${TNAME}.out.322\ ${TNAME}.out.323\ ${TNAME}.out.324\ + ${TNAME}.out.325\ " for i in $FILES diff --git a/src/test/runcheck.sh b/src/test/runcheck.sh index ab60ac6d0..84eb7407b 100755 --- a/src/test/runcheck.sh +++ b/src/test/runcheck.sh @@ -24,79 +24,110 @@ fi #echo "runcheck rtol = $RTOL, atol = $ATOL" -awk -v filename="$SNAME" 'BEGIN{ +awk -v ofilename="$FNAME" -v sfilename="$SNAME" 'BEGIN{ FS=" "; - key = 0; + saved_key = 0; + ln = 0; # Read saved file data into array - while (getline < filename) - { - if(NF > 0 && substr($1,1,1) !~ /#/) - { - saved_line[key]=$0; - saved_array[++key]=$NF; - } - } - close(filename); - - # read out file and compare - filename="'"$FNAME"'" - rtol = "'"$RTOL"'" + 0. # adding zero is necessary to convert from string to number - atol = "'"$ATOL"'" + 0. # adding zero is necessary to convert from string to number - key=0; - ln=0; - pass=1; - - while (getline < filename) + while (getline < sfilename) { ln++; if(NF > 0 && substr($1,1,1) !~ /#/) { - # get corresponding value in saved array - val = saved_array[++key]; - - # floating point field comparison - if($NF != int($NF)) + # loop over fields in current line + for(id=1; id<=NF; id++) { - err = val - $NF; - # get absolute value of err and val - err = err < 0 ? -err : err; - val = val < 0 ? -val : val; - # abs err <= atol or rel err <= rtol - if(err <= atol || err <= rtol*val) + # check if field is numeric + if($id ~ /^[0-9]+/) { - #print "PASSED" - } - else - { - pass=0; - printf "(%d) - %s\n", ln, saved_line[key-1] - printf "(%d) + %s (err %.2e)\n\n", ln, $0, err - #printf "(%d) + %s <-- %s, err %.2e\n", ln, $0, val, err - } - } - else # integer comparison - { - tau = val - $NF; - # get absolute value of tau - tau = tau < 0 ? -tau : tau; - # get ceiling of rtol*val (= max allowed change) - gamma = int(1.0 + rtol*val); - if(tau <= gamma) - { - #print "PASSED" - } - else - { - pass=0; - printf "(%d) %s <-- %s, err %d\n", ln, $0, val, tau + ln_id[saved_key]=ln; + saved_line[saved_key]=$0; + saved_array[++saved_key]=$id; } } } } + close(sfilename); + + # Read out file data into array + out_key=0; + ln=0; + while (getline < ofilename) + { + if(NF > 0 && substr($1,1,1) !~ /#/) + { + # loop over fields in current line + for(id=1; id<=NF; id++) + { + # check if field is numeric + if($id ~ /^[0-9]+/) + { + out_line[out_key]=$0; + out_array[++out_key]=$id; + } + } + } + } + close(ofilename); + + # compare data arrays + if(saved_key != out_key) + { + printf "Number of numeric entries do not match!!\n" + printf "Saved file (%d entries) Output file (%d entries)\n\n", saved_key, out_key + } + + # compare numeric entries + rtol = "'"$RTOL"'" + 0. # adding zero is necessary to convert from string to number + atol = "'"$ATOL"'" + 0. # adding zero is necessary to convert from string to number + for(id=1; id<=saved_key; id++) + { + # get value from arrays + saved_val = saved_array[id]; + out_val = out_array[id]; + + # floating point field comparison + if(length(saved_val) != length(int(saved_val)) && length(out_val) != length(int(out_val))) + { + err = saved_val - out_val; + # get absolute value of err and saved_val + err = err < 0 ? -err : err; + saved_val = saved_val < 0 ? -saved_val : saved_val; + # abs err <= atol or rel err <= rtol + if(err <= atol || err <= rtol*saved_val) + { + #print "PASSED" + } + else + { + pass=0; + printf "(%d) - %s\n", ln_id[id-1], saved_line[id-1] + printf "(%d) + %s (err %.2e)\n\n", ln_id[id-1], out_line[id-1], err + } + } + else if(length(saved_val) == length(int(saved_val)) && length(out_val) == length(int(out_val)))# integer comparison + { + tau = saved_val - out_val; + # get absolute value of tau + tau = tau < 0 ? -tau : tau; + # get ceiling of rtol*saved_val (= max allowed change) + gamma = int(1.0 + rtol*saved_val); + if(tau <= gamma) + { + #print "PASSED" + } + else + { + pass=0; + printf "(%d) %s <-- %s, err %d\n", ln, $0, saved_val, tau + } + } + else # type mismatch + { + printf "Numeric type mismatch in floating point or integer comparison!!\n" + printf "(%d) - %s \n", ln_id[id-1], saved_line[id-1] + printf "(%d) + %s \n\n", ln_id[id-1], out_line[id-1] + } + } }' -#if [ "x$PASSFAIL" != "x" ]; -#then -# echo $PASSFAIL -# diff -U3 -bI"time" $SNAME $FNAME >&2 -#fi