Closed StarBaseOne closed 3 years ago
@hougiebear Thanks for reporting this issue! I believe this is because of the 2D shape observation. I guess this will solve your problem.
from gym.spaces import Box
class FlattenWrapperEnv(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
# this is important
shape = self.observation_space.shape
self.observation_space = Box(shape=(shape[0] * shape[1],), low=-1, high=1, dtype=np.float32)
def step(self, action):
obs, reward, done, info = super.step(action)
flat_obs = np.reshape(obs, [-1])
return flat_obs, reward, done, info
def reset(self):
obs = self.reset()
return np.reshape(obs, [-1])
@hougiebear Thanks for reporting this issue! I believe this is because of the 2D shape observation. I guess this will solve your problem.
from gym.spaces import Box class FlattenWrapperEnv(gym.Wrapper): def __init__(self, env): super().__init__(env) # this is important shape = self.observation_space.shape self.observation_space = Box(shape=(shape[0] * shape[1],), low=-1, high=1, dtype=np.float32) def step(self, action): obs, reward, done, info = super.step(action) flat_obs = np.reshape(obs, [-1]) return flat_obs, reward, done, info def reset(self): obs = self.reset() return np.reshape(obs, [-1])
Thank you very much for responding Takuma. That worked. (had some issues with max recursion limit but ironed that out).
Also looked at the OpenAIgym wrapper repo and this works also
from gym import ObservationWrapper
class FlattenObservation(ObservationWrapper):
r"""Observation wrapper that flattens the observation."""
def __init__(self, env):
super(FlattenObservation, self).__init__(env)
self.observation_space = spaces.flatten_space(env.observation_space)
def observation(self, observation):
return spaces.flatten(self.env.observation_space, observation)
Hello dear all, many thanks for your great comments. I found them very clear and useful. I am using the same environment for an offline RL task and facing almost the same issue. For my task, I need to collect some data from the environment via some policy. In order to prevent the observation shape issue, I used the gym wrapper to flatten the observation space and then collected the data from the environment with a random policy employing the code provided in the d3rlpy documentation. However, I am receiving an error while starting the data collection process. I will appreciate any help.
The wrapper used to flatten the observation space:
import gym
import gym.spaces as spaces
class FlattenObservation(gym.ObservationWrapper):
def __init__(self, env: gym.Env):
super().__init__(env)
self.observation_space = spaces.flatten_space(env.observation_space)
def observation(self, observation):
return spaces.flatten(self.env.observation_space, observation)
The code used to collect data using random policy (from documentation):
import d3rlpy
# setup algorithm
random_policy = d3rlpy.algos.DiscreteRandomPolicy()
# prepare experience replay buffer
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=100000, env=env) # env is the flatten version from now on
# start data collection
random_policy.collect(env, buffer, n_steps=100000)
# export as MDPDataset
dataset = buffer.to_mdp_dataset()
The received error:
Hello Takuma
I am working with highway-env (custom environment) and have tried to test your implementation of the DQN as interested in using the Discrete CQL and CQL implementations alongside SB3. THe problem I am having is with the observation shapes of the environment (I've tried flattening the observations) and would like to know if you have any ideas to sort this out, perhaps you have seen this before? The observation shape is a 2D array. I tried tinkering with a custom policy, flattening the observations and using the VectorEncoder but to no avail.
return spaces.Box(shape=(self.vehicles_count, len(self.features)), low=-1, high=1, dtype=np.float32)
Snippet of the code
The error I receive with an observation array of shape (5,5) is
If I flatten it to 1D it still fails. Why is it expecting observation_shape to be 1? When I change the observation to image based (using Nature CNN, 4 stacked frames of 128,64 Net) I also receive a different error. (I chose n_frames of 4 for stacking)
The images are uint8 and follow the channel first layout of # C x W x H, As you can see in the params.json output the observation_shape is (1, 128, 64) with 4 stacked frames (4, 128, 64).
Not sure why it's throwing the assertion for the numpy array, I confirmed indeed that the input is an <class 'numpy.ndarray'>
Confirmed using latest release of d3rlpy 0.80