lattice / quda

QUDA is a library for performing calculations in lattice QCD on GPUs.
https://lattice.github.io/quda
Other
289 stars 97 forks source link

Coarse dslash multi-rhs with MMA #1355

Closed hummingtree closed 1 year ago

hummingtree commented 1 year ago

This PR adds the ability to run coarse dslash multi-rhs with MMA:

coarse dslash multi-rhs MMA can be tested and benchmarked with the following command:

tests/multigrid_benchmark_test --dim $l $l $l $l --nsrc $n --mg-nvec 0 $c --prec-sloppy single/half --mg-dslash-use-mma 0 true --test $t
maddyscientist commented 1 year ago

Doing some tests, and I'm getting

../include/kernels/dslash_coarse_mma.cuh(129): error: static assertion failed with "N %% Arg::bN != 0.

What are the valid combinations of nrhs and nvec? We should support compilation with non-supported combinations, e.g., to allow for numerical experimentation, and give a suitable run-time error.

maddyscientist commented 1 year ago

Given the Jenkins build failure, which is a bug in libcu++ (since fixed), I'm wondering if we should package our own copy of libcu++ with QUDA? This would aid compatibility with older toolkits for machines that are stuck on said toolkits. Perhaps make it an option to grab a blessed release using CPM, and only enable cuda::pipeline if we use a minimum version of libcu++? Thoughts @mathiaswagner?

mathiaswagner commented 1 year ago

Guess that is an option. I have not looked at CPM for libcu++ at all but should be possible.

hummingtree commented 1 year ago

Both the libcu++ and the MRHS issue should be gone now with the latest commits.

weinbe2 commented 1 year ago

I think QUDA_MULTIGRID_SETUP_USE_SMMA needs a #cmakedefine somewhere?

weinbe2 commented 1 year ago

Somewhat generic comment on static std::string get_type_name() functions (which use std::to_string); since the outputs of these functions go into tunekeys, we may want to replace them with strcat/optimized itoa, etc versions

hummingtree commented 1 year ago

I think QUDA_MULTIGRID_SETUP_USE_SMMA needs a #cmakedefine somewhere?

I have this in lib/CMakeLists.txt

if(QUDA_MULTIGRID_SETUP_USE_SMMA)
  target_compile_definitions(quda PUBLIC QUDA_MULTIGRID_SETUP_USE_SMMA)
endif()

should be enough?

weinbe2 commented 1 year ago

I think QUDA_MULTIGRID_SETUP_USE_SMMA needs a #cmakedefine somewhere?

I have this in lib/CMakeLists.txt

if(QUDA_MULTIGRID_SETUP_USE_SMMA)
  target_compile_definitions(quda PUBLIC QUDA_MULTIGRID_SETUP_USE_SMMA)
endif()

should be enough?

I may have not done it thoroughly enough, but I think when I looked at the generated header/etc files after cmake configuring it didn't include a #define QUDA_MULTIGRID_SETUP_USE_SMMA anywhere.... but admittedly I didn't look to see if a -DQUDA_MULTIGRID_SETUP_USE_SMMA was included in the subsequent compile command instead.

hummingtree commented 1 year ago

I think QUDA_MULTIGRID_SETUP_USE_SMMA needs a #cmakedefine somewhere?

I have this in lib/CMakeLists.txt

if(QUDA_MULTIGRID_SETUP_USE_SMMA)
  target_compile_definitions(quda PUBLIC QUDA_MULTIGRID_SETUP_USE_SMMA)
endif()

should be enough?

I may have not done it thoroughly enough, but I think when I looked at the generated header/etc files after cmake configuring it didn't include a #define QUDA_MULTIGRID_SETUP_USE_SMMA anywhere.... but admittedly I didn't look to see if a -DQUDA_MULTIGRID_SETUP_USE_SMMA was included in the subsequent compile command instead.

