AMReX-Codes / pyamrex

GPU-Enabled, Zero-Copy AMReX Python Bindings including AI/ML
http://pyamrex.readthedocs.io
Other
32 stars 15 forks source link

PyTorch and MPI-enabled AMReX don't get along in `load_state_dict` #322

Open RTSandberg opened 1 month ago

RTSandberg commented 1 month ago

On my local machine, PyTorch has some internal multithreaded functionality that doesn't get along with AMReX. Unless I set PyTorch.set_num_threads(1 or 2), then the attached script will hang when the neural network tries to set its initial parameters.

This script downloads some neural network parameters from Zenodo archive to then load them, and the load_state_dict function is the specific point of failure.

pytorch_amrex_hang_reproducer_v2.py.txt

ax3l commented 1 month ago

Thank you, @RTSandberg !

For reproducibility, can you please add the OS you used, versions of Python, pyAMReX, PyTorch, MPI flavor and version, and mpi4py version?

ax3l commented 1 month ago

If we can reduce this problem to a pure mpi4py + PyTorch issue, then we could also report this upstream in PyTorch: https://github.com/pytorch/pytorch/issues