HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.26k stars 239 forks source link

FrameStack Bug. #692

Closed Liuzy0908 closed 1 year ago

Liuzy0908 commented 1 year ago

Bug description

I want to use the frame stacking technique (4 consecutive frames of images as model input), which works well in PPO-only in SB3.

But after running the above program (about GAIL), the shape of obs collected by rollouts is: (1024,59,256,1). It seems to be collected are single frame images instead of 4 consecutive frames. I think the correct shape of obs should be (1024,59,256,4).

This causes the policy network to fail later when GAIL.train(): The obs input of PPO is (batch, 4, 59, 256), while the obs provided by rollouts is (batch, 1, 59, 526).

How should I solve this problem? Looking forward to your reply.

Steps to reproduce

from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv

print(MyEnv().observation_space.shape)              # out: (59, 256, 1). 

venv = DummyVecEnv([lambda: RolloutInfoWrapper(MyEnv())])
env = VecFrameStack(venv, n_stack=4)

print(env.observation_space.shape)                  # out: (59, 256, 4). Frame Stack.

expert = PPO.load(model_dir, env=env)

print(expert.get_env().observation_space.shape)     # out: (4, 59, 256)

rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    expert.get_env(),
    rollout.make_sample_until(min_timesteps=None, min_episodes=1),
    rng=rng,
)

image

Environment

ernestum commented 1 year ago

Hi @Liuzy0908 to replicate this bug I need to know what environment you used. In the code it says MyEnv, which could be anything. Would you please provide a code snippet that I can execute right away?