AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

Fix circ storage check for delayed case #861

Closed gobbleturk closed 2 months ago

gobbleturk commented 2 months ago

With delayed activation forwarding, we have a buffer for shfit and prev_outputs each of length stages, so we can hold 2 * stages microbatches of activations before needing additional (circular storage)

This PR fixes this check - previously the factor of 2 was the on the wrong side.