huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.39k stars 880 forks source link

save_state removes shared weights but load_state cannot load properly #2807

Open MiladInk opened 1 month ago

MiladInk commented 1 month ago

System Info

accelerate version: 0.27.2
python: 3.11

Information

Tasks

Reproduction

I am saving the state_dict of an 'facebook/opt-125m' model. In this model the weights are shared between embedding tokens and the language modelling head. When I am saving the state dictionary of the model, I see this warning:

WARNING: Removed shared tensor {'pretrained_model.lm_head.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading

The problem is when I want to load_state the same object, there is an error that:

Missing key(s) in state_dict: "pretrained_model.lm_head.weight". 

I do understand that because the weights are shared they are removed, but I don't understand how can I work with models which have shared weights then?

Interesting thing is, the code was working with the previous versions of the libraries. Unfortunately, I don't have the old environment to tell you exactly where things break.

Thanks in advance.

Expected behavior

I expected the save_state and load_state to be able to restore the original model no matter what. This does not work.

muellerzr commented 1 month ago

cc @SunMarc

SunMarc commented 1 month ago

Hi @MiladInk, thanks for the report. Could you share a minimal reproducer ? When we load a model with shared weights, we make sure to tie the shared weights together.

raghavgarg97 commented 4 weeks ago

I am facing a similar issue when trying to load and save “google/gemma-2b”

SunMarc commented 4 weeks ago

Hi @raghavgarg97, could you share a minimal reproducer ?

github-actions[bot] commented 4 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.