shankar1729 / jdftx

JDFTx: software for joint density functional theory
http://jdftx.org
82 stars 54 forks source link

Question regarding memory management with JDFTx #285

Closed mshammami closed 1 year ago

mshammami commented 1 year ago

Hello, I'm wondering about how access to GPU memory is managed with JDFTx. For example, the following phonon perturbation seems to require about 88-105GB of memory:

---------- Setting up k-points, bands, fillings ----------
No reducable k-points.
Computing the number of bands and number of electrons
Calculating initial fillings.
nElectrons: 4152.000000   nBands: 2232   nStates: 4

----- Setting up reduced wavefunction bases (one per k-point) -----
average nbasis = 492134.750 , ideal nbasis = 492191.929

I'm using as many GPUs as nStates (so 4 in this case), with each GPU being a 40GB NVIDIA A100. However, the calculation terminates at the electronic minimization step with an the error "GPU memory allocation failed (out of memory)". I have tried setting JDFTX_MEMPOOL_SIZE to a value close to 40GB or not using a memory pool at all, but that didn't solve the issue (the latter often results in "CUDA Error: out of memory").

I have also tried using higher memory GPUs (like the NVIDIA A100 80GB) and made sure that the job affinity is such that all GPUs are visible to all tasks (e.g. "gpu-bind=none").

I realize that the memory per process is still limited (to either 40 or 80GB), however, can JDFTx utilize the entire memory available per node or benefit from shared memory management options like "cudaMallocManaged"? Kindly, Matthew

shankar1729 commented 1 year ago

Hi Matthew,

We made the choice to only use physical GPU memory for performance reasons. When CUDA first added support for direct host memory access and the unified memory model, we tried expanding the available memory using those options, but found really poor performance compared to hosting everything on the GPU.

So unfortunately, the calculations you mention are just a little too big to fit on the GPU (even with 1 process/GPU). Perhaps try a smaller supercell size, or check if you can lower your plane-wave cutoff?

Our new code, QimPy, supports spreading wavefunctions of a given k-point over many GPUs, and will therefore not have this issue. It does not have a phonon implementation yet; we hope to have feature parity with JDFTx there within the next year.

Best, Shankar

mshammami commented 1 year ago

I see, looking forward to QimPy's release! Kindly, Matthew