phlippe / uvadlc_notebooks

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
https://uvadlc-notebooks.readthedocs.io/en/latest/
MIT License
2.59k stars 590 forks source link

feature request: redo timing comparisons in tutorial 5 (densenet) comparing jax with pytorch 2.0 #85

Open murphyk opened 1 year ago

murphyk commented 1 year ago

In https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html you claim jax is faster than pytorch 1. Is this still true using torch.compile from pytorch 2?

phlippe commented 1 year ago

In first checks, PyTorch 2.0 is a bit faster than PyTorch 1, but doesn't reach the performance of JAX, e.g., on the Transformer models. I noticed that it can give quite some boost in inference, though. Still, PyTorch currently fails to compile the CNNs (some issue in the compilation, both locally and on Colab), so will wait until that is stable and then redo the speed comparisons.