If you check compile_commands.json, -DQUDA_MULTIGRID_SETUP_USE_SMMA are there.

weinbe2 commented 1 year ago

My initial tests using this with coarse-level deflation with a dummy 16^4, beta = 6.3 staggered config looks good. Using block TRLM with n_rhs == 16 led to accurate eigenvalues relative to non-block (arguably an endorsement for 3xMMA instructions), but it took far more iterations to converge. That likely reflects needing to use a smaller blocksize for TRLM, and not an issue with the multirhs operator.

Representative multirhs profile_0.tsv info:

    0.242422          24.253          24.253             532      0.00045568             2x4x4x4        N4quda22DslashCoarsePolicyTuneINS_18DslashCoarseLaunchINS_1DELb0ELi64ELb0ELi1EEEEE      policy,dslash,vol=128,parity=1,precision=4,order=2,Ns=2,Nc=64,gauge_prec=4,halo_prec=4,comm=0000,topo=1111,order=,p2p=0,gdr=0,nvshmem=0,pol=11110000000,full,n_rhs=16        # 4711.55 Gflop/s, 4754.12 GB/s, tuning took 0.014680 seconds at Tue Jun  6 09:11:30 2023

Representative non-multirhs:

     0.16542         22.5893         81.7534            1810      9.1392e-05             2x4x4x4        N4quda22DslashCoarsePolicyTuneINS_18DslashCoarseLaunchINS_1DELb0ELi64ELb0ELi1EEEEE      policy,dslash,vol=128,parity=1,precision=4,order=2,Ns=2,Nc=64,gauge_prec=4,halo_prec=4,comm=0000,topo=1111,order=,p2p=0,gdr=0,nvshmem=0,pol=11110000000,full,n_rhs=1 # 1468.24 Gflop/s, 1481.50 GB/s, tuning took 0.006369 seconds at Tue Jun  6 09:11:20 2023

I won't complain about a >3x increase in kernel performance!

I just need to test this through the MILC interface, and then we should be good to go.

weinbe2 commented 1 year ago

Commands for the above test, for posterity:

REGULAR_DEFLATE="--mg-eig 2 true --mg-eig-type 2 trlm --mg-eig-use-dagger 2 false --mg-eig-use-normop 2 true"
REGULAR_DEFLATE="$REGULAR_DEFLATE --mg-nvec 2 32 --mg-eig-n-ev 2 48 --mg-eig-n-kr 2 64 --mg-eig-tol 2 1e-4 --mg-eig-use-poly-acc 2 false"
REGULAR_DEFLATE="$REGULAR_DEFLATE --mg-eig-max-restarts 2 1000"

BLOCK_DEFLATE="--mg-eig 2 true --mg-eig-type 2 blktrlm --mg-eig-use-dagger 2 false --mg-eig-use-normop 2 true"
BLOCK_DEFLATE="$BLOCK_DEFLATE --mg-nvec 2 32 --mg-eig-n-ev 2 48 --mg-eig-n-kr 2 64 --mg-eig-tol 2 1e-4 --mg-eig-use-poly-acc 2 false"
BLOCK_DEFLATE="$BLOCK_DEFLATE --mg-eig-max-restarts 2 1000 --mg-eig-block-size 2 16"

mpirun -np 1 ./staggered_invert_test --inv-multigrid true --dim 16 16 16 16 --verbosity verbose --nsrc 1 \
  --dslash-type asqtad --solve-type direct --solution-type mat --compute-fat-long true \
  --load-gauge l16t16b6p3 --mass 0.1 --tadpole-coeff 0.9 --tol 1e-10 \
  --prec double --prec-sloppy single --prec-precondition single --prec-null single \
  --recon 13 --recon-sloppy 13 --recon-precondition 9 \
  --mg-levels 3 --mg-coarse-solve-type 0 direct --mg-staggered-coarsen-type kd-optimized \
  --mg-block-size 0 1 1 1 1 --mg-nvec 0 3 \
  --mg-smoother-solve-type 0 direct --mg-smoother 0 ca-gcr --mg-nu-pre 0 0 --mg-nu-post 0 8 --mg-smoother-tol 0 1e-10 --mg-verbosity 0 verbose \
  --mg-coarse-solve-type 1 direct --mg-coarse-solver-tol 1 5e-2 \
  --mg-coarse-solver-maxiter 1 16 --mg-coarse-solver 1 gcr \
  --mg-setup-inv 1 cgnr --mg-setup-maxiter 1 1000 --mg-setup-tol 1 1e-6 \
  --mg-block-size 1 4 4 4 4 --mg-nvec 1 64 --mg-n-block-ortho 1 2 \
  --mg-smoother-solve-type 1 direct --mg-smoother 1 ca-gcr --mg-nu-pre 1 0 --mg-nu-post 1 2 --mg-smoother-tol 1 1e-10 --mg-verbosity 1 verbose \
  --mg-coarse-solve-type 2 direct-pc --mg-coarse-solver-tol 2 0.25 \
  --mg-coarse-solver-maxiter 2 16   --mg-coarse-solver 2 gcr --mg-verbosity 2 verbose \
  $REGULAR_DEFLATE
weinbe2 commented 1 year ago

@hummingtree can you merge in the latest develop when you have the chance, just as a sanity check? I don't foresee any (non-trivial) conflicts.

hummingtree commented 1 year ago

Commands for the above test, for posterity:

REGULAR_DEFLATE="--mg-eig 2 true --mg-eig-type 2 trlm --mg-eig-use-dagger 2 false --mg-eig-use-normop 2 true"
REGULAR_DEFLATE="$REGULAR_DEFLATE --mg-nvec 2 32 --mg-eig-n-ev 2 48 --mg-eig-n-kr 2 64 --mg-eig-tol 2 1e-4 --mg-eig-use-poly-acc 2 false"
REGULAR_DEFLATE="$REGULAR_DEFLATE --mg-eig-max-restarts 2 1000"

BLOCK_DEFLATE="--mg-eig 2 true --mg-eig-type 2 blktrlm --mg-eig-use-dagger 2 false --mg-eig-use-normop 2 true"
BLOCK_DEFLATE="$BLOCK_DEFLATE --mg-nvec 2 32 --mg-eig-n-ev 2 48 --mg-eig-n-kr 2 64 --mg-eig-tol 2 1e-4 --mg-eig-use-poly-acc 2 false"
BLOCK_DEFLATE="$BLOCK_DEFLATE --mg-eig-max-restarts 2 1000 --mg-eig-block-size 2 16"

mpirun -np 1 ./staggered_invert_test --inv-multigrid true --dim 16 16 16 16 --verbosity verbose --nsrc 1 \
  --dslash-type asqtad --solve-type direct --solution-type mat --compute-fat-long true \
  --load-gauge l16t16b6p3 --mass 0.1 --tadpole-coeff 0.9 --tol 1e-10 \
  --prec double --prec-sloppy single --prec-precondition single --prec-null single \
  --recon 13 --recon-sloppy 13 --recon-precondition 9 \
  --mg-levels 3 --mg-coarse-solve-type 0 direct --mg-staggered-coarsen-type kd-optimized \
  --mg-block-size 0 1 1 1 1 --mg-nvec 0 3 \
  --mg-smoother-solve-type 0 direct --mg-smoother 0 ca-gcr --mg-nu-pre 0 0 --mg-nu-post 0 8 --mg-smoother-tol 0 1e-10 --mg-verbosity 0 verbose \
  --mg-coarse-solve-type 1 direct --mg-coarse-solver-tol 1 5e-2 \
  --mg-coarse-solver-maxiter 1 16 --mg-coarse-solver 1 gcr \
  --mg-setup-inv 1 cgnr --mg-setup-maxiter 1 1000 --mg-setup-tol 1 1e-6 \
  --mg-block-size 1 4 4 4 4 --mg-nvec 1 64 --mg-n-block-ortho 1 2 \
  --mg-smoother-solve-type 1 direct --mg-smoother 1 ca-gcr --mg-nu-pre 1 0 --mg-nu-post 1 2 --mg-smoother-tol 1 1e-10 --mg-verbosity 1 verbose \
  --mg-coarse-solve-type 2 direct-pc --mg-coarse-solver-tol 2 0.25 \
  --mg-coarse-solver-maxiter 2 16   --mg-coarse-solver 2 gcr --mg-verbosity 2 verbose \
  $REGULAR_DEFLATE

