huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.64k stars 26.92k forks source link

Add gradient_checkpointing_segment_size #26103

Open zfang opened 1 year ago

zfang commented 1 year ago

Feature request

Currently when we enable gradient checkpointing, e.g. in LlamaModel, we call torch.utils.checkpoint.checkpoint on every LlamaDecoderLayer. As per Training Deep Nets with Sublinear Memory Cost, https://github.com/cybertronai/gradient-checkpointing and hinted by torch.utils.checkpoint.checkpoint_sequential, we can strike a balance between memory and compute by supporting gradient_checkpointing_segment_size.

Motivation

Gradient checkpointing makes training slow and people enable it mainly to avoid VRAM OOM. It would be great if we can leverage gradient_checkpointing_segment_size and make gradient checkpointing configurable.

Your contribution

Here is a diff to start off with:

diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index bbf229b58..685400ffd 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -1745,6 +1745,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
         """
         return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())

+    def set_gradient_checkpointing_segment_size(self, value: int):
+        """
+        Set gradient checkpointing segment size for the current model.
+        """
+        if not self.supports_gradient_checkpointing:
+            raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
+        if not hasattr(self.model, "_set_gradient_checkpointing_segment_size"):
+            raise ValueError(f"{self.__class__.__name__} does not support configuring gradient checkpointing segment size.")
+        self.apply(partial(self._set_gradient_checkpointing_segment_size, value=value))
+
     def save_pretrained(
         self,
         save_directory: Union[str, os.PathLike],
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 5e7a879c0..682e49b21 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -491,6 +491,10 @@ class LlamaPreTrainedModel(PreTrainedModel):
         if isinstance(module, LlamaModel):
             module.gradient_checkpointing = value

+    def _set_gradient_checkpointing_segment_size(self, module, value=1):
+        if isinstance(module, LlamaModel):
+            module.gradient_checkpointing_segment_size = value
+

 LLAMA_INPUTS_DOCSTRING = r"""
     Args:
@@ -578,6 +582,7 @@ class LlamaModel(LlamaPreTrainedModel):
         self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

         self.gradient_checkpointing = False
+        self.gradient_checkpointing_segment_size = 1
         # Initialize weights and apply final processing
         self.post_init()

@@ -689,7 +694,7 @@ class LlamaModel(LlamaPreTrainedModel):

             past_key_value = past_key_values[idx] if past_key_values is not None else None

-            if self.gradient_checkpointing and self.training:
+            if self.gradient_checkpointing and self.training and (idx % self.gradient_checkpointing_segment_size == 0):

                 def create_custom_forward(module):
                     def custom_forward(*inputs):
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 2b8cb013f..f7d5e87ea 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -1687,6 +1687,9 @@ class Trainer:
         if args.gradient_checkpointing:
             self.model.gradient_checkpointing_enable()

+            if args.gradient_checkpointing_segment_size > 1:
+               self.model.set_gradient_checkpointing_segment_size(args.gradient_checkpointing_segment_size)
+
         model = self._wrap_model(self.model_wrapped)

         if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None:
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 2f72622d9..ecffc5660 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -586,6 +586,8 @@ class TrainingArguments:
             Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
         gradient_checkpointing (`bool`, *optional*, defaults to `False`):
             If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+        gradient_checkpointing_segment_size (`int`, *optional*, defaults to `1`):
+            If gradient_checkpointing is True, use gradient checkpointing for every segment size
         include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
             Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
             that need inputs, predictions and references for scoring calculation in Metric class.
@@ -1144,6 +1146,12 @@ class TrainingArguments:
             "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
         },
     )
+    gradient_checkpointing_segment_size: int = field(
+        default=1,
+        metadata={
+            "help": "If gradient_checkpointing is True, use gradient checkpointing for every segment size."
+        },
+    )
     include_inputs_for_metrics: bool = field(
         default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
     )
@@ -2137,6 +2145,7 @@ class TrainingArguments:
         gradient_accumulation_steps: int = 1,
         seed: int = 42,
         gradient_checkpointing: bool = False,
+        gradient_checkpointing_segment_size: int = 1,
     ):
         """
         A method that regroups all basic arguments linked to the training.
@@ -2179,6 +2188,8 @@ class TrainingArguments:
                 parameters.
             gradient_checkpointing (`bool`, *optional*, defaults to `False`):
                 If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+            gradient_checkpointing_segment_size (`int`, *optional*, defaults to `1`):
+                If gradient_checkpointing is True, use gradient checkpointing for every segment size

         Example:
amyeroberts commented 1 year ago

Hi @zfang, thanks for opening this feature request and for the detailed explanation and diff!

Configuring the gradient checkpointing is indeed an interesting feature. In general, we try to keep our forward passes as simple as possible. As gradient checkpointing is used by most of our models, it is a change which would have to be applied across the library, and so requires careful consideration of how it's configured e.g. would this lead to many additional future arguments? And backwards compatibility.

What I suggest is leaving this issue open. If we see more interest in the community (measured by 👍 on this issue) in adding this, then we can consider its integration.

cc @ArthurZucker

grimulkan commented 8 months ago

Just wanted to indicate interest in this issue. There is a big gap in speed & VRAM between gradient checkpointing + naive MP vs more sophisticated methods such as deepspeed, FSDP, etc., for training very large models (bigger than can fit in a single GPU).

Being able to control the checkpointing strategy will let us incrementally take advantage of any additional VRAM, and increase speed with naive MP + checkpointing, without having to resort to those other methods (which typically need a lot more VRAM to pay off and don't support the same features).

EDIT: Looking closer at the diff I'm not sure the code will work as intended (except for the case of 2 segments maybe). I think we'd need to wrap a chunk of decoder layers into a single forward and pass that to the checkpoint function (in the new transformers implementation). In the above diff, instead, the intermediate layers inside each segment (when idx % self.gradient_checkpointing_segment_size != 0) will be run with the usual forward(), which will not be with no_grad() or using the correct non-reentrant logic.

EDIT2: Given the description of gradient_checkpointing_segment_size it does seem to do what the author intended (checkpoints are enabled every these many layers, and the rest are processed normally without checkpointing). However, this is not the same meaning of a segment as defined in checkpoint_sequential which moves in the opposite direction (a segment is a set of layers where only the input to the first layer in that set is saved/checkpointed, and everything in the interior of the set is processed egs., with no_grad() in reentrant mode).

amyeroberts commented 8 months ago

@grimulkan Given the difference in the two applications, what would the diff look like for your proposal?

grimulkan commented 7 months ago

Sorry about the late reply. Thinking further, I think leave the OP's proposal as-is, and manually using checkpoint_sequential for my proposal is a decent compromise.

Increase speed at the cost of memory (OP's proposal): Only checkpoint every n-th layer. This is relatively easy to implement as the diff shows.

Reduce memory at the cost of speed: Group together n-layers and checkpoint them together. This is much more difficult to implement in a generalized way I think, and currently transformers provides checkpoint_sequential to do this, as long as we can manually wrap our models (or parts of our models) as torch.nn.sequential. Supporting this would be different for each supported model.