huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.76k stars 941 forks source link

Correct loading of models with shared tensors when using accelerator.load_state() #2875

Closed jkuntzer closed 2 months ago

jkuntzer commented 3 months ago

What does this PR do?

I would run into problems with PyTorch's _load_statedict complaining about missing keys. These keys belonged to shared tensors. These shared keys are intentionally omitted by the safetensors library. To load a model correctly, one has to use safetensor's _loadmodel function instead of the default _load_statedict function (described here). This was previously not done when using the _loadstate function of the Accelerator.

Fixes # (issue) I think this issue might be relevant as they also report problems when loading with _accelerator.loadstate. https://github.com/huggingface/accelerate/issues/2155

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

HuggingFaceDocBuilderDev commented 3 months ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

jkuntzer commented 2 months ago

Yes, I'll have a look into it.

jkuntzer commented 2 months ago

You can verify that the shared weights are implemented correctly by checking the output. safetensors warns you about that fact.

jkuntzer commented 2 months ago

Just did. This is the expected error message I get when reverting my changes. Screenshot from 2024-07-09 15-15-51

SunMarc commented 2 months ago

Just did. This is the expected error message I get when reverting my changes.

I was only expecting linear2.weight and linear2.bias to be missing. Maybe this is due to

self.weight = self.linear1.weight
self.bias = self.linear1.bias
jkuntzer commented 2 months ago

Just did. This is the expected error message I get when reverting my changes.

I was only expecting linear2.weight and linear2.bias to be missing. Maybe this is due to

self.weight = self.linear1.weight
self.bias = self.linear1.bias

After removing the unnecessary bits, it correctly only throws an error for the weights and bias of the 2nd linear layer. Screenshot from 2024-07-09 16-41-56

SunMarc commented 2 months ago

Nice ! Could you just fix the quality issue (make style) and we are good to merge !