huggingface / safetensors

Simple, safe way to store and distribute tensors
https://huggingface.co/docs/safetensors
Apache License 2.0
2.79k stars 189 forks source link

Add `safetensors.torch.save_state_dict()` to python bindings #501

Closed mohahf19 closed 2 months ago

mohahf19 commented 2 months ago

What does this PR do?

This PR adds the safetensors.torch.save_state_dict() method to the python bindings. The safetensors.torch.save_model() method handles models that are instantiated in memory that also contain shared tensors. From the torch.save_model()'s structure, we see that there is no need for the model to be in instantiated, and only the state dictionary is necessary.

This PR moves the saving logic after getting the state dictionary from the safetensors.torch.save_model() method to its own safetensors.torch.save_state_dict() method, and refactors the save_model method to just call the save_state_dict method instead.

Motivation for this change: make converting loaded models from pytorch_model.bin files to model.safetensors files easier. The conversion process does not need the model to be instantiated, but only needs to know which tensors have to be converted. With the current interface, we would have to use save_file, which does not support shared tensors. Otherwise, we would have to load the state dictionary into a model. This change allows us to skip this process and call save_state_dict directly.

Narsil commented 2 months ago

Hi @mohahf19,

Unfortunately your premise is wrong, we do need the model to load the model correctly (not to save your are right). The reason is that torch.load_state_dict does not respect the invariants those functions respect, and everything would bug out if we enabled something like safetensors.torch.save_state_dict to be used with model.load_state_dict(safetensors.torch.load_state_dict(filename)).

load_model/save_model is there specifically to have symmetrical functions, that makes it easier to resolve weight sharing. If you are not using weight sharing, then save_file/load_file is what you are looking for.

Thanks a lot for this PR and the work though.