openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.74k stars 440 forks source link

Add support for async dynamic slice fusion #19834

Open shraiysh opened 3 days ago

shraiysh commented 3 days ago

This patch adds async handling to dynamic slice fusion when the hero operation is a collective operation. Currently, only reduce-scatter is supported as a hero operation in dynamic slice thunk, so this patch also follows the same.

Added a test with compute, to ensure that communication and compute overlap in the thunks emitted.

shraiysh commented 2 days ago

Do you have a workload for which you see improvements? How can we verify and track this performance imporvement?

I see performance improvements with async work + runtime calculation of offset work in llama 2 7b on paxml. I have not tested this for other models, but I expect better performance there too. I don't know if there is a way to verify this on individual pateches. I can share one branch on my fork with all the changes so you can reproduce the performance benefits. Does that work?

golechwierowicz commented 1 day ago

I think one of the common benchmarks is maxtext. Can you verify speedups for llama{2,3}-70b (they use jax.lax.scan to express multiple layers)?