takuseno / d4rl-atari

Datasets for data-driven deep reinforcement learning with Atari (wrapper for datasets released by Google)
MIT License
101 stars 14 forks source link

different observation with/without `stack=True` #6

Closed weiguowilliam closed 2 years ago

weiguowilliam commented 3 years ago

When I set stack = True/False for the same environment and get the first observation&reward&action: for stack case, the first observation is dataset_s['observations'][0][0,:]) for unstack case, the first observation is dataset['observations'][0,:]).

The question is, in both cases, the reward list and the action list is same. But the observation list in stack/unstack cases are different. I attached the first observation in stack/unstack case. I wonder what the reason is? Could you please explain it? Thanks in advance.

stacked case

unstacked case

here's the code:

import gym
import d4rl_atari
import pickle
import numpy as np
import matplotlib.pyplot as plt

def test_stack():
    env_s = gym.make('ms-pacman-expert-v0', stack=True) # -v{0, 1, 2, 3, 4} for datasets with the other random seeds
    env_s.reset()
    dataset_s = env_s.get_dataset()
    ob_s = dataset_s['observations'][0]
    # print(len(ob_s)) 1m
    # print(ob_s[0].shape) (4,84,84)
    re_s = dataset_s['rewards']
    # print(re_s.shape) (1m,)

    env = gym.make('ms-pacman-expert-v0', stack=False)
    env.reset()
    dataset = env.get_dataset()
    ob = dataset['observations'][0,:]
    re = dataset['rewards']
    print(np.sum(re != re_s))  # 0, so reward sequence is same
    a_s = dataset_s['actions']
    a = dataset['actions']
    print(np.sum(a_s != a)) #0, so action sequence is same
    o_s = ob_s[0,:]
    plt.imshow(o_s)
    plt.show()
    o = ob[0,:]
    plt.imshow(o)
    plt.show()
    # print(np.sum(o_s != o))

if __name__ == '__main__':
    test_stack()
takuseno commented 3 years ago

@weiguowilliam Thanks for reporting this issue! When stack=True, the first 3 observations (frame_stack=4) has the black channels to fill the frames since there is no past observations. https://github.com/takuseno/d4rl-atari/blob/8d1d3ff621d822a65adc2227441d7c220324f445/d4rl_atari/offline_env.py#L57

There are two ideas to change this behavior

  1. fill past frames with the initial frame instead of black frames
  2. remove the first 3 steps when stack=True

Do you have any thoughts on this?