DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.02k stars 1.69k forks source link

`torch.load` without `weights_only` parameter is unsafe #1852

Open kit1980 opened 8 months ago

kit1980 commented 8 months ago

This is found via https://github.com/pytorch-labs/torchfix/

torch.load without weights_only parameter is unsafe. Explicitly set weights_only to False only if you trust the data you load and full pickle functionality is needed, otherwise set weights_only=True.

stable_baselines3/common/policies.py:176:27

--- /home/sdym/repos/stable-baselines3/stable_baselines3/common/policies.py
+++ /home/sdym/repos/stable-baselines3/stable_baselines3/common/policies.py
@@ -171,11 +171,11 @@
         :param path:
         :param device: Device on which the policy should be loaded.
         :return:
         """
         device = get_device(device)
-        saved_variables = th.load(path, map_location=device)
+        saved_variables = th.load(path, map_location=device, weights_only=True)

         # Create policy object
         model = cls(**saved_variables["data"])
         # Load weights
         model.load_state_dict(saved_variables["state_dict"])

stable_baselines3/common/save_util.py:450:33

--- /home/sdym/repos/stable-baselines3/stable_baselines3/common/save_util.py
+++ /home/sdym/repos/stable-baselines3/stable_baselines3/common/save_util.py
@@ -445,11 +445,11 @@
                     file_content.write(param_file.read())
                     # go to start of file
                     file_content.seek(0)
                     # Load the parameters with the right ``map_location``.
                     # Remove ".pth" ending with splitext
-                    th_object = th.load(file_content, map_location=device)
+                    th_object = th.load(file_content, map_location=device, weights_only=True)
                     # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
                     if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
                         # PyTorch variables (not state_dicts)
                         pytorch_variables = th_object
                     else:
araffin commented 8 months ago

Duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1831

kit1980 commented 8 months ago

@araffin you should specify weights_only=False if you need pickle. Otherwise when soon PyTorch changes the default of weights_only to True, your code will break.

araffin commented 8 months ago

could you elaborate a bit? where was that change announced?

kit1980 commented 8 months ago

@araffin I don't think there is an announcement, but we're definitely thinking of it. See this comment https://github.com/pytorch/pytorch/issues/111806#issuecomment-1785685208

araffin commented 8 months ago

Thanks, btw what is the minimum pytorch version to be able to set weights_only=True?

kit1980 commented 8 months ago

@araffin The PR that added the option is https://github.com/pytorch/pytorch/pull/86812, first release with it is PyTorch 1.13.0

araffin commented 7 months ago

After trying out, we cannot use weights_only=True in SB3 as it breaks some functionality, see https://github.com/DLR-RM/stable-baselines3/pull/1866. It would be nice to be able to extend _get_allowed_globals() for the unpickler.

markscsmith commented 6 months ago

So glad I found this!

If you're getting the error:

Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar

when loading a model this is a possible cause. 2.2.1 doesn't have throw the error, but because it doesn't use weights_only=True? stable_baselines3==2.3.0 throws this if I do a model.save and model.load of a PPO model.

I think it's this line here? https://github.com/DLR-RM/stable-baselines3/blob/5623d98f9d6bcfd2ab450e850c3f7b090aef5642/stable_baselines3/common/save_util.py#L450

Maybe a param of some kind that passes weights_only through to the underlying torch.load and let the dev / user decide if they trust the source?

I'll drop a PR when my testing comes out OK! I promise it'll be cleaner than the heap of late night hacking I've been trying to get to play Pokémon ;)

araffin commented 6 months ago

If you're getting the error:

Please provide a minimal example to reproduce the error.

throws this if I do a model.save and model.load of a PPO model.

I guess you are doing something custom because the tests passes on the CI server.

I think it's this line here?

yes, probably.

markscsmith commented 6 months ago

Please provide a minimal example to reproduce the error.

For sure! Right now the only example I have is my 16mb blob of model, so I'm trying to find the minimal reproduction of that here.

I guess you are doing something custom because the tests passes on the CI server.

Yup! I too think it's something strange I'm doing with my model in particular, because I don't think I'm doing anything fancy with model.save and model.load themselves.

yes, probably.

Excellent! Hopefully my weights_only idea is overkill and it's just a weird quirk of my model that I can adjust, and the PR will just be a test and maybe a warning if someone makes the same mistake as I'm making.

If not, the pull request I'm working on still defaults to weights_only=True and throws a warning if it's overridden to false. My theory is that won't disrupt existing users but would allow people doing weird stuff to be able to load models they trust.

markscsmith commented 6 months ago

@araffin Figured it out! My learning_rate_schedule was using np.pi and np.sin. I've got a test to reproduce now and a pull request ready if my approach is OK! How would you feel about me adding an enhancement to warn if model.save() is called with objects that won't unpickle with weights_only=True?

araffin commented 6 months ago

thanks for finding out. I can reproduce with:

from stable_baselines3 import PPO
import numpy as np

model = PPO("MlpPolicy", "CartPole-v1", learning_rate=lambda _: np.sin(1))
model.save("demo")
model = PPO.load("demo")

it comes from policy.optimizer, although I'm a bit confused of why because the optimizer should only receive float from the learning schedule. I guess it has type np.ndarray instead of float and that crashes everything.

I have a simpler fix in your case, cast to float as it is the required type for lr schedule:

from stable_baselines3 import PPO
import numpy as np

model = PPO("MlpPolicy", "CartPole-v1", learning_rate=lambda _: float(np.sin(1)))
model.save("demo")
model = PPO.load("demo")

EDIT: a better PR would be to cast any call to learning_rate() to float

markscsmith commented 6 months ago

EDIT: a better PR would be to cast any call to learning_rate() to float

Oooh, good call! I'll start on that a bit later today! I'm curious about the gymnasium loading issue mentioned earlier as well. Maybe something similar where it's using fancy numpy types?

araffin commented 6 months ago

Maybe something similar where it's using fancy numpy types?

it's different, the problem occurs because we want to save the complete nn.Module object that contains types (from gymnasium, potentially from numpy) not on the pytorch whitelist.

Franziac commented 6 months ago

I also ran into the same error as @markscsmith . I'm storing some values (some probably np.nan) in the env that are probably causing the error. I made a workaround by changing weights_only = False.

I understand that the issue originates from me doing something that I probably shouldn't, but I don't really see the harm in doing this:

Maybe a param of some kind that passes weights_only through to the underlying torch.load and let the dev / user decide if they trust the source?

markscsmith commented 6 months ago

@araffin I took a look at the logic around the learning_rate solution and am running tests on a fix now. I'll open a new issue and PR for that fix in particular. Thank you again for the help figuring this out!

@Franziac do you have more detail about what you mean by values in the env? Based on araffin's previous comments it might be as simple as converting the types into safely-unpickleable types in the right spot.

That said, arrafin, if you think you're going to get a bunch of bug reports like this from people doing odd stuff with models, I've cleaned up my original changes for weights_only a bit, and am noodling how to do a warning on save of "hey, when you try to unpickle this you're going to get an error!"

My first instinct was "how do I get weights_only=False?" but you led me to ask "why is weights_only=False suddenly necessary?" I imagine this would help @Franziac identify when the special case is needed as well. Given that torch is the one creating that option in the first place, maybe something to do upstream in pytorch?

They seem to be getting a fair number of issues around this as well, and a proactive "hey this object contains weird stuff, weights_only=False will be necessary. Here are the objects that are weird:" might be a nudge to devs to cast to safer types? It aligns well with the advice you gave me that worked for me too!

araffin commented 6 months ago

@Franziac please provide a minimal working example to reproduce the issue.

araffin commented 6 months ago

@Franziac please have a look at https://github.com/DLR-RM/stable-baselines3/issues/1911, it seems to be due to an old version of PyTorch, in the meantime, I will revert the change and release SB3 2.3.2 that should solve the issue.