huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.68k stars 26.22k forks source link

'CLIPEncoder' object has no attribute '_gradient_checkpointing_func' #30264

Closed TideDra closed 4 months ago

TideDra commented 4 months ago

System Info

Who can help?

@amyeroberts

Information

Tasks

Reproduction

When using gradient_checkpointing, CLIPEncoder calls the self._gradient_checkpointing_func https://github.com/huggingface/transformers/blob/51bcadc10a569847b93a30dbe3a077037ae63bad/src/transformers/models/clip/modeling_clip.py#L628 However, it seems that this attribute is defined in PretrainedModel, while CLIPEncoder is not a subclass of PretrainedModel. This bug can be reproduced using the following code

from transformers import CLIPVisionModel
vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
print(vision_tower.vision_model.encoder._gradient_checkpointing_func)

which reports 'CLIPEncoder' object has no attribute '_gradient_checkpointing_func'

Expected behavior

Currently, I replace the self.__gradient_checkpointing_func directly with torch.utils.checkpoint.checkpoint, and it works when training with gradient checkpointing. However, I think this is not a perfect solution, since we need __gradient_checkpointing_func to handle the gradient_checkpointing_kwargs.

amyeroberts commented 4 months ago

Hi @TideDra, thanks for opening this issue!

You first have to enable gradient checkpointing for the model:

In [1]: from transformers import CLIPVisionModel
   ...: vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
   ...: vision_tower.gradient_checkpointing_enable({"use_reentrant": True})
   ...: vision_tower.vision_model.encoder._gradient_checkpointing_func
Out[1]: functools.partial(<function checkpoint at 0x15accd240>, use_reentrant=True)
TideDra commented 4 months ago

Thanks for your help!