Open tbenthompson opened 2 years ago
currently, JAX code is kinda scary because we don't know what parts are slow! it's hard to profile!
JAX has profiling tools built in but they mostly operate on a level above the jit-ted functions: https://jax.readthedocs.io/en/latest/profiling.html
some of this is inevitable due to fundamental CUDA behavior, but some of it can be fixed. Two things to explore:
@JamesYang007 just recorded some of our conversation here.
While it can't descend into individual operations inside a compiled JAX function, scalene seems to be able to attribute GPU time to JAX:
currently, JAX code is kinda scary because we don't know what parts are slow! it's hard to profile!
JAX has profiling tools built in but they mostly operate on a level above the jit-ted functions: https://jax.readthedocs.io/en/latest/profiling.html
some of this is inevitable due to fundamental CUDA behavior, but some of it can be fixed. Two things to explore: