Open zfang opened 1 year ago
Hi @zfang, thanks for opening this feature request and for the detailed explanation and diff!
Configuring the gradient checkpointing is indeed an interesting feature. In general, we try to keep our forward passes as simple as possible. As gradient checkpointing is used by most of our models, it is a change which would have to be applied across the library, and so requires careful consideration of how it's configured e.g. would this lead to many additional future arguments? And backwards compatibility.
What I suggest is leaving this issue open. If we see more interest in the community (measured by 👍 on this issue) in adding this, then we can consider its integration.
cc @ArthurZucker
Just wanted to indicate interest in this issue. There is a big gap in speed & VRAM between gradient checkpointing + naive MP vs more sophisticated methods such as deepspeed, FSDP, etc., for training very large models (bigger than can fit in a single GPU).
Being able to control the checkpointing strategy will let us incrementally take advantage of any additional VRAM, and increase speed with naive MP + checkpointing, without having to resort to those other methods (which typically need a lot more VRAM to pay off and don't support the same features).
EDIT: Looking closer at the diff I'm not sure the code will work as intended (except for the case of 2 segments maybe). I think we'd need to wrap a chunk of decoder layers into a single forward and pass that to the checkpoint function (in the new transformers implementation). In the above diff, instead, the intermediate layers inside each segment (when idx % self.gradient_checkpointing_segment_size != 0
) will be run with the usual forward()
, which will not be with no_grad()
or using the correct non-reentrant logic.
EDIT2: Given the description of gradient_checkpointing_segment_size
it does seem to do what the author intended (checkpoints are enabled every these many layers, and the rest are processed normally without checkpointing). However, this is not the same meaning of a segment as defined in checkpoint_sequential
which moves in the opposite direction (a segment is a set of layers where only the input to the first layer in that set is saved/checkpointed, and everything in the interior of the set is processed egs., with no_grad()
in reentrant mode).
@grimulkan Given the difference in the two applications, what would the diff look like for your proposal?
Sorry about the late reply. Thinking further, I think leave the OP's proposal as-is, and manually using checkpoint_sequential
for my proposal is a decent compromise.
Increase speed at the cost of memory (OP's proposal): Only checkpoint every n-th layer. This is relatively easy to implement as the diff shows.
Reduce memory at the cost of speed: Group together n-layers and checkpoint them together. This is much more difficult to implement in a generalized way I think, and currently transformers
provides checkpoint_sequential
to do this, as long as we can manually wrap our models (or parts of our models) as torch.nn.sequential
. Supporting this would be different for each supported model.
Feature request
Currently when we enable gradient checkpointing, e.g. in
LlamaModel
, we calltorch.utils.checkpoint.checkpoint
on everyLlamaDecoderLayer
. As per Training Deep Nets with Sublinear Memory Cost, https://github.com/cybertronai/gradient-checkpointing and hinted by torch.utils.checkpoint.checkpoint_sequential, we can strike a balance between memory and compute by supportinggradient_checkpointing_segment_size
.Motivation
Gradient checkpointing makes training slow and people enable it mainly to avoid VRAM OOM. It would be great if we can leverage
gradient_checkpointing_segment_size
and make gradient checkpointing configurable.Your contribution
Here is a diff to start off with: