NanoComp / meep

free finite-difference time-domain (FDTD) software for electromagnetic simulations
GNU General Public License v2.0
1.19k stars 610 forks source link

Using parallel meep causes 8 GPU processes and memory exception #2572

Open Arcadianlee opened 1 year ago

Arcadianlee commented 1 year ago

Hi, when I do: mpirun -np 8 python filename.py, I end up with 8 GPU processes in parallel which exceeds the cuda memory limit and causes the error:

RuntimeError: CUDA error: out of memory.

image

Any idea how to fix this?

smartalecH commented 1 year ago

This isn't a meep issue. Meep doesn't allocate any CUDA memory.

Arcadianlee commented 1 year ago

Then whose issue is it? And why did you close the post before I could resolve it?

smartalecH commented 1 year ago

This has nothing to do with meep. Meep can't possibly allocate any Cuda memory. Closing as it's off topic.

Arcadianlee commented 1 year ago

When I use the non-MPI version of meep everything works fine, however when I use the MPI version it reports a cuda error. I don't what's causing it here. Still I don't think closing an issue without at least trying to help me resolve it is inappropriate for you to do as a meep contributor.

Arcadianlee commented 1 year ago

@stevengj @oskooi

oskooi commented 1 year ago

One possible explanation for the large CUDA memory consumption that you are encountering is that you are using Meep's adjoint solver with JAX which has been set up to use GPUs as part of its XLA backend. From the JAX documentation page GPU memory allocation:

JAX will preallocate 90% of the total GPU memory when the first JAX operation is run. 
Preallocating minimizes allocation overhead and memory fragmentation, but can sometimes 
cause out-of-memory (OOM) errors.

This behavior may be exacerbated by JAX's aggressive threading model when running a parallel Meep simulation using MPI as described in #1661. (This could explain why your serial runs did not have this problem.) Note that #1661 was filed two years ago. I recently discussed this issue with the JAX developers. Unfortunately, there still is no reliable fix.

A workaround is to not to use JAX but rather autograd.

Arcadianlee commented 1 year ago

Sounds like a possible cause. Will check it out.

Arcadianlee commented 1 year ago

Hi, I've also encountered an error: simulation fields NaN or inf. Bizarrely, this error only occurs when I run parallel MPI meep, and does NOT occur when I run the non-MPI meep. Any idea why it behaves like this?

image