Closed sanchit-gandhi closed 2 years ago
Implements gradient checkpointing through use of remat in conjunction with scan_with_axes. The result: a 4x increase in maximum per-device batch size (from 2 to 16), with a 70% lower compilation time.
remat
scan_with_axes
Implements gradient checkpointing through use of
remat
in conjunction withscan_with_axes
. The result: a 4x increase in maximum per-device batch size (from 2 to 16), with a 70% lower compilation time.