Are you willing to contribute it (Yes/No): Not sure, will submit PR if I have the bandwidth
Describe the new feature and the current behavior/state
The JAX framework has a built-in jax.profiler that can profile trace, but when we tried to use jax.profiler to analyze model traces which was parallelized by alpa.parallelize instead of jax.jit, we found that there is no trace about cuda kernel.
I wonder if there are any ways to analyze Alpa traces? Thanks!
Will this change the current API? How?
N
Describe alternatives you've considered
System information
Describe the new feature and the current behavior/state The JAX framework has a built-in
jax.profiler
that can profile trace, but when we tried to usejax.profiler
to analyze model traces which was parallelized by alpa.parallelize instead of jax.jit, we found that there is no trace about cuda kernel.I wonder if there are any ways to analyze Alpa traces? Thanks!
Will this change the current API? How? N Describe alternatives you've considered
Additional context None