lattice / quda

QUDA is a library for performing calculations in lattice QCD on GPUs.
https://lattice.github.io/quda
Other
279 stars 94 forks source link

Add multi-RHS support throughout QUDA #1465

Closed maddyscientist closed 1 month ago

maddyscientist commented 2 months ago

A long-time coming, this PR adds support for multi-RHS to all stencil kernels in QUDA:

With respect to performance, the performance gain from multi-RHS is dependent on the operator in question, but 2-2.5x is observed for Wilson-like operators at large volume, in line with roofline projections. For smaller volumes, the gain can exceed roofline expectations due to additional gain from parallelism.

Subsequent work will focus on the addition of multi-RHS solvers. This PR is focused on getting the kernels available. Worth noting that the block Lanczos solver works out of the box with this PR though.

This PR is ready for review, but there are few things outstanding:

maddyscientist commented 2 months ago

@Marcogarofalo @kostrzewa could you test tmLQCD with this branch please if you get a chance, including the force calculations? You should see a nice performance boost for the clover force, but good to verify correctness 😄

Marcogarofalo commented 2 months ago

@maddyscientist thank you for the update. I check with tmLQCD the correctness and everything seems ok. I see a performance boost on a 64^3x128 with 4 nodes of jewels-booster (4 × NVIDIA A100) computeTMCloverForceQuda with nvector=5

with nvector=3

maddyscientist commented 2 months ago

Thanks for confirming at your end @Marcogarofalo.

Another ask: could you send me the output from the QUDA profile for the HMC runs for tmLQCD (both the stdout printed at the end of the run and the tsk files)? I suspect that the main limiter for the clover force at present is the communication, but I would like to confirm this. There is 2x more communication being done than is needed in this function, and I'm wondering how much of an overhead this is. Thx

Marcogarofalo commented 2 months ago

