Confirm-Solutions / confirmasaurus

3 stars 0 forks source link

Learn to profile/optimize JAX code. #61

Open tbenthompson opened 2 years ago

tbenthompson commented 2 years ago

currently, JAX code is kinda scary because we don't know what parts are slow! it's hard to profile!

JAX has profiling tools built in but they mostly operate on a level above the jit-ted functions: https://jax.readthedocs.io/en/latest/profiling.html

some of this is inevitable due to fundamental CUDA behavior, but some of it can be fixed. Two things to explore:

tbenthompson commented 2 years ago

@JamesYang007 just recorded some of our conversation here.

tbenthompson commented 1 year ago

While it can't descend into individual operations inside a compiled JAX function, scalene seems to be able to attribute GPU time to JAX: