danijar / dreamerv2

Mastering Atari with Discrete World Models
https://danijar.com/dreamerv2
MIT License
898 stars 195 forks source link

Input shape incompatible #9

Closed RyanRTJJ closed 3 years ago

RyanRTJJ commented 3 years ago

Hi authors, thanks for your paper and code. I was trying to test dreamerv2 on retro games, and I spent a really long time looking at the code and trying to debug, but I have no clue what's going on.

I ran python3 dreamerv2/train.py --logdir ~/logdir/atari_pong/dreamerv2/1 --configs defaults retro --task retro_Airstriker-Genesis, and the output seemed good for a while:

Logdir /Users/ryantjj/logdir/atari_pong/dreamerv2/1
Create envs.
make_env(): suite is retro.
task: Airstriker-Genesis

This shows that I parsed the arguments correctly, and also hooked up gym-retro, and edited the configs.yaml and envs.py files to support retro.

But after some iterations it seems, I run into this error:

    /dreamerv2/agent.py:79 train  *
        metrics.update(self._task_behavior.train(self.wm, start, reward))
    /dreamerv2/agent.py:212 train  *
        feat, state, action, disc = world_model.imagine(self.actor, start, hor)
    /dreamerv2/agent.py:150 step  *
        succ = self.rssm.img_step(state, action)
    ./common/other.py:41 static_scan  *
        last = fn(last, inp)
    ./common/nets.py:105 img_step  *
        x = self.get('img_in', tfkl.Dense, self._hidden, self._act)(x)
    /Users/ryantjj/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:1013 __call__  **
        input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
    /Users/ryantjj/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/input_spec.py:255 assert_input_compatibility
        ' but received input with shape ' + display_shape(x.shape))

    ValueError: Input 0 of layer dense is incompatible with the layer: expected axis -1 of input shape to have value 1042 but received input with shape (2450, 1036)

So I printed the shapes of these variables in nets.py:105 img_step by inserting the print statements in this function as shown:

  @tf.function
  def img_step(self, prev_state, prev_action, sample=True):
    prev_stoch = self._cast(prev_state['stoch'])
    prev_action = self._cast(prev_action)
    if self._discrete:
      shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
      prev_stoch = tf.reshape(prev_stoch, shape)
    x = tf.concat([prev_stoch, prev_action], -1)
    print("prev_stoch.shape: " + str(prev_stoch.shape))    # OVER HERE 
    print("prev_action.shape: " + str(prev_action.shape))   # OVER HERE
    x = self.get('img_in', tfkl.Dense, self._hidden, self._act)(x)
    deter = prev_state['deter']
    x, deter = self._cell(x, [deter])
    deter = deter[0]  # Keras wraps the state in a list.
    x = self.get('img_out', tfkl.Dense, self._hidden, self._act)(x)
    stats = self._suff_stats_layer('img_dist', x)
    dist = self.get_dist(stats)
    stoch = dist.sample() if sample else dist.mode()
    prior = {'stoch': stoch, 'deter': deter, **stats}
    return prior

And these are the terminal outputs when I run the code:

Create agent.
prev_stoch.shape: (50, 1024)
prev_action.shape: (50, 18)
prev_stoch.shape: (50, 1024)
prev_action.shape: (50, 18)
prev_stoch.shape: (50, 1024)
prev_action.shape: (50, 18)
Found 19975379 model parameters.
prev_stoch.shape: (2450, 1024)
prev_action.shape: (2450, 12)
Traceback (most recent call last):

Any clue as to why the prev_action.shape changed from 18 to 12? Thanks for getting through this really long post. I really appreciate your help! :)

danijar commented 3 years ago

I think this means that your input images are not of shape 64x64x3. Can you verify that's the case?

danijar commented 3 years ago

Closing due to inactivity. Feel free to comment below if this you still have a question about this.

vmichals commented 2 years ago

Just stumbled on this issue and I guess OP has solved the issue, but in case someone runs into a similar issue: My guess is that the logdir had a checkpoint from an older experiment on another task which was loaded (--logdir ~/logdir/atari_pong/dreamerv2/1 suggests the logdir is being reused) causing a mismatch between action space sizes.