As @maddyscientist mentioned offline you need something like --mg-dslash-use-mma 0 true to enable MMA for the dslash.

weinbe2 commented 1 year ago

Thanks all, new numbers from a slightly different config...

non-multirhs:

    0.113806         7.13976         54.0597            1798      6.3296e-05             2x4x4x4        N4quda12DslashCoarseIfssLi2ELi64ELb1ELb0ELb0ELNS_10DslashTypeE2EEE      policy_kernel,GPU-offline,vol=128,parity=1,precision=4,order=2,Ns=2,Nc=64,comm=0000,full,halo=00000000,n_rhs=1  # 2119.96 Gflop/s, 1078.88 GB/s, tuning took 0.566794 seconds at Wed Jun  7 10:27:05 2023

multirhs, non-mma

     0.35414         11.1818         29.2985             786      0.00045056             2x4x4x4        N4quda12DslashCoarseIfssLi2ELi64ELb1ELb0ELb1ELNS_10DslashTypeE2EEE      policy_kernel,GPU-offline,vol=128,parity=1,precision=4,order=2,Ns=2,Nc=64,comm=0000,full,halo=00000000,n_rhs=16 # 4765.09 Gflop/s, 2425.02 GB/s, tuning took 8.218729 seconds at Wed Jun  7 10:28:34 2023

multirhs, mma

    0.125911         4.55102          63.702             812     0.000155063             2x4x4x4        N4quda15DslashCoarseMmaIfssLi2ELi64ELb1ELb0ELb1ELNS_10DslashTypeE2ELi16EEE      policy_kernel,2x4x4x4,comm=0000,full,halo=00000000,mma,3xbfloat16,m16n8k8,dslash,n_rhs=16       # 13845.74 Gflop/s, 554.51 GB/s, tuning took 0.058369 seconds at Wed Jun  7 10:29:14 2023

Not bad :)

weinbe2 commented 1 year ago

I took the liberty to re-simplify the MMA logic in the MILC HISQ MG interface, returning it to being a use_mma true/false "hammer" which maps to setting all of the appropriate flags to true/false under the hood. This decision can be revisited in the future as appropriate, but I don't see any issues in the short term.

weinbe2 commented 1 year ago

The MILC run looked good; I can take care of a merge of develop and a clang-format, I just want to do one last code review...

maddyscientist commented 1 year ago

@hummingtree looks like there's badness if we compile for pre Volta, e.g., Pascal, if QUDA_ENABLE_MMA is set. The resulting error message is rather unintuitive (CMake complaining about division by by an unset variable MRHS_ATOM). Also, we get badness if you disable QUDA_ENABLE_MMA and compile for Pascal, at link time (which is the same issue for the failing ROCm tests).

hummingtree commented 1 year ago

@hummingtree looks like there's badness if we compile for pre Volta, e.g., Pascal, if QUDA_ENABLE_MMA is set. The resulting error message is rather unintuitive (CMake complaining about division by by an unset variable MRHS_ATOM). Also, we get badness if you disable QUDA_ENABLE_MMA and compile for Pascal, at link time (which is the same issue for the failing ROCm tests).

Both issues should be fixed now. I have tested that a Pascal build works with both QUDA_ENABLE_MMA on and off.