Closed mohahf19 closed 2 months ago
Hi @mohahf19,
Unfortunately your premise is wrong, we do need the model to load the model correctly (not to save your are right).
The reason is that torch.load_state_dict
does not respect the invariants those functions respect, and everything would bug out if we enabled something like safetensors.torch.save_state_dict
to be used with model.load_state_dict(safetensors.torch.load_state_dict(filename))
.
load_model/save_model
is there specifically to have symmetrical functions, that makes it easier to resolve weight sharing.
If you are not using weight sharing, then save_file/load_file
is what you are looking for.
Thanks a lot for this PR and the work though.
What does this PR do?
This PR adds the
safetensors.torch.save_state_dict()
method to the python bindings. Thesafetensors.torch.save_model()
method handles models that are instantiated in memory that also contain shared tensors. From thetorch.save_model()
's structure, we see that there is no need for the model to be in instantiated, and only the state dictionary is necessary.This PR moves the saving logic after getting the state dictionary from the
safetensors.torch.save_model()
method to its ownsafetensors.torch.save_state_dict()
method, and refactors thesave_model
method to just call thesave_state_dict
method instead.Motivation for this change: make converting loaded models from
pytorch_model.bin
files tomodel.safetensors
files easier. The conversion process does not need the model to be instantiated, but only needs to know which tensors have to be converted. With the current interface, we would have to usesave_file
, which does not support shared tensors. Otherwise, we would have to load the state dictionary into a model. This change allows us to skip this process and callsave_state_dict
directly.