DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.03k stars 1.69k forks source link

[Question] Can a model be used in environments with different observation_space sizes? #2031

Open SummerDiver opened 2 hours ago

SummerDiver commented 2 hours ago

❓ Question

I am trying to use stablebaselines3 to handle a graph-related problem. Graphs of different sizes will have different numbers of nodes and edges, resulting in different observation space sizes of the environment defined based on this.

My goal is to train an agent using environment A and then use it in environment B. I have customized a feature extractor using a graph neural network and it should be able to handle inputs of graphs of uncertain sizes.

However, when I try to input the observation generated by environment B into the model, an error occurs:

    env_small = GymMISEnv(data_folder_path) # with observation shape (148640, 2)
    env_big = GymMISEnv(test_data_folder_path) # with observation shape (250151, 2)
    policy_kwargs = {
        "features_extractor_class": GNNFeatureExtractor,
        "features_extractor_kwargs": {"features_dim": 64},
        "net_arch": dict(pi=[128, 64], vf=[128, 64]),
    }
    model = PPO("MultiInputPolicy", env_small, policy_kwargs=policy_kwargs, verbose=1) # env_small to train
    obs, _ = env_big.reset() # switch to env_big
    model.predict(obs)  # Error: Unexpected observation shape (250151, 2) for Box environment, please use (148640, 2) or (n_env, 148640, 2) for the observation shape.

Is there a way to handle this problem?

Checklist

araffin commented 2 hours ago

Hello, there is no "out of the box" solution for your problem, you will probably need to add an adapter layer to have the same size at the end (and also the model might complain, we have some checks at load time).

SummerDiver commented 1 hour ago

I see, thanks a lot for replying, I'll try to see if I can hold this. Otherwise I'm afraid I have to implement the whole work without sb3 unfortunately. Still, very appreciate for your excellent work!