Open shraiysh opened 3 days ago
Do you have a workload for which you see improvements? How can we verify and track this performance imporvement?
I see performance improvements with async work + runtime calculation of offset work in llama 2 7b on paxml. I have not tested this for other models, but I expect better performance there too. I don't know if there is a way to verify this on individual pateches. I can share one branch on my fork with all the changes so you can reproduce the performance benefits. Does that work?
I think one of the common benchmarks is maxtext. Can you verify speedups for llama{2,3}-70b (they use jax.lax.scan to express multiple layers)?
This patch adds async handling to dynamic slice fusion when the hero operation is a collective operation. Currently, only reduce-scatter is supported as a hero operation in dynamic slice thunk, so this patch also follows the same.
Added a test with compute, to ensure that communication and compute overlap in the thunks emitted.