@maddyscientist, here is the stdout output of the QUDA profile. I did not find any tsk file, could you please tell me how to generate them?


               initQuda Total time =     0.433 secs
                     init     =     0.433 secs ( 99.996%),   with        2 calls at 2.163e+05 us per call
        total accounted       =     0.433 secs ( 99.996%)
        total missing         =     0.000 secs (  0.004%)

          loadGaugeQuda Total time =     2.002 secs
                 download     =     1.867 secs ( 93.249%),   with       13 calls at 1.436e+05 us per call
                     init     =     0.093 secs (  4.645%),   with       91 calls at 1.022e+03 us per call
                  compute     =     0.040 secs (  1.989%),   with       65 calls at 6.126e+02 us per call
                     free     =     0.000 secs (  0.002%),   with      364 calls at 1.016e-01 us per call
        total accounted       =     2.000 secs ( 99.885%)
        total missing         =     0.002 secs (  0.115%)

         loadCloverQuda Total time =     1.027 secs
                     init     =     0.082 secs (  8.004%),   with       65 calls at 1.265e+03 us per call
                  compute     =     0.153 secs ( 14.848%),   with       26 calls at 5.866e+03 us per call
                    comms     =     0.456 secs ( 44.357%),   with       13 calls at 3.505e+04 us per call
                     free     =     0.096 secs (  9.300%),   with      410 calls at 2.330e+02 us per call
        total accounted       =     0.786 secs ( 76.509%)
        total missing         =     0.241 secs ( 23.491%)

             invertQuda Total time =   461.892 secs
                 download     =     2.089 secs (  0.452%),   with       30 calls at 6.962e+04 us per call
                   upload     =     0.812 secs (  0.176%),   with       30 calls at 2.708e+04 us per call
                     init     =    13.152 secs (  2.848%),   with   344740 calls at 3.815e+01 us per call
                 preamble     =    58.827 secs ( 12.736%),   with   409205 calls at 1.438e+02 us per call
                  compute     =   342.744 secs ( 74.205%),   with  7419610 calls at 4.619e+01 us per call
                 epilogue     =     2.413 secs (  0.523%),   with    10412 calls at 2.318e+02 us per call
                     free     =    12.484 secs (  2.703%),   with  6999292 calls at 1.784e+00 us per call
                    eigen     =     0.100 secs (  0.022%),   with    63079 calls at 1.593e+00 us per call
        total accounted       =   432.623 secs ( 93.663%)
        total missing         =    29.269 secs (  6.337%)

   invertMultiShiftQuda Total time =   184.085 secs
                 download     =     1.069 secs (  0.581%),   with       26 calls at 4.112e+04 us per call
                   upload     =    10.385 secs (  5.642%),   with      140 calls at 7.418e+04 us per call
                     init     =     2.766 secs (  1.502%),   with    11928 calls at 2.319e+02 us per call
                 preamble     =     0.750 secs (  0.408%),   with      192 calls at 3.907e+03 us per call
                  compute     =   166.006 secs ( 90.179%),   with   741854 calls at 2.238e+02 us per call
                 epilogue     =     1.892 secs (  1.028%),   with     5092 calls at 3.716e+02 us per call
                     free     =     1.052 secs (  0.572%),   with   757778 calls at 1.389e+00 us per call
        total accounted       =   183.920 secs ( 99.910%)
        total missing         =     0.165 secs (  0.090%)

   computeGaugeForceQuda Total time =     1.798 secs
                 download     =     0.694 secs ( 38.570%),   with        7 calls at 9.908e+04 us per call
                   upload     =     0.553 secs ( 30.747%),   with        7 calls at 7.898e+04 us per call
                     init     =     0.041 secs (  2.258%),   with       77 calls at 5.273e+02 us per call
                  compute     =     0.269 secs ( 14.963%),   with        7 calls at 3.844e+04 us per call
                    comms     =     0.183 secs ( 10.202%),   with        7 calls at 2.621e+04 us per call
                     free     =     0.032 secs (  1.779%),   with     1106 calls at 2.892e+01 us per call
        total accounted       =     1.772 secs ( 98.518%)
        total missing         =     0.027 secs (  1.482%)

   computeTMCloverForceQuda Total time =    10.860 secs
                 download     =     2.321 secs ( 21.369%),   with       60 calls at 3.868e+04 us per call
                   upload     =     2.000 secs ( 18.417%),   with       24 calls at 8.334e+04 us per call
                     init     =     0.353 secs (  3.247%),   with      631 calls at 5.588e+02 us per call
                  compute     =     1.923 secs ( 17.705%),   with     1614 calls at 1.191e+03 us per call
                    comms     =     3.290 secs ( 30.289%),   with      200 calls at 1.645e+04 us per call
                     free     =     0.552 secs (  5.087%),   with     6215 calls at 8.890e+01 us per call
        total accounted       =    10.438 secs ( 96.115%)
        total missing         =     0.422 secs (  3.885%)

                endQuda Total time =     0.909 secs
                     free     =     0.033 secs (  3.603%),   with       92 calls at 3.560e+02 us per call
        total accounted       =     0.033 secs (  3.603%)
        total missing         =     0.876 secs ( 96.397%)

       initQuda-endQuda Total time =   818.633 secs

                   QUDA Total time =   663.006 secs
                 download     =     8.039 secs (  1.213%),   with      136 calls at 5.911e+04 us per call
                   upload     =    13.751 secs (  2.074%),   with      201 calls at 6.841e+04 us per call
                     init     =    16.924 secs (  2.553%),   with   357534 calls at 4.734e+01 us per call
                 preamble     =    59.582 secs (  8.987%),   with   409397 calls at 1.455e+02 us per call
                  compute     =   511.220 secs ( 77.106%),   with  8163176 calls at 6.263e+01 us per call
                    comms     =     3.929 secs (  0.593%),   with      220 calls at 1.786e+04 us per call
                 epilogue     =     4.306 secs (  0.649%),   with    15504 calls at 2.777e+02 us per call
                     free     =    14.310 secs (  2.158%),   with  7765257 calls at 1.843e+00 us per call
                    eigen     =     0.101 secs (  0.015%),   with    63079 calls at 1.601e+00 us per call
        total accounted       =   632.162 secs ( 95.348%)
        total missing         =    30.845 secs (  4.652%)

