Open murphyk opened 1 year ago
In first checks, PyTorch 2.0 is a bit faster than PyTorch 1, but doesn't reach the performance of JAX, e.g., on the Transformer models. I noticed that it can give quite some boost in inference, though. Still, PyTorch currently fails to compile the CNNs (some issue in the compilation, both locally and on Colab), so will wait until that is stable and then redo the speed comparisons.
In https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html you claim jax is faster than pytorch 1. Is this still true using torch.compile from pytorch 2?