Closed marianokamp closed 6 months ago
No answer or re-action yet, but not stale either.
Gentle ping @muellerzr @pacman100
@pacman100, @muellerz Just re-ran with transformers 4.36.0, same result:
@pacman100, @muellerzr, @younesbelkada. Anything I can do here to help you acknowledge the ticket? If I am hearing nothing I will let it auto-close.
Hello @marianokamp, Thank you for your patience. As I don't have a clear minimal reproducer here, I ran the below experiments and don't see a diff in performance with and without gradient checkpointing.
set_seed
for deterministic runs:
import argparse
import os
import torch from torch.optim import AdamW from torch.utils.data import DataLoader from peft import ( get_peft_config, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, LoraConfig, PeftType, PrefixTuningConfig, PromptEncoderConfig, )
import evaluate from datasets import load_dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed from tqdm import tqdm
3. In gradient ckpt run, add the `model.gradient_checkpointing_enable` command:
```diff
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model
4. Run the notebooks with and without gradient ckpt.
5. mem usage:
![Screenshot 2024-01-10 at 11 32 35 AM](https://github.com/huggingface/transformers/assets/13534540/8095284b-e29b-4d14-a227-12a09ac70cd0)
6. Without gradient ckpt output logs:
0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__
method is faster than using a method to encode the text followed by a call to the pad
method to get a padded encoding.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:27<00:00, 4.18it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.53it/s]
epoch 0: {'accuracy': 0.7083333333333334, 'f1': 0.8210526315789474}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.52it/s]
epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.31it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.53it/s]
epoch 2: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.29it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.52it/s]
epoch 3: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.52it/s]
epoch 4: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.54it/s]
epoch 5: {'accuracy': 0.8186274509803921, 'f1': 0.8766666666666666}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.54it/s]
epoch 6: {'accuracy': 0.8333333333333334, 'f1': 0.8885245901639344}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.50it/s]
epoch 7: {'accuracy': 0.875, 'f1': 0.9109947643979057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.52it/s]
epoch 8: {'accuracy': 0.8872549019607843, 'f1': 0.9184397163120569}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.51it/s]
epoch 9: {'accuracy': 0.8872549019607843, 'f1': 0.9201388888888888}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.29it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.49it/s]
epoch 10: {'accuracy': 0.8921568627450981, 'f1': 0.9225352112676057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.29it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.49it/s]
epoch 11: {'accuracy': 0.8897058823529411, 'f1': 0.9220103986135182}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.28it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.49it/s]
epoch 12: {'accuracy': 0.8946078431372549, 'f1': 0.9241622574955909}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.46it/s]
epoch 13: {'accuracy': 0.8970588235294118, 'f1': 0.926056338028169}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.43it/s]
epoch 14: {'accuracy': 0.8921568627450981, 'f1': 0.9225352112676057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.28it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.43it/s]
epoch 15: {'accuracy': 0.8872549019607843, 'f1': 0.9181494661921709}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.28it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.49it/s]
epoch 16: {'accuracy': 0.8897058823529411, 'f1': 0.9211908931698775}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.48it/s]
epoch 17: {'accuracy': 0.8897058823529411, 'f1': 0.9203539823008849}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.44it/s]
epoch 18: {'accuracy': 0.8872549019607843, 'f1': 0.9195804195804195}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00, 4.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.43it/s]
epoch 19: {'accuracy': 0.8921568627450981, 'f1': 0.923076923076923}
7. with gradient checkpointing output logs:
0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__
method is faster than using a method to encode the text followed by a call to the pad
method to get a padded encoding.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:41<00:00, 2.77it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.37it/s]
epoch 0: {'accuracy': 0.7083333333333334, 'f1': 0.8210526315789474}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.37it/s]
epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.39it/s]
epoch 2: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.41it/s]
epoch 3: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.40it/s]
epoch 4: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.40it/s]
epoch 5: {'accuracy': 0.8186274509803921, 'f1': 0.8766666666666666}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:41<00:00, 2.79it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.44it/s]
epoch 6: {'accuracy': 0.8333333333333334, 'f1': 0.8885245901639344}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.42it/s]
epoch 7: {'accuracy': 0.875, 'f1': 0.9109947643979057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.40it/s]
epoch 8: {'accuracy': 0.8872549019607843, 'f1': 0.9184397163120569}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.84it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.39it/s]
epoch 9: {'accuracy': 0.8872549019607843, 'f1': 0.9201388888888888}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.40it/s]
epoch 10: {'accuracy': 0.8921568627450981, 'f1': 0.9225352112676057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.39it/s]
epoch 11: {'accuracy': 0.8897058823529411, 'f1': 0.9220103986135182}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.40it/s]
epoch 12: {'accuracy': 0.8946078431372549, 'f1': 0.9241622574955909}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.39it/s]
epoch 13: {'accuracy': 0.8970588235294118, 'f1': 0.926056338028169}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.34it/s]
epoch 14: {'accuracy': 0.8921568627450981, 'f1': 0.9225352112676057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.33it/s]
epoch 15: {'accuracy': 0.8872549019607843, 'f1': 0.9181494661921709}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.41it/s]
epoch 16: {'accuracy': 0.8897058823529411, 'f1': 0.9211908931698775}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.42it/s]
epoch 17: {'accuracy': 0.8897058823529411, 'f1': 0.9203539823008849}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.35it/s]
epoch 18: {'accuracy': 0.8872549019607843, 'f1': 0.9195804195804195}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 8.39it/s]
epoch 19: {'accuracy': 0.8921568627450981, 'f1': 0.923076923076923}
Observations: No performance gap between runs with gradient checkpointing and without gradient checkpointing.
Thanks @pacman100. I got it now - a minimalist example is needed. I will try to create one over the weekend.
@pacman100. Hi Sourab, thanks for investing the time!
You didn't say otherwise, so it's confirmed that using gradient checkpointing should not change the functional impact of the model, correct?
I now have a minimal implementation sample notebook that shows the issue.
Background: The original code is from an article that illustrates for educational purposes how a simple LoRA implementation looks like. It's just Python code and worked fine, until I tried gradient checkpointing in the 2nd article.
I am not aware of specific expectations that the transformers lib has on code. But there are two things I do in my example that may be worth pointing out as not being in the middle of the road. (a) Freezing modules and (b) overwriting the forward function in the module to be adapted to point it to the adapter implementation in the forward pass. Both work fine without gradient checkpointing, but maybe they are problematic with gradient checkpointing? The code is in the example I linked above, but for easier consumption I reproduce this method here:
def adapt_model(model):
class MinimalLoRAAdapter(nn.Module):
def __init__(self,
adaptee):
super().__init__()
self.adaptee = adaptee
self.orig_forward = adaptee.forward
adaptee.forward = self.forward # <-----------------
r = 1
adaptee.lora_A = nn.Parameter(
torch.randn(adaptee.in_features, r) / math.sqrt(adaptee.in_features)
)
adaptee.lora_B = nn.Parameter(torch.zeros(r, adaptee.out_features))
def forward(self, x, *args, **kwargs):
return (
self.orig_forward(x, *args, **kwargs) # <-----------------
+ F.dropout(x, 0.1) @ self.adaptee.lora_A @ self.adaptee.lora_B
)
# freeze all layers, incl. embeddings, except for the classifier
for m in model.roberta.modules():
m.requires_grad_(False) # <-----------------
# Adapt linear modules in transformer layers
for m in model.roberta.encoder.modules():
if isinstance(m, nn.Linear):
MinimalLoRAAdapter(m)
Here is an excerpt from the output. Full output in the linked notebook (check eval_accuracy):
---- without gradient checkpointing ----
[..]
model.is_gradient_checkpointing=False
[..]
{'train_runtime': 457.1886, 'train_samples_per_second': 489.951, 'train_steps_per_second': 2.187, 'train_loss': 0.38296363830566404, 'epoch': 3.32}
{'eval_loss': 0.23593959212303162, 'eval_accuracy': 0.908256880733945, 'eval_runtime': 1.6902, 'eval_samples_per_second': 515.919, 'eval_steps_per_second': 64.49, 'epoch': 3.32}
---- with gradient checkpointing ----
[..]
model.is_gradient_checkpointing=True
[..]
{'train_runtime': 227.8506, 'train_samples_per_second': 983.101, 'train_steps_per_second': 4.389, 'train_loss': 0.6675097045898437, 'epoch': 3.32}
{'eval_loss': 0.6635248064994812, 'eval_accuracy': 0.5194954128440367, 'eval_runtime': 1.6397, 'eval_samples_per_second': 531.808, 'eval_steps_per_second': 66.476, 'epoch': 3.32}
[..]
I tried the above with both GPU and CPU and I can observe the same behavior. Hope that helps to narrow it down.
Gentle ping @pacman100
Hello @marianokamp,
Thank you for the minimal reproducer via the notebook. I ran it using the latest versions with the below changes:
+ gradient_checkpointing_kwargs = None
if cp_enabled:
- model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False})
+ gradient_checkpointing_kwargs = {"use_reentrant":False}
training_args = TrainingArguments(
gradient_checkpointing=cp_enabled,
+ gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
...
The issue you are facing with gradient checkpointing with LoRA is as follows:
use_reentrant=True
as mentioned in https://pytorch.org/docs/stable/checkpoint.html:
At least one input and output must have requires_grad=True for the reentrant variant. If this condition is unmet, the checkpointed part of the model will not have gradients. The non-reentrant version does not have this requirement.
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False})
but the Trainer was resetting it because of trainingArgument gradient_checkpointing
and as you didn't pass gradient_checkpointing_kwargs
, the default value of use_reentrant=True
was used. This is clear from the warning in your notebook output:
requires_grad=True
which is required when using use_reentrant=True
. As such, no gradients are computed and no learning happens leading to very low model accuracy. use_reentrant=False
. use_reentrant=True
is to make the outputs of the embedding layer require grads even though you won't be needing it as this fulfils the condition of least one input and output must have requires_grad=True for the reentrant variant
. You can see this being done in the PEFT codebase at https://github.com/huggingface/peft/blob/02b5aeddf9c1ea11451f10a8a26da7e5df8cca4a/src/peft/utils/other.py#L112-L122Output with the above changes:
Library versions:
Code:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer, set_seed
from torch import nn
from torch.nn import functional as F
import math
hf_ckp = 'roberta-base'
set_seed(100)
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return {f"accuracy": (predictions == labels).mean()}
def count_parameters(m, verbose=True):
total_count = 0
learnable_count = 0
if verbose:
print("Parameters (name, tunable, count):")
output_width = max([len(n) for n, _ in m.named_parameters()])
for n, p in m.named_parameters():
count = p.data.numel()
if verbose:
print(f" {n:{output_width}} {p.requires_grad:5b} {count:>11d}")
total_count += count
if p.requires_grad:
learnable_count += count
print(
f"Total parameters: {total_count:,}, "
f"thereof learnable: {learnable_count:,} "
f"({learnable_count/total_count*100.:5.4f}%)"
)
return total_count, learnable_count
def adapt_model(model):
# Minimalized example in place of the original LoRA-from-Scratch
# implementation from the article:
# https://towardsdatascience.com/dive-into-lora-adapters-38f4da488ede
class MinimalLoRAAdapter(nn.Module):
def __init__(self,
adaptee):
super().__init__()
self.adaptee = adaptee
self.orig_forward = adaptee.forward
adaptee.forward = self.forward
r = 1
adaptee.lora_A = nn.Parameter(
torch.randn(adaptee.in_features, r) / math.sqrt(adaptee.in_features)
)
adaptee.lora_B = nn.Parameter(torch.zeros(r, adaptee.out_features))
def forward(self, x, *args, **kwargs):
return (
self.orig_forward(x, *args, **kwargs)
+ F.dropout(x, 0.1) @ self.adaptee.lora_A @ self.adaptee.lora_B
)
# freeze all layers, incl. embeddings, except for the classifier
for m in model.roberta.modules():
m.requires_grad_(False)
# Adapt linear modules in transformer layers
for m in model.roberta.encoder.modules():
if isinstance(m, nn.Linear):
MinimalLoRAAdapter(m)
%%time
tokenizer = AutoTokenizer.from_pretrained(hf_ckp)
collator = DataCollatorWithPadding(tokenizer=tokenizer)
datasets.logging.disable_progress_bar()
dataset = datasets.load_dataset("glue", "sst2")
train = dataset["train"]
valid = dataset["validation"]
def preprocess_function(examples):
return tokenizer(examples['sentence'], padding=False, truncation=True)
tokenized_train = train.map(preprocess_function, batched=False)
tokenized_valid = valid.map(preprocess_function, batched=False)
def train(cp_enabled, model):
gradient_checkpointing_kwargs = None
if cp_enabled:
gradient_c_heckpointing_kwargs = {"use_reentrant":False}
training_args = TrainingArguments(
gradient_checkpointing=cp_enabled,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
output_dir="out",
per_device_train_batch_size=224,
learning_rate=3e-5,
save_steps=10_000,
eval_steps= 250,
max_steps = 1_000,
evaluation_strategy="steps",
save_strategy="steps",
save_total_limit=1,
disable_tqdm=True,
metric_for_best_model='eval_accuracy',
report_to="none", # Disable wandb, tensorboard
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_valid,
tokenizer=tokenizer,
data_collator=collator,
compute_metrics=compute_metrics,
)
print(f'{model.is_gradient_checkpointing=}')
total, learnable = count_parameters(model, verbose=False)
trainer.train()
trainer.evaluate()
print('\n---- without gradient checkpointing ----\n')
model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)
adapt_model(model)
train(False, model)
del(model) # essential!
print('\n---- with gradient checkpointing ----\n')
model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)
adapt_model(model)
train(True, model)
@pacman100, thanks for your help and walking me through the solution in detail. I am still a bit confused by the API, but I understand the steps you showed me and following them fixed my issue in my original, non-minimal, code. All clear for me now. Much appreciated, Sourab!
System Info
Latest released and py3.10.
accelerate-0.21.0 aiohttp-3.8.5 aiosignal-1.3.1 async-timeout-4.0.3 bitsandbytes-0.41.0 datasets-2.14.5 evaluate-0.4.0 frozenlist-1.4.0 huggingface-hub-0.17.1 multidict-6.0.4 peft-0.4.0 pynvml-11.5.0 regex-2023.8.8 responses-0.18.0 safetensors-0.3.3 sagemaker-inference-1.10.0 tensorboardX-2.6.2.2 tokenizers-0.13.3 transformers-4.33.2 xxhash-3.3.0 yarl-1.9.2
Who can help?
@pacman100, @muellerzr
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Hi @pacman100, @muellerzr.
I was wondering about the memory use of LoRA. Specifically what happens if I adapt modules that are
Given that the number of parameters to train remains the same in both cases, the memory usage should be the same, except that to calculate the gradients for (bottom) we would need to keep more activations around from the forward pass. If that were the case, then turning on gradient checkpointing should make (top) and (bottom) use the same memory, as we are discarding the activations and recalculating them on the backward pass. That is correct, no (@younesbelkada)?
Trying this out, I can see that behavior as expected. However, the accuracy also changed. My understanding would be that with gradient checkpointing we would now need less memory, more time, but the functional aspects, here model performance, should be unchanged. Hence the issue.
Details
Below you can see on the x-axis on which layer of a 12 layer RoBERTa Base the adapters were applied. As you can see the memory for (bottom - lower layer numbers, closer to the embeddings) are higher than for (top - higher layer numbers, closer to the head), when not using gradient checkpointing, and they are same when using gradient checkpointing.
However, when looking at the model performance we can see that we have a difference of 0.1 between using and not using checkpointing.
Not that it matters, but this is using the glue/sst-2 dataset. I am not changing anything, but passing 0 or 1 as an argument to Trainer's gradient_checkpointing attribute (and 0 and 1 to empty-cuda-cache every 30 seconds).
Expected behavior
No functional change when using gradient_checkpointing.