patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.03k stars 135 forks source link

[Question] Best practices for profiling Equinox models #793

Open AakashKumarNain opened 1 month ago

AakashKumarNain commented 1 month ago

I am building GPT kinda model in Equinox, and right now the forward pass is extremely slow compared to my torch implementation. I think this is one of the cases where I would like to attach a profiler and visualize the cost of every operation in my graph (model).

Are there any recommended practices for using the profiler with equinox models? If there is no such guide, I am ready to contribute but will need some guidance

PS: I think it is high time we enable the Discussions in this repo 😄

patrick-kidger commented 1 month ago

Haha, I actually quite like using issues as discussions. Partly because half of people tend to use issues for that purpose anyway... ;)

So I think profiling Equinox should be the exact same as profiling JAX. jax.profiler is probably a natural place to start:

https://jax.readthedocs.io/en/latest/jax.profiler.html

I like the idea of creating some more docs on best practices here.

AakashKumarNain commented 1 month ago

Cool. Thanks a lot. I will push a PR once I get to run this exercise successfully on my end. If I encounter any issues, I will put them in this thread only.

haydn-jones commented 1 month ago

I've also struggled to get some of my models that use attention layers within them to perform as well as PyTorch. The performance gap I've seen has been up to 50% training throughput, but perhaps I'm not handling precision well enough. Would like to hear about what you find.

Edit: I can go through my code as well and take a look if that would help.

AakashKumarNain commented 1 month ago

Okay, I haven't attached the profiler to my runs yet, but I am seeing extremely poor performance on my side compared to torch. Here are some stats for GPT-2 model with distributed training

PyTorch

Causal attention : True
Fused AdamW: True
Compiled: True
Forward pass only: 16-20 ms
Forward and backward: 60-70 ms

Equinox

Causal attention : False (ran self attention without any causal mask just for benchmarking)
Fused AdamW: N/A (no fused implementation in Optax)
Compiled: True
Forward pass only: 45-50 ms
Forward and backward: 194 ms

Even if I ignore the torch numbers for now, the forward pass is extremely slow, and the overhead from Optax is huge. I don't have any for loops either in my code.

PS: Given how hastily I have done these, there is a chance that the numbers might be a bit off. But I am sure I don't have any bug in the code that can change these numbers drastically

patrick-kidger commented 1 month ago

Usual warning to avoid including compile time if benchmarks, by the way.

FWIW I usually find the opposite -- PyTorch is usually slower than JAX.

Note that performance is a JAX-level thing, not an Equinox thing, so you'll probably find success by taking a look at some of the other JAX repos and seeing what tricks other people have.

AakashKumarNain commented 1 month ago

Usual warning to avoid including compile time if benchmarks, by the way.

Yes, I am aware of it, and I generally warmup the model with a single batch

FWIW I usually find the opposite -- PyTorch is usually slower than JAX.

That has been my experience for years, and this is why I am surprised by these numbers

what tricks other people have.

The only tricks in this case can be scan over for loop for transformer layers, fused qkv, and efficient attention. I am pretty sure all are in place in my code

AakashKumarNain commented 3 weeks ago

Okay, I did some digging into this. I benchmarked each component separately, and found that all Equinox components are much faster than the torch counterparts. Beyond this, there are only two components for slowness:

  1. Data loader -> Switched to tf.data.Dataset for now to resolve this. Though I wanted to avoid it in the first place as it creates another dependency
  2. The major concern is optax. The forward pass is pretty efficient, but the backward pass is slow as hell. I don't know why it so slow but it instantly makes my JAX code ~1.5-2x slower than my torch code

Also, I have found that XLA flags for GPU either don't provide the speedups, or provide minor speedups. I will try this with H100, but not very hopeful. Will keep updating this thread

Xynonners commented 1 week ago

Okay, I did some digging into this. I benchmarked each component separately, and found that all Equinox components are much faster than the torch counterparts. Beyond this, there are only two components for slowness:

1. Data loader -> Switched to `tf.data.Dataset` for now to resolve this. Though I wanted to avoid it in the first place as it creates another dependency

2. The major concern is `optax`. The forward pass is pretty efficient, but the backward pass is slow as hell. I don't know why it so slow but it instantly makes my JAX code ~1.5-2x slower than my torch code

Also, I have found that XLA flags for GPU either don't provide the speedups, or provide minor speedups. I will try this with H100, but not very hopeful. Will keep updating this thread

since optax is the main large optimizer library in jax, I'm assuming you haven't found a solution to this?

the torch optimizers and jax optimizers seem to work very differently due to inplace operations.

AakashKumarNain commented 1 week ago

Though I haven't found a solution yet, optax is fine! Pretty sure I have some silly bug somewhere in my code, and is mostly a mistake on my end rather than something related to jax/equinox/optax. I will eventually get to it once I have some dedicated time to spend on this (context switching is annoying!)

The only thing that I am certain about at this none of the xla flags speeds up any operation on A100. I am going to open an issue for that in the XLA repo.

PS: Having separate components is more useful and better than packing everything in a monolith. I consider this as the biggest strength of the JAX ecosystem