Device memory used = 32039.1 MiB
Pinned device memory used = 3024.0 MiB
Managed memory used = 0.0 MiB
Shmem memory used = 0.0 MiB
Page-locked host memory used = 3040.4 MiB
Total host memory used >= 3623.8 MiB
kostrzewa commented 1 month ago

@maddyscientist Apologies for the delay here. Please find below two profiles.

QUDA develop, 8 trajectories, 64c128 at phys. point

QUDA feature/mrhs, 4 trajectories, 64c128 at phys. point

weinbe2 commented 1 month ago

Somewhat orthogonal comment to this PR, but while we're here can I request that we rename covDev.cu[h] to covariant_derivative.cu[h] or something of the sort? It's the only outlier file in terms of the naming convention, and iirc git should note that it's just a "rename" so we won't lose the commit history.

weinbe2 commented 1 month ago

Ghost accessor question---right now, source index offsets to ghost accessors are performed via explicit pointer arithmetic, a representative case below:

const Vector in = arg.halo.Ghost(d, 1, ghost_idx + src_idx * arg.dc.ghostFaceCB[d], their_spinor_parity);

Where the pointer arithmetic I'm referring to is ghost_idx + src_idx * arg.dc.ghostFaceCB[d]. Can we make src_idx a separate argument so it looks something more like arg.halo.Ghost(d, 1, ghost_idx, src_idx, their_spinor_parity)? Perhaps ghostFaceCB could be a field of the halo accessor.

The reason I'm asking this question is because I see this explicit pointer offsetting as antithetical to an accessor---implementation details should be under the hood. It would also make it more challenging to quickly test some type of hypothetical multi-rhs packed halo format in the future. (I guess this concern is also true with the arrays of "local" accessors, but one thought at a time.)

weinbe2 commented 1 month ago

Not the highest priority, but it should be easy to apply the multi-rhs optimizations to the KD operator (lib/staggered_kd_apply_xinv.cu and include/kernels/staggered_kd_apply_xinv_kernel.cuh). Testing it is another story... I need to get it exposed directly through staggered_invert_test, etc. Feel free to punt this back to me for the future, just file an issue about it.

maddyscientist commented 1 month ago

Not the highest priority, but it should be easy to apply the multi-rhs optimizations to the KD operator (lib/staggered_kd_apply_xinv.cu and include/kernels/staggered_kd_apply_xinv_kernel.cuh). Testing it is another story... I need to get it exposed directly through staggered_invert_test, etc. Feel free to punt this back to me for the future, just file an issue about it.

This is an easy ask, so I've just gone and done it (071ccf717b61673c9fec664ddab3e59801a46772).

kostrzewa commented 1 month ago

@maddyscientist I can't pinpoint exactly yet what is happening but it seems that this is causing either some tuning issues or even some kind of lockup. I've tested this latest commit 142ff on a quad-A100 node doing 2+1+1 twisted clover HMC where it takes a very long time to tune (not sure which kernel, would need to run with higher debug levels). I'm also testing an an octo-A40 node (also 2+1+1 HMC) where it tunes for a while and then appears to get stuck, the last output being.

On the A100 I believe the slow tuning was because the device memory was almost full as the job was killed by cudaMalloc shortly thereafter. This also means that with feature/mrhs in place, our HMC runs will require more device memory (and thus more nodes in some cases). Are any mechanisms in place to get the "original" behaviour back (in order to save device memory) for those cases where this would be more efficient or enable one to run at all (in cases where the number of nodes on a system is limited, for example)?

