AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

Delay Activation Forwarding #857

Closed gobbleturk closed 2 months ago

gobbleturk commented 2 months ago

Add option to delay activation forwarding by one iteration so the communication can easily be overlapped

gobbleturk commented 2 months ago

Very cool! I'm curious, does this change breaking the data dependency alone enable overlapped comms, or do we need further support from XLA?

From initial experiments I have seen near perfect overlap over ICI, and partial/improved overlap over DCN. I have seen inconsistent results with the DCN experiments - it may be that the message sizes have to be a certain size for stability. Always the exposed DCN comms reduce with this option, but from AI analysis they should be removed completely