skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.89k stars 391 forks source link

Add PyTorch 2.4.0 to CI #1063

Closed BenjaminBossan closed 2 months ago

BenjaminBossan commented 4 months ago

Also:

BenjaminBossan commented 3 months ago

Okay, so I investigated the failing tests with PyTorch 2.4 a bit further. Right now, we get a warning:

FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.

The reason why this warning leads to an error is pure coincidence: In the given tests, we have a filter to catch warnings for different reasons, which is what is triggered by this new FutureWarning:

https://github.com/skorch-dev/skorch/blob/346d7050117e0abd19f12831aa26e214b710716f/skorch/tests/test_hf.py#L715-L719

Anyway, it's good that we got an early indicator that this will break in the future. However, fixing the problem is not trivial. Here is why:

Since PyTorch 1.13, there is an option in torch.load called weights_only. If set to True, only a select few types can be loaded, with the intent of making torch.load more secure. As it uses pickle under the hood, there is the danger of malicious code being executed when loading torch checkpoints, so this is a sensible decision (note that alternatives like safetensors and gguf don't suffer from this).

Right now, the default is to set weights_only=False but as the warning indicates, this will be flipped to weights_only=True in the future. Here are some ideas to address this.

1. Defaulting to weights_only=True

My first idea was to fix the warning by switching to weights_only=True in skorch, where we use torch.load for deserialization. However, this makes a bunch of tests fail because they use types that are not considered to be secure.

As an example, each time we define a custom nn.Module, it is considered unsafe and we would have to call torch.serialization.add_safe_globals([MyModule]) if the test involves serialization. But that's not enough: Even builtin types like set and PyTorch types like nn.Linear are not considered secure, so all of these would have to be added too.

The latter could be done once inside of conftest and it would be fine but I really don't want to scatter the code with torch.serialization.add_safe_globals each time a custom class is defined. Moreover, if we make this switch, it means that a bunch of user code would start failing. Yes, this is bound to happen when PyTorch makes the switch, but still it's not a nice experience.

What's also annoying is that PyTorch reports these insecure types only one at a time, so we have to add them, run the tests again, get a new error, add the new type, etc.

2. Setting weights_only=False

We could hard-code this and thus ensure that all the code that used to work will continue working. Neither tests, nor user code would require adjusting. This also wouldn't be more insecure as the status quo, but it defeats the whole idea of making PyTorch more secure.

If we take this route, we should allow users to override this by exposing the argument.

3. Not setting anything in torch.load

I.e. just leaving the code as is and using whatever default is used by the installed PyTorch version. The failing test would still fail, but it could be fixed by excepting this FutureWarning. User code would work as normal. When the new PyTorch version with flipped defaults releases, users have to start dealing with this, same as other PyTorch users. Similarly, we will have to deal with this for skorch, same as discussed in solution 1.

For now, I have reported this internally to PyTorch devs, let's see what comes out of it.

Input by others would be appreciated @ottonemo @thomasjpfan

thomasjpfan commented 3 months ago

In the long term, I'll want a way to allow weights_only=True even if it takes some time to get right with torch.serialization.add_safe_globals. For skorch, I propose:

  1. Use weights_only=False as the default
  2. Add torch_load_kwargs to NeuralNet.__init__ to allow user to override torch.load kwargs and set weights_only=True.

Concretely, whenever we call torch.load:

default_load_kwargs = {"weights_only": True}

torch_load_kwargs = {**default_load_kwargs, **self.torch_load_kwargs}
torch.load(..., **torch_load_kwargs)
BenjaminBossan commented 3 months ago

Thanks for the input, this sounds reasonable. It's not pretty, but since we cannot directly pass arguments to __setstate__, I don't see a better way.

As to the default: WDYT about using "auto" and then switching to whatever the default is for the given PyTorch version?

I found that there is also a context manager torch.serializaton.safe_globals. For test-specific classes, we can use that, for the rest like set we can use add_safe_globals in conftest.py. Edit: This was only added recently, so it's not available for older releases.

Edit: Planned release is v2.6.0.

ottonemo commented 3 months ago

As to the default: WDYT about using "auto" and then switching to whatever the default is for the given PyTorch version?

I like this. It would mean that we expose a way of handling model loading security to the user while keeping pytorch's defaults. Since this is a long-standing security issue I'd say we should at least follow the pytorch default as soon as they deem the ecosystem to be ready for it.

We could simply use the pytorch release version as a default indicator (might be better than using inspect?)

I assume that the need for a class variable for the load kwarg comes from the fact that we support pickling skorch models?

I found that there is also a context manager torch.serializaton.safe_globals. For test-specific classes, we can use that, for the rest like set we can use add_safe_globals in conftest.py.

I was going to say that it might be beneficial to have the tests look as close to user code where possible so that we have approximately the same issues (in terms of functionality but also in terms of 'design') as our users do. The context manager + whitelisting generic classes is probably a good middle-ground.

thomasjpfan commented 3 months ago

I'm happy with an "auto" option.