poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

[Bug] Elegy crash on GPU #246

Open jiyuuchc opened 2 years ago

jiyuuchc commented 2 years ago

Describe the bug

Running mnist_cnn.py in the example dir crash the instance at the end of the first epoch.

This was previously reported on Colab GPU instance. But I can reproduce this on CLI too.

Running on CPU does not have this problem.

Running on eager mode with GPU does not have this problem.

Minimal code to reproduce

python mnist_cnn.py

Expected behavior Not stuck.

Library Info CentOS Linux release 7.6.1810 elegy 0.8.6

Additional context absl-py==1.2.0 aiohttp==3.8.1 aiosignal==1.2.0 async-timeout==4.0.2 attrs==22.1.0 certifi==2021.10.8 charset-normalizer==2.1.1 chex==0.1.4 click==8.1.3 cloudpickle==1.6.0 colorama==0.4.5 commonmark==0.9.1 cycler==0.11.0 datasets==2.4.0 dill==0.3.5.1 dm-tree==0.1.7 docker-pycreds==0.4.0 einops==0.4.1 elegy==0.8.6 etils==0.7.1 filelock==3.8.0 flax==0.4.2 fonttools==4.36.0 frozenlist==1.3.1 fsspec==2022.7.1 gitdb==4.0.9 GitPython==3.1.27 h5py==3.6.0 huggingface-hub==0.8.1 idna==3.3 importlib-resources==5.9.0 jax==0.3.16 jaxlib==0.3.15+cuda11.cudnn82 kiwisolver==1.4.4 matplotlib==3.5.3 msgpack==1.0.4 multidict==6.0.2 multiprocess==0.70.13 numpy==1.22.3 opt-einsum==3.3.0 optax==0.1.3 packaging==21.3 pandas==1.4.3 pathtools==0.1.2 Pillow==9.2.0 promise==2.3 protobuf==3.20.1 psutil==5.9.1 pyarrow==9.0.0 Pygments==2.13.0 pyparsing==3.0.9 python-dateutil==2.8.2 pytz==2022.2.1 PyYAML==6.0 requests==2.28.1 responses==0.18.0 rich==11.2.0 scipy==1.8.0 sentry-sdk==1.9.5 setproctitle==1.3.2 shortuuid==1.0.9 six==1.16.0 smmap==5.0.0 tensorboardX==2.5.1 toolz==0.12.0 tqdm==4.64.0 treeo==0.0.10 treex==0.6.10 typing_extensions==4.3.0 urllib3==1.26.11 wandb==0.12.21 xxhash==3.0.0 yarl==1.8.1 zipp==3.8.1

jiyuuchc commented 2 years ago

Related to #234 and #238

jiyuuchc commented 2 years ago

The cause seems to be #226