DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.35k stars 1.6k forks source link

[Bug]: Manually setting net_arch=None causes crash when loading model #1928

Closed jak3122 closed 1 month ago

jak3122 commented 1 month ago

šŸ› Bug

I've run into this issue a few times, since I use command line args in a wrapper script where I default net_arch to None, and I pass that to policy_kwargs.

I was able to fix it locally by changing this line from

if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0:

to

if data.get("policy_kwargs", {}).get("net_arch"):

To Reproduce

import gymnasium as gym

from stable_baselines3 import PPO

env = gym.make("CartPole-v1")

model = PPO("MlpPolicy", env, policy_kwargs=dict(net_arch=None))
model.learn(total_timesteps=5000)
model.save("ppo_cartpole")

del model

model = PPO.load("ppo_cartpole")

Relevant log output / Error message

Traceback (most recent call last):
  File "/Users/hyzer/stable-baselines3/min_example.py", line 13, in <module>
    model = PPO.load("ppo_cartpole")
  File "/Users/hyzer/stable-baselines3/stable_baselines3/common/base_class.py", line 695, in load
    if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0:
TypeError: object of type 'NoneType' has no len()

System Info

Checklist

araffin commented 1 month ago

Hello, why do you want to set net_arch=None?

jak3122 commented 1 month ago

In my training script I have CLI args, including an optional net_arch arg:

net_arch: Tuple[int, ...] | None = None

Which I then pass to the sb3 model:

model = PPO(
    policy_type,
    env,
    policy_kwargs=dict(
        net_arch=args.net_arch,
        ...
    )
)

Since the documentation says that net_arch can be None, I would expect this to work, and have None indicate the default net_arch, which it does, except for loading the saved model.

I also realize it's possible to just omit net_arch from policy_kwargs instead, like this:

policy_kwargs = dict()
if args.net_arch is not None:
    policy_kwargs["net_arch"] = args.net_arch
model = PPO(
    policy_type,
    env,
    policy_kwargs=policy_kwargs
)

But then for anyone like me who happens to save a model with net_arch manually set to None, then the model will not load.

araffin commented 1 month ago

Since the documentation says that net_arch can be None, I would expect this to work, and have None indicate the default net_arch, which it does, except for loading the saved mode

I would be happy to receive a PR that solves this issue =)

Although in your case, I would indeed recommend not populating the net_arch instead of passing None.