keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

RaggedTensor #19646

Open markub3327 opened 2 weeks ago

markub3327 commented 2 weeks ago

Hi @fchollet,

I try to implement Embedding layer with RaggedTensor in TF 2.16.1:

inputs = tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True)
x = tf.keras.layers.Embedding(input_dim=vectorize_layer.vocabulary_size(), output_dim=16)(inputs)
m = tf.keras.Model(inputs=inputs, outputs=x)

where is vectorize_layer = tf.keras.layers.TextVectorization(..., ragged=True). I need a different size of sentences as input to the model. For more info here is docs: https://www.tensorflow.org/guide/ragged_tensor#tensorflow_apis

The issue is:

TypeError Traceback (most recent call last) Cell In[6], line 1 ----> 1 inputs = tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True) 2 x = tf.keras.layers.Embedding(input_dim=vectorize_layer.vocabulary_size(), output_dim=16)(inputs) 3 m = tf.keras.Model(inputs=inputs, outputs=x)

TypeError: Input() got an unexpected keyword argument 'ragged'

Thanks for your reply.

fchollet commented 2 weeks ago

Keras 3 does not support Ragged tensors at this time. You should drop the ragged argument.

markub3327 commented 2 weeks ago

It is not an option that I need. Does not exist the path to I can solve it? I would like to have a different sentence length. Without padding it is consuming the memory unnecessarily. Thanks.

fchollet commented 2 weeks ago

You can bucket your inputs by sequence length and reduce your batch size in order to minimize padding requirements. With small batch sizes and sufficient bucketing padding becomes pretty much unnecessary.

markub3327 commented 2 weeks ago

@fchollet This is working:

sample = [
    'A B',
    'A B C',
    'A B C D',
    'A B C D E',
    'A B C D E F',
    'A B C D E F G',
    'A B C D E F G H',
    'A B C D E F G H I',
    'A B C D E F G H I J',
    'A B C D E F G H I J K',
    'A B C D E F G H I J K L',
]
ds = tf.data.Dataset.from_tensor_slices(sample)
ds = ds.map(lambda x: tf.strings.split(x)).repeat().bucket_by_sequence_length(
    element_length_func=lambda x: tf.shape(x)[0],
    bucket_boundaries=[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    bucket_batch_sizes=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
)

for x in ds.take(10):
    print(x)

Result

tf.Tensor( [[b'A' b'B'] [b'A' b'B'] [b'A' b'B'] [b'A' b'B']], shape=(4, 2), dtype=string) tf.Tensor( [[b'A' b'B' b'C'] [b'A' b'B' b'C'] [b'A' b'B' b'C'] [b'A' b'B' b'C']], shape=(4, 3), dtype=string) tf.Tensor( [[b'A' b'B' b'C' b'D'] [b'A' b'B' b'C' b'D'] [b'A' b'B' b'C' b'D'] [b'A' b'B' b'C' b'D']], shape=(4, 4), dtype=string) tf.Tensor( [[b'A' b'B' b'C' b'D' b'E'] [b'A' b'B' b'C' b'D' b'E'] [b'A' b'B' b'C' b'D' b'E'] [b'A' b'B' b'C' b'D' b'E']], shape=(4, 5), dtype=string) tf.Tensor( [[b'A' b'B' b'C' b'D' b'E' b'F'] [b'A' b'B' b'C' b'D' b'E' b'F'] [b'A' b'B' b'C' b'D' b'E' b'F'] [b'A' b'B' b'C' b'D' b'E' b'F']], shape=(4, 6), dtype=string) tf.Tensor( [[b'A' b'B' b'C' b'D' b'E' b'F' b'G'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G']], shape=(4, 7), dtype=string) tf.Tensor( [[b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H']], shape=(4, 8), dtype=string) tf.Tensor( [[b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I']], shape=(4, 9), dtype=string) tf.Tensor( [[b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I' b'J'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I' b'J'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I' b'J'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I' b'J']], shape=(4, 10), dtype=string) tf.Tensor( [[b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I' b'J' b'K'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I' b'J' b'K'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I' b'J' b'K'] [b'A' b'B' b'C' b'D' b'E' b'F' b'G' b'H' b'I' b'J' b'K']], shape=(4, 11), dtype=string)

The repeat() is only for illustration purposes. Thanks for your time.