maddyscientist commented 1 month ago

@maddyscientist I can't pinpoint exactly yet what is happening but it seems that this is causing either some tuning issues or even some kind of lockup. I've tested this latest commit 142ff on a quad-A100 node doing 2+1+1 twisted clover HMC where it takes a very long time to tune (not sure which kernel, would need to run with higher debug levels). I'm also testing an an octo-A40 node (also 2+1+1 HMC) where it tunes for a while and then appears to get stuck, the last output being.

On the A100 I believe the slow tuning was because the device memory was almost full as the job was killed by cudaMalloc shortly thereafter. This also means that with feature/mrhs in place, our HMC runs will require more device memory (and thus more nodes in some cases). Are any mechanisms in place to get the "original" behaviour back (in order to save device memory) for those cases where this would be more efficient or enable one to run at all (in cases where the number of nodes on a system is limited, for example)?

There is a cmake parameter that sets the maximum rhs allowable, and beyond that the code will split and recurse until the maximum value is met. I suppose if we set that to 1 the original behavior would be restored.

Also highlighting the env QUDA_ENABLE_DEVICE_MEMORY_POOL, if you disable the pool setting it 0, this should noticeably the reduce memory footprint used by QUDA, which may help(it does come at the cost of exposing malloc/free overheads which are likely a small single digit overhead.

kostrzewa commented 1 month ago

Also highlighting the env QUDA_ENABLE_DEVICE_MEMORY_POOL, if you disable the pool setting it 0, this should noticeably the reduce memory footprint used by QUDA, which may help(it does come at the cost of exposing malloc/free overheads which are likely a small single digit overhead.

In tmLQCD we even set QUDA_ENABLE_DEVICE_MEMORY_POOL = 0 by default using setenv for exactly this reason with the option for the user to enable it via our input file.

https://github.com/etmc/tmLQCD/blob/950a3a161c7a80e4ffd45ef4b773ea33adf1f4be/quda_interface.c#L403

There is a cmake parameter that sets the maximum rhs allowable, and beyond that the code will split and recurse until the maximum value is met. I suppose if we set that to 1 the original behavior would be restored.

Oh I hadn't seen that. I think that's already good enough for our case. Having an additional env variable for runtime control would of course be nice.

maddyscientist commented 1 month ago

The split/recurse pattern will only work to limit the autotuning overhead and not the memory used. But I shall work on adding that, so we can control the footprint of the clover force.

maddyscientist commented 1 month ago

Somewhat orthogonal comment to this PR, but while we're here can I request that we rename covDev.cu[h] to covariant_derivative.cu[h] or something of the sort? It's the only outlier file in terms of the naming convention, and iirc git should note that it's just a "rename" so we won't lose the commit history.

Done with c475655

maddyscientist commented 1 month ago

Also highlighting the env QUDA_ENABLE_DEVICE_MEMORY_POOL, if you disable the pool setting it 0, this should noticeably the reduce memory footprint used by QUDA, which may help(it does come at the cost of exposing malloc/free overheads which are likely a small single digit overhead.

In tmLQCD we even set QUDA_ENABLE_DEVICE_MEMORY_POOL = 0 by default using setenv for exactly this reason with the option for the user to enable it via our input file.

https://github.com/etmc/tmLQCD/blob/950a3a161c7a80e4ffd45ef4b773ea33adf1f4be/quda_interface.c#L403

There is a cmake parameter that sets the maximum rhs allowable, and beyond that the code will split and recurse until the maximum value is met. I suppose if we set that to 1 the original behavior would be restored.

Oh I hadn't seen that. I think that's already good enough for our case. Having an additional env variable for runtime control would of course be nice.

@kostrzewa I have added support for the env QUDA_MAX_MULTI_RHS that can be used to cap at run-time them maximum number of rhs per kernel (https://github.com/lattice/quda/pull/1465/commits/04fd7ea9c3314e0d48b8d5c252fd9b9f0bff3e98), and all multi-RHS kernels will split and recurse respectively to respect this limit. This can be used to control the autotuning time if that is ever an issue. It is the multi-RHS dslash kernels that take a long time to tune, specifically the main time taken comes from specific dslash policies, which are only really relevant on systems with fast host-device interconnect. The best solution here is likely to disable these policies by default which fixes this headache. I'll do that shortly...

With respect to memory footprint, I've tweaked the clover force code a bit to reduce the memory footprint (552957e). This reduces the footprint of the clover force by about 15% on the tests I was doing locally. There's more scope for reduction if this is needed.

maddyscientist commented 1 month ago

@kostrzewa @weinbe2 I have now disabled all zero-copy policies by default, this dramatically accelerates the autotuning time, and avoid us having to pointlessly run dslash variants which are unnecessary for most systems.

I believe this PR is now functionally complete.

kostrzewa commented 1 month ago

@maddyscientist @Marcogarofalo I'm afraid that between fc92ee628 and a9714ba25 something was done which appears to trigger something which looks like an integrator instability from "way back when" in our HMC.

Log files from Juwels Booster below:

tmLQCD_nf211_hmc_64c128_n4_fc92ee628.zip

tmLQCD_nf211_hmc_64c128_n4_a9714ba25.zip

showing that in the second log CG from a determinant monomial diverges.

Not quite sure how to catch this kind of thing with QUDA tests but I'm guessing the issue originates in the mrhs improvement of the fermion force for the partial fractions (perhaps the outer product?).

maddyscientist commented 1 month ago

@kostrzewa could you run with tuning disabled? One of the drawbacks of our current testing is the run to run variance possible with different tuning configs. If you run with disabled and the issue reproduces this would be helpful information for debugging. Thx

kostrzewa commented 1 month ago

@kostrzewa could you run with tuning disabled? One of the drawbacks of our current testing is the run to run variance possible with different tuning configs. If you run with disabled and the issue reproduces this would be helpful information for debugging. Thx

I think I have to take back what I said in https://github.com/lattice/quda/pull/1465#issuecomment-2141701489. While I did observe solvers diverging twice in a row with https://github.com/lattice/quda/commit/a9714ba25087e77c08c08ea28cabe531b249bc31, I cannot reproduce this any more. I think it must have been a transient machine fluke. I'll run a few more test runs but for now I can say that the mrhs branch works correctly within tmLQCD's HMC and shaves off another 10-15% of the time per trajectory in a simulation of a 64c128 ensemble on 4 Juwels Booster nodes.

kostrzewa commented 1 month ago

I'll have to retest on the A40-based machine.

maddyscientist commented 1 month ago

I think I have to take back what I said in #1465 (comment). While I did observe solvers diverging twice in a row with a9714ba, I cannot reproduce this any more. I think it must have been a transient machine fluke. I'll run a few more test runs but for now I can say that the mrhs branch works correctly within tmLQCD's HMC and shaves off another 10-15% of the time per trajectory in a simulation of a 64c128 ensemble on 4 Juwels Booster nodes.

Thanks for the follow up @kostrzewa and restoring my sanity. 😅 I have been doing exhaustive testing since your prior report, and couldn't find anything bad. Great to hear of the performance improvement. Things will only get better from here when we add MRHS to the MG setup, which will hopefully be not too far away now....

maddyscientist commented 1 month ago

This gets my conditional approval. It's passed a visual review and a correctness review. @maddyscientist is investigating some performance regressions, so no one else should hit merge until those have been addressed (please)!

The performance regression was caused by an excess of calls to qudaDeviceSynchronize being made. This is addressed in d01194619344a5934877ef59dace76176ace2e10 (and to a lesser degree 07f142c70d31236f7e0b1537e052e4e1423c85a5).