sanchit-gandhi / seq2seq-speech

Repository for fine-tuning Transformers 🤗 based seq2seq speech models in JAX/Flax.
34 stars 6 forks source link

Gradient checkpointing and scan #13

Closed sanchit-gandhi closed 2 years ago

sanchit-gandhi commented 2 years ago

Implements gradient checkpointing through use of remat in conjunction with scan_with_axes. The result: a 4x increase in maximum per-device batch size (from 2 to 16), with a 70% lower compilation time.