Closed BenjaminBossan closed 2 months ago
Okay, so I investigated the failing tests with PyTorch 2.4 a bit further. Right now, we get a warning:
FutureWarning: You are using
torch.load
withweights_only=False
(the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value forweights_only
will be flipped toTrue
. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user viatorch.serialization.add_safe_globals
. We recommend you start settingweights_only=True
for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
The reason why this warning leads to an error is pure coincidence: In the given tests, we have a filter to catch warnings for different reasons, which is what is triggered by this new FutureWarning
:
Anyway, it's good that we got an early indicator that this will break in the future. However, fixing the problem is not trivial. Here is why:
Since PyTorch 1.13, there is an option in torch.load
called weights_only
. If set to True
, only a select few types can be loaded, with the intent of making torch.load
more secure. As it uses pickle under the hood, there is the danger of malicious code being executed when loading torch checkpoints, so this is a sensible decision (note that alternatives like safetensors and gguf don't suffer from this).
Right now, the default is to set weights_only=False
but as the warning indicates, this will be flipped to weights_only=True
in the future. Here are some ideas to address this.
weights_only=True
My first idea was to fix the warning by switching to weights_only=True
in skorch, where we use torch.load
for deserialization. However, this makes a bunch of tests fail because they use types that are not considered to be secure.
As an example, each time we define a custom nn.Module
, it is considered unsafe and we would have to call torch.serialization.add_safe_globals([MyModule])
if the test involves serialization. But that's not enough: Even builtin types like set
and PyTorch types like nn.Linear
are not considered secure, so all of these would have to be added too.
The latter could be done once inside of conftest
and it would be fine but I really don't want to scatter the code with torch.serialization.add_safe_globals
each time a custom class is defined. Moreover, if we make this switch, it means that a bunch of user code would start failing. Yes, this is bound to happen when PyTorch makes the switch, but still it's not a nice experience.
What's also annoying is that PyTorch reports these insecure types only one at a time, so we have to add them, run the tests again, get a new error, add the new type, etc.
weights_only=False
We could hard-code this and thus ensure that all the code that used to work will continue working. Neither tests, nor user code would require adjusting. This also wouldn't be more insecure as the status quo, but it defeats the whole idea of making PyTorch more secure.
If we take this route, we should allow users to override this by exposing the argument.
torch.load
I.e. just leaving the code as is and using whatever default is used by the installed PyTorch version. The failing test would still fail, but it could be fixed by excepting this FutureWarning
. User code would work as normal. When the new PyTorch version with flipped defaults releases, users have to start dealing with this, same as other PyTorch users. Similarly, we will have to deal with this for skorch, same as discussed in solution 1.
For now, I have reported this internally to PyTorch devs, let's see what comes out of it.
Input by others would be appreciated @ottonemo @thomasjpfan
In the long term, I'll want a way to allow weights_only=True
even if it takes some time to get right with torch.serialization.add_safe_globals
. For skorch
, I propose:
weights_only=False
as the defaulttorch_load_kwargs
to NeuralNet.__init__
to allow user to override torch.load
kwargs and set weights_only=True
.Concretely, whenever we call torch.load
:
default_load_kwargs = {"weights_only": True}
torch_load_kwargs = {**default_load_kwargs, **self.torch_load_kwargs}
torch.load(..., **torch_load_kwargs)
Thanks for the input, this sounds reasonable. It's not pretty, but since we cannot directly pass arguments to __setstate__
, I don't see a better way.
As to the default: WDYT about using "auto" and then switching to whatever the default is for the given PyTorch version?
I found that there is also a context manager torch.serializaton.safe_globals
. For test-specific classes, we can use that, for the rest like set
we can use add_safe_globals
in conftest.py
. Edit: This was only added recently, so it's not available for older releases.
Edit: Planned release is v2.6.0.
As to the default: WDYT about using "auto" and then switching to whatever the default is for the given PyTorch version?
I like this. It would mean that we expose a way of handling model loading security to the user while keeping pytorch's defaults. Since this is a long-standing security issue I'd say we should at least follow the pytorch default as soon as they deem the ecosystem to be ready for it.
We could simply use the pytorch release version as a default indicator (might be better than using inspect?)
I assume that the need for a class variable for the load
kwarg comes from the fact that we support pickling skorch models?
I found that there is also a context manager
torch.serializaton.safe_globals
. For test-specific classes, we can use that, for the rest like set we can useadd_safe_globals
inconftest.py
.
I was going to say that it might be beneficial to have the tests look as close to user code where possible so that we have approximately the same issues (in terms of functionality but also in terms of 'design') as our users do. The context manager + whitelisting generic classes is probably a good middle-ground.
I'm happy with an "auto" option.
Also: