carlthome / tensorflow-convlstm-cell

A ConvLSTM cell with layer normalization and peepholes for TensorFlow's RNN API.
MIT License
399 stars 150 forks source link

Working with varying cell-state sizes #19

Open anjany opened 7 years ago

anjany commented 7 years ago

When working with variable shaped inputs (defined by [batch_size, None, None, channels]), I get the following error in dynamic_rnn (at line 115 of rnn_cell_impl.py in TF1.2.1) during the graph construction phase: Provided a prefix or suffix of None: Tensor("rnn_7/strided_slice:0", shape=(), dtype=int32) and (?, ?, 1024)

I get the same error when I work with your example by changing the 'shape' to [None, None] instead of [640, 480]. So, is there a way to work with inputs of varying dimensions? (Observe that for a given unrolled RNN, thiese dimensions would be fixed)

I guess this might be a related commit: https://github.com/tensorflow/tensorflow/commit/54efd636b504aad368eea254eca2970a16d457f6

carlthome commented 7 years ago

AFAIK it is not mathematically possible to leave the shape of the convolution kernels unspecified.

If you just want to lazily initialize the shapes at runtime you could use tf.shape.

anjany commented 7 years ago

Well, it is not the convolution kernel size I am talking about. It is the shape of the input:

I get the same error when I work with your example by changing the 'shape' to [None, None] instead of [640, 480]

carlthome commented 7 years ago

Ah, my bad. Yes, that should be possible. Could you try with peephole=False and see if that works?

anjany commented 7 years ago

Sorry for the delay. I was away on a vacation. And, no luck with peephole=False. :/

DragonZzzz commented 6 years ago

@anjany Hello, have you solved this problem?

carlthome commented 6 years ago

tf.nn.dynamic_rnn seems to assume static shapes unfortunately.

zero_state = nest.map_structure(lambda x: tf.zeros_like(inputs[:, 0]), cell.state_size)
tf.nn.dynamic_rnn(cell, inputs, initial_state=zero_state)

or similar should work but you get hit with https://github.com/tensorflow/tensorflow/blob/c81830af5d488de600a4f62392c63059e310c017/tensorflow/python/ops/rnn.py#L699-L702