huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.32k stars 26.35k forks source link

BertForSequenceClassification.from_pretrained broken when using FSDP #32068

Open ojh31 opened 2 months ago

ojh31 commented 2 months ago

System Info

- `Accelerate` version: 0.29.2
- Platform: Linux-6.5.0-44-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /home/oskar/projects/robust-llm/venv/bin/accelerate
- Python version: 3.10.12
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.2.2+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 14.85 GB
- GPU type: NVIDIA GeForce RTX 3050 Ti Laptop GPU
- `Accelerate` default config:
        Not found

Information

Tasks

Reproduction

The following code successfully loads the model checkpoint if ran using python foo.py or accelerate launch foo.py but not with FSDP enabled accelerate launch --use_fsdp foo.py.

This seems like a bug where we say in PreTrainedModel.from_pretrained that pretrained_model_name_or_path can be None "if you are both providing the configuration and state dictionary", which I do here. But then if is_fsdp_enabled() is True, we set low_cpu_mem_usage = True and thus in turn state_dict = None, which causes the loading to fail.

import torch
from transformers import BertForSequenceClassification
from accelerate import Accelerator

checkpoint_path = "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_original-c1212f89.ckpt"
model_type = "bert-base-uncased"
num_classes = 6

accelerator = Accelerator()
loaded = torch.hub.load_state_dict_from_url(
    checkpoint_path, map_location=accelerator.device
)
state_dict = loaded["state_dict"]
config = BertForSequenceClassification.config_class.from_pretrained(
    model_type, num_labels=num_classes
)
model = BertForSequenceClassification.from_pretrained(
    pretrained_model_name_or_path=None,
    config=config,
    state_dict=state_dict,
    local_files_only=False,
)
print(type(model))

Error message:

   model = BertForSequenceClassification.from_pretrained(
  File "/home/oskar/projects/robust-llm/venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3754, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/oskar/projects/robust-llm/venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4194, in _load_pretrained_model
    state_dict = load_state_dict(shard_file, is_quantized=is_quantized)
  File "/home/oskar/projects/robust-llm/venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 506, in load_state_dict
    if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
AttributeError: 'NoneType' object has no attribute 'endswith'

Expected behavior

Should output <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>

muellerzr commented 2 months ago

Transferring to transformers as this is really a transformers issue.

Things I need:

Versions of not just accelerate but transformers as well, and can you try updating to their latest?

ojh31 commented 2 months ago

Transferring to transformers as this is really a transformers issue.

Things I need:

Versions of not just accelerate but transformers as well, and can you try updating to their latest?

I'm on latest transformers already (4.42.4) and accelerate=0.29.2

muellerzr commented 2 months ago

Can you try the latest accelerate as well

ojh31 commented 2 months ago

Can you try the latest accelerate as well

Same behavior with accelerate==0.32.1

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

LysandreJik commented 4 days ago

Thanks for your issue @ojh31!

This is indeed an issue but isn't related to FSDP per-se. It's relate d to this unsafe code path in which we're replacing the existing state_dict with None in order to load it differently:

https://github.com/huggingface/transformers/blob/74026b473e8748706a7a86fd20d6a275306d8ffb/src/transformers/modeling_utils.py#L3830-L3834

This isn't safe as we can see here. @SunMarc, would you have the bandwidth to take a look at this issue?