huggingface / huggingface_hub

The official Python client for the Huggingface Hub.
https://huggingface.co/docs/huggingface_hub
Apache License 2.0
2.02k stars 531 forks source link

Serialization: support saving torch state dict to disk #2314

Closed Wauplin closed 3 months ago

Wauplin commented 4 months ago

Implement save_torch_state_dict to save a torch state dictionary to disk (first part of https://github.com/huggingface/huggingface_hub/issues/2065). It uses split_torch_state_dict_into_shards under the hood (https://github.com/huggingface/huggingface_hub/pull/1938).

State dict is saved either to a single file (if less than 5GB) or to shards with the corresponding index.json. By default, shards are saved as safetensors but safe_serialization=False can be passed to save as pickle. A warning is logged when saving as pickle and hopefully we should be able to dropped support for it when transformers/diffusers/accelerate/... completely phase out from .bin saving. cc @LysandreJik I'd like to get your opinion on this. I'm fine with not adding support for .bin files at all but worry it would slow down adoption in our libraries.

For the implementation, I took inspiration from https://github.com/huggingface/diffusers/pull/7830 + accelerate/transformers. What it does:

  1. Split state dict into shard (logic already exists)
  2. Clean existing directory (remove previous shard/index files)
  3. Write shards to disk
  4. Write index to disk (optional)

Example:

>>> from huggingface_hub import save_torch_state_dict
>>> model = ... # A PyTorch model

# Save state dict to "path/to/folder"
# The model is split into shards of 5GB each and saved as safetensors.
>>> state_dict = model_to_save.state_dict()
>>> save_torch_state_dict(state_dict, "path/to/folder")

cc @amyeroberts / @ArthurZucker for transformers, @sayakpaul for diffusers, @SunMarc @muellerzr for accelerate Happy to get feedback on this type of critical part. The goal is to standardize things to be consistent across libraries so please let me know if you want to add/remove something!

(documentation has also been updated)


note: I also removed split_numpy_state_dict_into_shards which is a breaking change but I don't expect anything to break in the wild. Better to just remove it to avoid future maintenance (I shouldn't have added it in the first place).

(failing CI is unrelated)

HuggingFaceDocBuilderDev commented 4 months ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Wauplin commented 3 months ago

Thanks everyone for the reviews! :heart: