rail-berkeley / rlkit

Collection of reinforcement learning algorithms
MIT License
2.5k stars 553 forks source link

Loading checkpoint trained on GPU on a CPU device fails. #67

Closed vuoristo closed 5 years ago

vuoristo commented 5 years ago

It seems that the way model parameters are saved in rlkit/core/logging.py using pickle.dump results in checkpoints, which are not recoverable on a computer without a GPU, if the checkpoint was trained on a GPU. Loading parameters of a SAC model trained on GPU using scripts/run_policy.py results in the same error message as in this pytorch issue. I tried the different map_location arguments from that issue but they did not fix the problem for me.

Changing pickle.dump into torch.save fixes the problem in my case. Not sure if that change has some side effects elsewhere.

Verified this happens on commit c138bae3b3904c25de2c37c950e315410b3c0b99

vitchyr commented 5 years ago

Good catch and suggestion! Could you submit a PR for this (and test that loading the policy works)?

vuoristo commented 5 years ago

Above PR applies the proposed fix. With the fix, the policy recovery works on a mac when the model was trained on a linux computer with a GPU.

vitchyr commented 5 years ago

Thank you!

nanbaima commented 5 years ago

I was facing a similar problem, I'm running my training without GPU, and while trying to load the model, I get an error of Magic Number Failed, here they say it's some problem with loading the data with a different library than pytorch. This solution fixed the problem.

vitchyr commented 5 years ago

@nanbaima The admin in the forum you linked to concludes, "If you save it with another library and try to load it using PyTorch, you’ll encounter this error."

Do you know if that's what's happening? In particular, if you saved it with the old rlkit version (prior to fed75c6) and tried to load it with the new rlkit version (after fed75c6), then you'll get this error since you saved it with pickle.dump but loaded it with torch.load.

nanbaima commented 5 years ago

@nanbaima The admin in the forum you linked to concludes, "If you save it with another library and try to load it using PyTorch, you’ll encounter this error."

Do you know if that's what's happening? In particular, if you saved it with the old rlkit version (prior to fed75c6) and tried to load it with the new rlkit version (after fed75c6), then you'll get this error since you saved it with pickle.dump but loaded it with torch.load.

Sorry for not being clear. I wanted to say that yours new version, the one with the fix of this Issue (after fed75c6), also fixed my problem. I wanted just to make sure that, people that might have had the same problem, could find this solution here. Which is just to update the rlkit version.