huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.67k stars 26.7k forks source link

Gradient checkpointing should have no functional impact #26221

Closed marianokamp closed 6 months ago

marianokamp commented 1 year ago

System Info

Latest released and py3.10.

accelerate-0.21.0 aiohttp-3.8.5 aiosignal-1.3.1 async-timeout-4.0.3 bitsandbytes-0.41.0 datasets-2.14.5 evaluate-0.4.0 frozenlist-1.4.0 huggingface-hub-0.17.1 multidict-6.0.4 peft-0.4.0 pynvml-11.5.0 regex-2023.8.8 responses-0.18.0 safetensors-0.3.3 sagemaker-inference-1.10.0 tensorboardX-2.6.2.2 tokenizers-0.13.3 transformers-4.33.2 xxhash-3.3.0 yarl-1.9.2

Who can help?

@pacman100, @muellerzr

Information

Tasks

Reproduction

Hi @pacman100, @muellerzr.

I was wondering about the memory use of LoRA. Specifically what happens if I adapt modules that are

Given that the number of parameters to train remains the same in both cases, the memory usage should be the same, except that to calculate the gradients for (bottom) we would need to keep more activations around from the forward pass. If that were the case, then turning on gradient checkpointing should make (top) and (bottom) use the same memory, as we are discarding the activations and recalculating them on the backward pass. That is correct, no (@younesbelkada)?

Trying this out, I can see that behavior as expected. However, the accuracy also changed. My understanding would be that with gradient checkpointing we would now need less memory, more time, but the functional aspects, here model performance, should be unchanged. Hence the issue.

Details

Below you can see on the x-axis on which layer of a 12 layer RoBERTa Base the adapters were applied. As you can see the memory for (bottom - lower layer numbers, closer to the embeddings) are higher than for (top - higher layer numbers, closer to the head), when not using gradient checkpointing, and they are same when using gradient checkpointing.

image

However, when looking at the model performance we can see that we have a difference of 0.1 between using and not using checkpointing.

image

Not that it matters, but this is using the glue/sst-2 dataset. I am not changing anything, but passing 0 or 1 as an argument to Trainer's gradient_checkpointing attribute (and 0 and 1 to empty-cuda-cache every 30 seconds).

Expected behavior

No functional change when using gradient_checkpointing.

marianokamp commented 1 year ago

No answer or re-action yet, but not stale either.

amyeroberts commented 11 months ago

Gentle ping @muellerzr @pacman100

marianokamp commented 10 months ago

@pacman100, @muellerz Just re-ran with transformers 4.36.0, same result:

image

marianokamp commented 9 months ago

@pacman100, @muellerzr, @younesbelkada. Anything I can do here to help you acknowledge the ticket? If I am hearing nothing I will let it auto-close.

pacman100 commented 9 months ago

Hello @marianokamp, Thank you for your patience. As I don't have a clear minimal reproducer here, I ran the below experiments and don't see a diff in performance with and without gradient checkpointing.

  1. Code: https://github.com/huggingface/peft/blob/main/examples/sequence_classification/LoRA.ipynb
  2. Use the set_seed for deterministic runs:
    
    import argparse
    import os

import torch from torch.optim import AdamW from torch.utils.data import DataLoader from peft import ( get_peft_config, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, LoraConfig, PeftType, PrefixTuningConfig, PromptEncoderConfig, )

import evaluate from datasets import load_dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed from tqdm import tqdm

Observations: No performance gap between runs with gradient checkpointing and without gradient checkpointing.

marianokamp commented 9 months ago

Thanks @pacman100. I got it now - a minimalist example is needed. I will try to create one over the weekend.

marianokamp commented 9 months ago

@pacman100. Hi Sourab, thanks for investing the time!

You didn't say otherwise, so it's confirmed that using gradient checkpointing should not change the functional impact of the model, correct?

I now have a minimal implementation sample notebook that shows the issue.

Background: The original code is from an article that illustrates for educational purposes how a simple LoRA implementation looks like. It's just Python code and worked fine, until I tried gradient checkpointing in the 2nd article.

I am not aware of specific expectations that the transformers lib has on code. But there are two things I do in my example that may be worth pointing out as not being in the middle of the road. (a) Freezing modules and (b) overwriting the forward function in the module to be adapted to point it to the adapter implementation in the forward pass. Both work fine without gradient checkpointing, but maybe they are problematic with gradient checkpointing? The code is in the example I linked above, but for easier consumption I reproduce this method here:

