NVIDIA / CUDALibrarySamples

CUDA Library Samples
Other
1.5k stars 311 forks source link

Incorrect implementation of cuSPARSE/bicgstab and cuSPARSE/CG #197

Closed ninotreve closed 1 month ago

ninotreve commented 1 month ago

BiCGSTAB issue

Description The implementation of the BiCGStab algorithm in cuSPARSE/bicgstab/bicgstab_example.c contains an error that leads to inefficient convergence. Specifically, according to the BiCGStab algorithm, the operation at line 4 of the pseudo code should only be triggered when i = 1. However, in the current implementation, this operation is triggered without the appropriate if clause.

Steps to Reproduce

  1. Clone the repository and navigate to the BiCGStab example.
  2. Compile and run the bicgstab example with the provided data.
  3. Observe that the algorithm converges after 78 iterations.

Expected Behavior The BiCGStab algorithm should converge in 13 iterations when implemented correctly.

Actual Behavior The algorithm converges after 78 iterations due to the incorrect condition at line 4 of the pseudo code.

Suggested Fix: Upon reviewing the implementation, I noticed that line 246 to 249 of cuSPARSE/bicgstab/bicgstab_example.c should be changed in the following way:

Incorrect Code:

        // ### 4, 7 ### P_i = R_i
        CHECK_CUDA( cudaMemcpy(d_P.ptr, d_R.ptr, m * sizeof(double),
                               cudaMemcpyDeviceToDevice) )
        if (i > 1) {

to

Correct Code:

        if (i == 1) {
        // ### 4 ### P_i = R_i
        CHECK_CUDA( cudaMemcpy(d_P.ptr, d_R.ptr, m * sizeof(double),
                               cudaMemcpyDeviceToDevice) )
        } else {

(It is impossible to assign p = r first, because the following code still depend on p at the last iterate.)

With this fix, the algorithm should converge in the expected 13 iterations instead of 78.

BiCGStab loop:
  Initial Residual: Norm 1.670180e+02' threshold 1.670180e-08
  Iteration = 1; Error Norm = 1.670180e+02
  Iteration = 2; Error Norm = 1.817961e+01
  Iteration = 3; Error Norm = 1.039428e+00
  Iteration = 4; Error Norm = 1.606126e-01
  Iteration = 5; Error Norm = 2.878677e-02
  Iteration = 6; Error Norm = 5.218989e-03
  Iteration = 7; Error Norm = 9.456954e-04
  Iteration = 8; Error Norm = 1.724395e-04
  Iteration = 9; Error Norm = 3.169380e-05
  Iteration = 10; Error Norm = 5.875212e-06
  Iteration = 11; Error Norm = 1.100554e-06
  Iteration = 12; Error Norm = 2.083203e-07
  Iteration = 13; Error Norm = 4.015379e-08
Check Solution
Final error norm = 1.522333e-08

CG issue

Description Similarly, there exist incorrect implementation of the CG procedure that leads to inefficient convergence.

According to the pseudo code, line 242 should be

    CHECK_CUBLAS( cublasDdot(cublasHandle, m, d_R.ptr, 1, d_R_aux.ptr, 1, &delta) )

And line 312 to 318 will not yield the correct result. The suggested fix is:

        //    (a) P = beta * P
        CHECK_CUBLAS( cublasDscal(cublasHandle, m, &beta, d_P.ptr, 1) )
        //    (b) P = R_aux + P
        CHECK_CUBLAS( cublasDaxpy(cublasHandle, m, &one, d_R_aux.ptr, 1,
                                    d_P.ptr, 1) )

With the above fixes, the algorithm should converge in 40 iterations instead of 88. (I suspect there is also unknown bug in ichol decomposition otherwise it should converge faster)

CG loop:
  Initial Residual: Norm 4.633034e+01' threshold 4.633034e-07
  Iteration = 0; Error Norm = 4.633034e+01
  Iteration = 1; Error Norm = 4.350999e+01
  Iteration = 2; Error Norm = 4.540506e+01
  Iteration = 3; Error Norm = 2.375638e+01
  Iteration = 4; Error Norm = 1.164941e+01
  Iteration = 5; Error Norm = 6.184499e+00
  Iteration = 6; Error Norm = 3.519462e+00
  Iteration = 7; Error Norm = 2.090597e+00
  Iteration = 8; Error Norm = 1.272138e+00
  Iteration = 9; Error Norm = 7.866908e-01
  Iteration = 10; Error Norm = 4.934229e-01
  Iteration = 11; Error Norm = 3.126723e-01
  Iteration = 12; Error Norm = 1.991794e-01
  Iteration = 13; Error Norm = 1.272908e-01
  Iteration = 14; Error Norm = 8.146527e-02
  Iteration = 15; Error Norm = 5.208671e-02
  Iteration = 16; Error Norm = 3.330293e-02
  Iteration = 17; Error Norm = 2.132552e-02
  Iteration = 18; Error Norm = 1.365135e-02
  Iteration = 19; Error Norm = 8.731633e-03
  Iteration = 20; Error Norm = 5.584883e-03
  Iteration = 21; Error Norm = 3.560831e-03
  Iteration = 22; Error Norm = 2.260165e-03
  Iteration = 23; Error Norm = 1.434660e-03
  Iteration = 24; Error Norm = 9.110929e-04
  Iteration = 25; Error Norm = 5.767032e-04
  Iteration = 26; Error Norm = 3.646304e-04
  Iteration = 27; Error Norm = 2.308476e-04
  Iteration = 28; Error Norm = 1.457810e-04
  Iteration = 29; Error Norm = 9.181084e-05
  Iteration = 30; Error Norm = 5.791789e-05
  Iteration = 31; Error Norm = 3.656195e-05
  Iteration = 32; Error Norm = 2.305098e-05
  Iteration = 33; Error Norm = 1.455168e-05
  Iteration = 34; Error Norm = 9.194049e-06
  Iteration = 35; Error Norm = 5.798299e-06
  Iteration = 36; Error Norm = 3.659231e-06
  Iteration = 37; Error Norm = 2.315736e-06
  Iteration = 38; Error Norm = 1.465358e-06
  Iteration = 39; Error Norm = 9.266568e-07
  Iteration = 40; Error Norm = 5.872029e-07
Check Solution
Final error norm = 3.726714e-07
essex-edwards commented 1 month ago

Thank you for making this bug report. I agree that these look like mistakes in the example code. We'll post an update here when we fix it.

essex-edwards commented 1 month ago

@JanuszL the label should be cuSPARSE, not cuSPARSELt. Could you change that?

essex-edwards commented 1 month ago

@ninotreve This has been fixed in https://github.com/NVIDIA/CUDALibrarySamples/commit/8cd8d122ed8a0e916ec761ffa993fb4174ed7c82. There was one additional bug in the code, applying L^-1 L^-T instead of L^-T L^-1. With the fixes in that commit, the results exactly match Octave's bicgstab+ilu and pcg+ichol.

ninotreve commented 1 month ago

Cheers!