alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.07k stars 357 forks source link

How to profile Alpa models and get the trace #971

Open zigzagcai opened 11 months ago

zigzagcai commented 11 months ago

System information

Describe the new feature and the current behavior/state The JAX framework has a built-in jax.profiler that can profile trace, but when we tried to use jax.profiler to analyze model traces which was parallelized by alpa.parallelize instead of jax.jit, we found that there is no trace about cuda kernel.

I wonder if there are any ways to analyze Alpa traces? Thanks!

Will this change the current API? How? N Describe alternatives you've considered

Additional context None

zigzagcai commented 11 months ago

When I tried to use jax.profiler to get the traces, I found only cpu traces and no cuda kernel traces tracked.

image