def adapt_model(model):

    class MinimalLoRAAdapter(nn.Module): 
        def __init__(self, 
                     adaptee):
            super().__init__()

            self.adaptee = adaptee

            self.orig_forward = adaptee.forward
            adaptee.forward = self.forward # <-----------------

            r = 1
            adaptee.lora_A = nn.Parameter(
                torch.randn(adaptee.in_features, r) / math.sqrt(adaptee.in_features)
            )
            adaptee.lora_B = nn.Parameter(torch.zeros(r, adaptee.out_features))

        def forward(self, x, *args, **kwargs):
            return (
                self.orig_forward(x, *args, **kwargs) # <-----------------
                + F.dropout(x, 0.1) @ self.adaptee.lora_A @ self.adaptee.lora_B
            )

    # freeze all layers, incl. embeddings, except for the classifier
    for m in model.roberta.modules():    
        m.requires_grad_(False) # <-----------------

    # Adapt linear modules in transformer layers
    for m in model.roberta.encoder.modules():    
        if isinstance(m, nn.Linear):
            MinimalLoRAAdapter(m)

Here is an excerpt from the output. Full output in the linked notebook (check eval_accuracy):

---- without gradient checkpointing ----

[..]
model.is_gradient_checkpointing=False
[..]
{'train_runtime': 457.1886, 'train_samples_per_second': 489.951, 'train_steps_per_second': 2.187, 'train_loss': 0.38296363830566404, 'epoch': 3.32}
{'eval_loss': 0.23593959212303162, 'eval_accuracy': 0.908256880733945, 'eval_runtime': 1.6902, 'eval_samples_per_second': 515.919, 'eval_steps_per_second': 64.49, 'epoch': 3.32}

---- with gradient checkpointing ----

[..]
model.is_gradient_checkpointing=True
[..]
{'train_runtime': 227.8506, 'train_samples_per_second': 983.101, 'train_steps_per_second': 4.389, 'train_loss': 0.6675097045898437, 'epoch': 3.32}
{'eval_loss': 0.6635248064994812, 'eval_accuracy': 0.5194954128440367, 'eval_runtime': 1.6397, 'eval_samples_per_second': 531.808, 'eval_steps_per_second': 66.476, 'epoch': 3.32}
[..]

I tried the above with both GPU and CPU and I can observe the same behavior. Hope that helps to narrow it down.

amyeroberts commented 7 months ago

Gentle ping @pacman100

pacman100 commented 6 months ago

Hello @marianokamp,

Thank you for the minimal reproducer via the notebook. I ran it using the latest versions with the below changes:

+ gradient_checkpointing_kwargs = None
    if cp_enabled:
-         model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False})
+        gradient_checkpointing_kwargs = {"use_reentrant":False}

    training_args = TrainingArguments(
        gradient_checkpointing=cp_enabled,
+        gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
...

The issue you are facing with gradient checkpointing with LoRA is as follows:

  1. Let's see the behaviour for use_reentrant=True as mentioned in https://pytorch.org/docs/stable/checkpoint.html:

    At least one input and output must have requires_grad=True for the reentrant variant. If this condition is unmet, the checkpointed part of the model will not have gradients. The non-reentrant version does not have this requirement.

  2. You were correctly setting model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False}) but the Trainer was resetting it because of trainingArgument gradient_checkpointing and as you didn't pass gradient_checkpointing_kwargs, the default value of use_reentrant=True was used. This is clear from the warning in your notebook output: Screenshot 2024-04-02 at 3 35 53 PM
  3. Now, as the embedding layer is frozen, neither the input nor the output has requires_grad=True which is required when using use_reentrant=True. As such, no gradients are computed and no learning happens leading to very low model accuracy.
  4. The above changes rectify this to use the recommended use_reentrant=False.
  5. Another alternative if you still want to use use_reentrant=True is to make the outputs of the embedding layer require grads even though you won't be needing it as this fulfils the condition of least one input and output must have requires_grad=True for the reentrant variant. You can see this being done in the PEFT codebase at https://github.com/huggingface/peft/blob/02b5aeddf9c1ea11451f10a8a26da7e5df8cca4a/src/peft/utils/other.py#L112-L122

