[ ] In reversible mode, one can further save memory by computing backward in chunks:
a few tokens at a time for feedforward layers, since grad(concat(mlp(x1), mlp(x2))) = concat(grad(mlp(x1)), grad(mlp(x2)))
a few queries at a time for self-attention, since grad(head1 + head2) = grad(head1) + grad(head2), where head1 and head2 are attention outputs after linear projection
improved checkpointing
[x] allow user to specify the number of checkpoints, as in checkpoint_sequential
[x] do not rematerialize the last layer, as in checkpoint_sequential
[x] optionally cast checkpoints to a lower precision, as in revlib
compacted params
[ ] compacted layernorms, biases
[ ] compacted adapters
[ ] Attention could be computed in O(sqrt(n)) memory (Rabe et al, 2021), but this may be overkill
sparse or linear attention: they are great for very long sequences. However, for large models, attention is not a bottleneck in typical NLP and vision tasks (tested gpt-3 up to length 4096).
[ ] Per-block grad scaling as described in (Ramesh et al, 2021) - we rely on Sandwich Norm to maintain stability up to 96 layers (did not test more). However, it would be nice to
have per-block scaling to avoid the need for an extra LayerNorm.
Something else that we missed - please find us on discord.
grad(concat(mlp(x1), mlp(x2))) = concat(grad(mlp(x1)), grad(mlp(x2)))
grad(head1 + head2) = grad(head1) + grad(head2)
, where head1 and head2 are attention outputs after linear projectionO(sqrt(n))
memory (Rabe et al, 2021), but this may be overkill