DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.85k stars 1.68k forks source link

[Question] Best practice for transferring saved models from server #441

Closed axkoenig closed 3 years ago

axkoenig commented 3 years ago

Question

I am training models on a research cluster with SB3 and a custom environment. I then transfer the saved .zip models via sftp to my local computer to visualize the learned behavior in my robotics simulator. I run into an annoying problem here, since somehow the .zip file is corrupted on the way. When I run model.load("my_path") I get the following error:

Traceback (most recent call last):
  File "main.py", line 141, in <module>
    main(args)
  File "main.py", line 79, in main
    model.load(args.eval_model_path)
  File "/home/parallels/.local/lib/python3.8/site-packages/stable_baselines3/common/base_class.py", line 600, in load
    data, params, pytorch_variables = load_from_zip_file(path, device=device)
  File "/home/parallels/.local/lib/python3.8/site-packages/stable_baselines3/common/save_util.py", line 396, in load_from_zip_file
    data = json_to_data(json_data)
  File "/home/parallels/.local/lib/python3.8/site-packages/stable_baselines3/common/save_util.py", line 165, in json_to_data
    deserialized_object = cloudpickle.loads(base64_object)
AttributeError: Can't get attribute 'TrainFreq' on <module 'stable_baselines3.common.type_aliases' from '/home/parallels/.local/lib/python3.8/site-packages/stable_baselines3/common/type_aliases.py'>

Loading the model into a new session on the research cluster works without problems, so I really think the model transfer is the problem. How do you guys transfer your trained models from A to B?

Additional context

Here is my main script if this is of any interest

# ...

env = GazeboEnv(hparams)
if args.check_env:
    check_env(env)

n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1, tensorboard_log=log_path)

if args.train:
    rospy.loginfo("Training model...")
    model.learn(total_timesteps=args.time_steps, tb_log_name=args.log_name, log_interval=args.log_interval)
    model.save(model_path)
    rospy.loginfo("Saved final model under: " + model_path)

else:
    rospy.loginfo("Loading model from: " + args.eval_model_path)
    model.load(args.eval_model_path) # problem occurs here ... 
    rospy.loginfo("Evaluating model...")

    for episode in range(20):
        obs = env.reset()
        for t in range(1000):
            action, _state = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            if done:
                rospy.loginfo(f"Episode finished after {t+1} timesteps.")
                break

# ...
Miffyli commented 3 years ago

Make sure that your Python, PyTorch and stable-baselines3 versions match. Unfortunately some of the saving relies on Python's pickling, which does not always play nice between Python versions (same can happen when PyTorch version changes).

axkoenig commented 3 years ago

Thanks for this lightning fast reply! That was it, thank you! I was on a different stable_baselines3 version ...