phlippe / uvadlc_notebooks

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
https://uvadlc-notebooks.readthedocs.io/en/latest/
MIT License
2.59k stars 590 forks source link

jax tutorial 5: cannot run on GPU (due to clash with pytorch RNG) #84

Closed murphyk closed 1 year ago

murphyk commented 1 year ago

Tutorial: -1 (Fill-in number of tutorial)

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html

Describe the bug

Same issue as in https://github.com/phlippe/uvadlc_notebooks/issues/83. However it crops right at the start - merely importing pytorch causes issues. See screenshot below

Can be fixed by moving

# Seeding for random operations
main_rng = random.PRNGKey(42)

before importing pytorch.

Screenshot 2023-03-18 at 6 41 29 PM
murphyk commented 1 year ago

Here is a minimal colab to reproduce the core problem. https://colab.research.google.com/drive/1gGDAGotSA0aJ9Nd1ilV9_xWBy6uFvVkU?usp=sharing

phlippe commented 1 year ago

Closing the issue for being fixed in efbc9de8a08857f1d074d1063c8c3c861475eea4 and discussion continues in #83