AMReX-Codes / pyamrex

GPU-Enabled, Zero-Copy AMReX Python Bindings including AI/ML
http://pyamrex.readthedocs.io
Other
36 stars 19 forks source link

PyTorch and MPI-enabled AMReX don't get along in `load_state_dict` #322

Open RTSandberg opened 5 months ago

RTSandberg commented 5 months ago

On my local machine, PyTorch has some internal multithreaded functionality that doesn't get along with AMReX. Unless I set PyTorch.set_num_threads(1 or 2), then the attached script will hang when the neural network tries to set its initial parameters.

This script downloads some neural network parameters from Zenodo archive to then load them, and the load_state_dict function is the specific point of failure.

pytorch_amrex_hang_reproducer_v2.py.txt

ax3l commented 5 months ago

Thank you, @RTSandberg !

For reproducibility, can you please add the OS you used, versions of Python, pyAMReX, PyTorch, MPI flavor and version, and mpi4py version?

ax3l commented 5 months ago

If we can reduce this problem to a pure mpi4py + PyTorch issue, then we could also report this upstream in PyTorch: https://github.com/pytorch/pytorch/issues