Open DASpringate opened 4 years ago
The InvalidArgumentError is happening in the second (lstm) layer, which doesn't make sense to me since this should just fit the output of the first embedding layer
Not sure what's going on here - don't have much experience working with Keras/LSTM. Tried reproducing your example locally and observed that the batches were constructed properly. However, you already knew that since in your example, fetching the data to Python and then feeding it back to keras worked fine...
Thanks. There's a reproducible toy example of the issue here. Is support for keras/lstm on the roadmap?
I was under impression that as long as we have a tf.data.Dataset object, we are covered with Keras. Was my assumption incorrect? Let's try figuring out what's going on here. Maybe it can be fixed?
Digged a little bit. I am not a Keras/LSTM expert, so maybe you can help here. What I saw that the issue is caused by a transpose operator in tensorflow/python/keras/layers/recurrent.py
, swap_batch_timestep
function.
During graph construction the static shape is assumed to be [?, ?, 20], but in graph evaluation it gets [10 1 100 1]
, hence the transpose fails.
I double checked the static and dynamic shapes of the tensors returned by make_petastorm_dataset
and they appeared to be as expected.
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
import tensorflow as tf
input_t = tf.Print(input_t, ['Static shape:', str(input_t.shape), 'Dynamic shape:', tf.shape(input_t)], summarize=1000)
return array_ops.transpose(input_t, axes)
and got:
[Static shape:][(?, ?, 1)][Dynamic shape:][10 1 100 1]
Does it hint you anything?
I'm trying to read in data from parquet for a language model.
The parquet contains two columns:
When I try the code below I get an InvalidArgumentError. This seems to be because the int array in the parquet file is not getting transformed to a tf.Dataset but is just passed through as is.
warning:
error:
But, if I convert the dataset to an iterator and then run the output X and Ys separately, it runs as expected for that batch:
Why should this work but the full dataset does not?