Open AakashKumarNain opened 3 months ago
Haha, I actually quite like using issues as discussions. Partly because half of people tend to use issues for that purpose anyway... ;)
So I think profiling Equinox should be the exact same as profiling JAX. jax.profiler
is probably a natural place to start:
https://jax.readthedocs.io/en/latest/jax.profiler.html
I like the idea of creating some more docs on best practices here.
Cool. Thanks a lot. I will push a PR once I get to run this exercise successfully on my end. If I encounter any issues, I will put them in this thread only.
I've also struggled to get some of my models that use attention layers within them to perform as well as PyTorch. The performance gap I've seen has been up to 50% training throughput, but perhaps I'm not handling precision well enough. Would like to hear about what you find.
Edit: I can go through my code as well and take a look if that would help.
Okay, I haven't attached the profiler to my runs yet, but I am seeing extremely poor performance on my side compared to torch. Here are some stats for GPT-2 model with distributed training
Causal attention : True
Fused AdamW: True
Compiled: True
Forward pass only: 16-20 ms
Forward and backward: 60-70 ms
Causal attention : False (ran self attention without any causal mask just for benchmarking)
Fused AdamW: N/A (no fused implementation in Optax)
Compiled: True
Forward pass only: 45-50 ms
Forward and backward: 194 ms
Even if I ignore the torch numbers for now, the forward pass is extremely slow, and the overhead from Optax is huge. I don't have any for
loops either in my code.
PS: Given how hastily I have done these, there is a chance that the numbers might be a bit off. But I am sure I don't have any bug in the code that can change these numbers drastically
Usual warning to avoid including compile time if benchmarks, by the way.
FWIW I usually find the opposite -- PyTorch is usually slower than JAX.
Note that performance is a JAX-level thing, not an Equinox thing, so you'll probably find success by taking a look at some of the other JAX repos and seeing what tricks other people have.
Usual warning to avoid including compile time if benchmarks, by the way.
Yes, I am aware of it, and I generally warmup the model with a single batch
FWIW I usually find the opposite -- PyTorch is usually slower than JAX.
That has been my experience for years, and this is why I am surprised by these numbers
what tricks other people have.
The only tricks in this case can be scan over for loop
for transformer layers, fused qkv, and efficient attention. I am pretty sure all are in place in my code
Okay, I did some digging into this. I benchmarked each component separately, and found that all Equinox components are much faster than the torch counterparts. Beyond this, there are only two components for slowness:
tf.data.Dataset
for now to resolve this. Though I wanted to avoid it in the first place as it creates another dependencyoptax
. The forward pass is pretty efficient, but the backward pass is slow as hell. I don't know why it so slow but it instantly makes my JAX code ~1.5-2x slower than my torch codeAlso, I have found that XLA
flags for GPU either don't provide the speedups, or provide minor speedups. I will try this with H100, but not very hopeful. Will keep updating this thread
Okay, I did some digging into this. I benchmarked each component separately, and found that all Equinox components are much faster than the torch counterparts. Beyond this, there are only two components for slowness:
1. Data loader -> Switched to `tf.data.Dataset` for now to resolve this. Though I wanted to avoid it in the first place as it creates another dependency 2. The major concern is `optax`. The forward pass is pretty efficient, but the backward pass is slow as hell. I don't know why it so slow but it instantly makes my JAX code ~1.5-2x slower than my torch code
Also, I have found that
XLA
flags for GPU either don't provide the speedups, or provide minor speedups. I will try this with H100, but not very hopeful. Will keep updating this thread
since optax is the main large optimizer library in jax, I'm assuming you haven't found a solution to this?
the torch optimizers and jax optimizers seem to work very differently due to inplace operations.
Though I haven't found a solution yet, optax is fine! Pretty sure I have some silly bug somewhere in my code, and is mostly a mistake on my end rather than something related to jax/equinox/optax. I will eventually get to it once I have some dedicated time to spend on this (context switching is annoying!)
The only thing that I am certain about at this none of the xla flags speeds up any operation on A100. I am going to open an issue for that in the XLA repo.
PS: Having separate components is more useful and better than packing everything in a monolith. I consider this as the biggest strength of the JAX ecosystem
I am building GPT kinda model in Equinox, and right now the forward pass is extremely slow compared to my torch implementation. I think this is one of the cases where I would like to attach a profiler and visualize the cost of every operation in my graph (model).
Are there any recommended practices for using the profiler with equinox models? If there is no such guide, I am ready to contribute but will need some guidance
PS: I think it is high time we enable the
Discussions
in this repo 😄