Open kit1980 opened 8 months ago
@araffin you should specify weights_only=False
if you need pickle.
Otherwise when soon PyTorch changes the default of weights_only
to True
, your code will break.
could you elaborate a bit? where was that change announced?
@araffin I don't think there is an announcement, but we're definitely thinking of it. See this comment https://github.com/pytorch/pytorch/issues/111806#issuecomment-1785685208
Thanks, btw what is the minimum pytorch version to be able to set weights_only=True
?
@araffin The PR that added the option is https://github.com/pytorch/pytorch/pull/86812, first release with it is PyTorch 1.13.0
After trying out, we cannot use weights_only=True
in SB3 as it breaks some functionality, see https://github.com/DLR-RM/stable-baselines3/pull/1866.
It would be nice to be able to extend _get_allowed_globals()
for the unpickler.
So glad I found this!
If you're getting the error:
Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar
when loading a model this is a possible cause. 2.2.1 doesn't have throw the error, but because it doesn't use weights_only=True
? stable_baselines3==2.3.0
throws this if I do a model.save and model.load of a PPO model.
I think it's this line here? https://github.com/DLR-RM/stable-baselines3/blob/5623d98f9d6bcfd2ab450e850c3f7b090aef5642/stable_baselines3/common/save_util.py#L450
Maybe a param of some kind that passes weights_only through to the underlying torch.load and let the dev / user decide if they trust the source?
I'll drop a PR when my testing comes out OK! I promise it'll be cleaner than the heap of late night hacking I've been trying to get to play Pokémon ;)
If you're getting the error:
Please provide a minimal example to reproduce the error.
throws this if I do a model.save and model.load of a PPO model.
I guess you are doing something custom because the tests passes on the CI server.
I think it's this line here?
yes, probably.
Please provide a minimal example to reproduce the error.
For sure! Right now the only example I have is my 16mb blob of model, so I'm trying to find the minimal reproduction of that here.
I guess you are doing something custom because the tests passes on the CI server.
Yup! I too think it's something strange I'm doing with my model in particular, because I don't think I'm doing anything fancy with model.save and model.load themselves.
yes, probably.
Excellent! Hopefully my weights_only idea is overkill and it's just a weird quirk of my model that I can adjust, and the PR will just be a test and maybe a warning if someone makes the same mistake as I'm making.
If not, the pull request I'm working on still defaults to weights_only=True and throws a warning if it's overridden to false. My theory is that won't disrupt existing users but would allow people doing weird stuff to be able to load models they trust.
@araffin Figured it out! My learning_rate_schedule was using np.pi and np.sin. I've got a test to reproduce now and a pull request ready if my approach is OK! How would you feel about me adding an enhancement to warn if model.save() is called with objects that won't unpickle with weights_only=True?
thanks for finding out. I can reproduce with:
from stable_baselines3 import PPO
import numpy as np
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=lambda _: np.sin(1))
model.save("demo")
model = PPO.load("demo")
it comes from policy.optimizer
, although I'm a bit confused of why because the optimizer should only receive float from the learning schedule. I guess it has type np.ndarray
instead of float
and that crashes everything.
I have a simpler fix in your case, cast to float as it is the required type for lr schedule:
from stable_baselines3 import PPO
import numpy as np
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=lambda _: float(np.sin(1)))
model.save("demo")
model = PPO.load("demo")
EDIT: a better PR would be to cast any call to learning_rate()
to float
EDIT: a better PR would be to cast any call to learning_rate() to float
Oooh, good call! I'll start on that a bit later today! I'm curious about the gymnasium loading issue mentioned earlier as well. Maybe something similar where it's using fancy numpy types?
Maybe something similar where it's using fancy numpy types?
it's different, the problem occurs because we want to save the complete nn.Module
object that contains types (from gymnasium, potentially from numpy) not on the pytorch whitelist.
I also ran into the same error as @markscsmith . I'm storing some values (some probably np.nan) in the env that are probably causing the error. I made a workaround by changing weights_only = False.
I understand that the issue originates from me doing something that I probably shouldn't, but I don't really see the harm in doing this:
Maybe a param of some kind that passes weights_only through to the underlying torch.load and let the dev / user decide if they trust the source?
@araffin I took a look at the logic around the learning_rate solution and am running tests on a fix now. I'll open a new issue and PR for that fix in particular. Thank you again for the help figuring this out!
@Franziac do you have more detail about what you mean by values in the env? Based on araffin's previous comments it might be as simple as converting the types into safely-unpickleable types in the right spot.
That said, arrafin, if you think you're going to get a bunch of bug reports like this from people doing odd stuff with models, I've cleaned up my original changes for weights_only a bit, and am noodling how to do a warning on save of "hey, when you try to unpickle this you're going to get an error!"
My first instinct was "how do I get weights_only=False?" but you led me to ask "why is weights_only=False suddenly necessary?" I imagine this would help @Franziac identify when the special case is needed as well. Given that torch is the one creating that option in the first place, maybe something to do upstream in pytorch?
They seem to be getting a fair number of issues around this as well, and a proactive "hey this object contains weird stuff, weights_only=False will be necessary. Here are the objects that are weird:" might be a nudge to devs to cast to safer types? It aligns well with the advice you gave me that worked for me too!
@Franziac please provide a minimal working example to reproduce the issue.
@Franziac please have a look at https://github.com/DLR-RM/stable-baselines3/issues/1911, it seems to be due to an old version of PyTorch, in the meantime, I will revert the change and release SB3 2.3.2 that should solve the issue.
This is found via https://github.com/pytorch-labs/torchfix/
torch.load
withoutweights_only
parameter is unsafe. Explicitly setweights_only
to False only if you trust the data you load and full pickle functionality is needed, otherwise setweights_only=True
.stable_baselines3/common/policies.py:176:27
stable_baselines3/common/save_util.py:450:33