Output with the above changes: Screenshot 2024-04-02 at 3 43 36 PM

Library versions: Screenshot 2024-04-02 at 3 44 05 PM

Code:

from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer, set_seed
from torch import nn
from torch.nn import functional as F
import math

hf_ckp = 'roberta-base'
set_seed(100)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {f"accuracy": (predictions == labels).mean()}

def count_parameters(m, verbose=True):
    total_count = 0
    learnable_count = 0
    if verbose:
        print("Parameters (name, tunable, count):")

    output_width = max([len(n) for n, _ in m.named_parameters()])
    for n, p in m.named_parameters():
        count = p.data.numel()
        if verbose:
            print(f" {n:{output_width}} {p.requires_grad:5b} {count:>11d}")
        total_count += count
        if p.requires_grad:
            learnable_count += count

    print(
        f"Total parameters: {total_count:,}, "
        f"thereof learnable: {learnable_count:,} "
        f"({learnable_count/total_count*100.:5.4f}%)"
    )

    return total_count, learnable_count

def adapt_model(model):

    # Minimalized example in place of the original LoRA-from-Scratch 
    # implementation from the article: 
    # https://towardsdatascience.com/dive-into-lora-adapters-38f4da488ede
    class MinimalLoRAAdapter(nn.Module): 
        def __init__(self, 
                     adaptee):
            super().__init__()

            self.adaptee = adaptee

            self.orig_forward = adaptee.forward
            adaptee.forward = self.forward

            r = 1
            adaptee.lora_A = nn.Parameter(
                torch.randn(adaptee.in_features, r) / math.sqrt(adaptee.in_features)
            )
            adaptee.lora_B = nn.Parameter(torch.zeros(r, adaptee.out_features))

        def forward(self, x, *args, **kwargs):
            return (
                self.orig_forward(x, *args, **kwargs)
                + F.dropout(x, 0.1) @ self.adaptee.lora_A @ self.adaptee.lora_B
            )

    # freeze all layers, incl. embeddings, except for the classifier
    for m in model.roberta.modules():    
        m.requires_grad_(False)

    # Adapt linear modules in transformer layers
    for m in model.roberta.encoder.modules():    
        if isinstance(m, nn.Linear):
            MinimalLoRAAdapter(m)
%%time

tokenizer = AutoTokenizer.from_pretrained(hf_ckp)
collator = DataCollatorWithPadding(tokenizer=tokenizer)

datasets.logging.disable_progress_bar()
dataset = datasets.load_dataset("glue", "sst2")
train = dataset["train"]
valid = dataset["validation"]

def preprocess_function(examples):
        return tokenizer(examples['sentence'], padding=False, truncation=True)

tokenized_train = train.map(preprocess_function, batched=False)
tokenized_valid = valid.map(preprocess_function, batched=False)

def train(cp_enabled, model):
    gradient_checkpointing_kwargs = None
    if cp_enabled:
        gradient_c_heckpointing_kwargs = {"use_reentrant":False}

    training_args = TrainingArguments(
        gradient_checkpointing=cp_enabled,
        gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
        output_dir="out",    
        per_device_train_batch_size=224,
        learning_rate=3e-5,
        save_steps=10_000,
        eval_steps=   250,
        max_steps = 1_000,
        evaluation_strategy="steps",
        save_strategy="steps",
        save_total_limit=1,
        disable_tqdm=True,
        metric_for_best_model='eval_accuracy',
        report_to="none", # Disable wandb, tensorboard
    )

    trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_train,
            eval_dataset=tokenized_valid,
            tokenizer=tokenizer,
            data_collator=collator,
            compute_metrics=compute_metrics,
    )
    print(f'{model.is_gradient_checkpointing=}')
    total, learnable = count_parameters(model, verbose=False)

    trainer.train()
    trainer.evaluate()

print('\n---- without gradient checkpointing ----\n')
model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)   
adapt_model(model)
train(False, model)

del(model) # essential!

print('\n---- with gradient checkpointing ----\n')
model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)
adapt_model(model)

train(True, model)
marianokamp commented 6 months ago

@pacman100, thanks for your help and walking me through the solution in detail. I am still a bit confused by the API, but I understand the steps you showed me and following them fixed my issue in my original, non-minimal, code. All clear for me now. Much appreciated, Sourab!