keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
62 stars 28 forks source link

RNN backward layers illegally apply zero_output_for_mask #252

Open mergian opened 1 year ago

mergian commented 1 year ago

System information.

Describe the problem.

According to https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN the zero_output_for_mask only gets applied, when return_sequences=True. However, this gets ignored when go_backwards=True. As can be seen in the reproducer, while zero_output_for_mask has no effect in go_backwards=False, it has if True.

Describe the current behavior. All values get zeroed although zero_output_for_mask should not apply in cases where return_sequences=False.

Describe the expected behavior. The layer should work as described in the documentation.

Contributing.

I have not looked into details yet, but I assume that there is something wrong with the conditions within the go_backwards=True path.

Standalone code to reproduce the issue.

import tensorflow as tf
import numpy as np

seq             = 7
in_channels     = 8
out_channels    = 8
inp = tf.keras.Input(batch_shape=[1, seq, in_channels])

np.random.seed(314159)
data = np.random.rand(1, seq, in_channels).astype(np.float32)

data[:, 0, :]   = -10
data[:, 2, :]   = -10
data[:, 3, :]   = -10
data[:, -2, :]  = -10

mask = tf.keras.layers.Masking(-10)(inp)
labels = []
outputs = []
weights = None

for return_sequences in [False, True]:
    for go_backwards in [False, True]:
        for zero_output_for_mask in [False, True]:
            tf.keras.utils.set_random_seed(314159)
            l = tf.keras.layers.LSTM(units=out_channels, go_backwards=go_backwards, return_sequences=return_sequences)
            if weights is None: weights = l.get_weights()
            else:               assert l.get_weights() == weights
            l.zero_output_for_mask = zero_output_for_mask
            outputs.append(l(mask))
            labels.append(f'go_backwards={go_backwards}, return_sequences={return_sequences}, zero_output_for_mask={zero_output_for_mask}:')

model = tf.keras.Model(inputs=inp, outputs=outputs)

for k, v in zip(labels, model(data)):
    print(k, v)
    print()

Source code / logs.

go_backwards=False, return_sequences=False, zero_output_for_mask=False: tf.Tensor(
[[-0.02429421 -0.09013332 -0.10361656 -0.13974014 -0.20900866 -0.26837015
  -0.07549819 -0.33952928]], shape=(1, 8), dtype=float32)

go_backwards=False, return_sequences=False, zero_output_for_mask=True: tf.Tensor(
[[-0.02429421 -0.09013332 -0.10361656 -0.13974014 -0.20900866 -0.26837015
  -0.07549819 -0.33952928]], shape=(1, 8), dtype=float32)

go_backwards=True, return_sequences=False, zero_output_for_mask=False: tf.Tensor(
[[-0.05535218 -0.11682055 -0.04727266 -0.14445937 -0.16172363 -0.2500477
  -0.00716362 -0.29466543]], shape=(1, 8), dtype=float32)

go_backwards=True, return_sequences=False, zero_output_for_mask=True: tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 8), dtype=float32)

go_backwards=False, return_sequences=True, zero_output_for_mask=False: tf.Tensor(
[[[ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [-0.0499199  -0.05824665  0.02581842 -0.0982305  -0.09205388
   -0.0882687   0.02808495 -0.13197206]
  [-0.0499199  -0.05824665  0.02581842 -0.0982305  -0.09205388
   -0.0882687   0.02808495 -0.13197206]
  [-0.0499199  -0.05824665  0.02581842 -0.0982305  -0.09205388
   -0.0882687   0.02808495 -0.13197206]
  [-0.0581002  -0.08441985  0.00981993 -0.11347026 -0.10313985
   -0.14142308 -0.00306875 -0.1604866 ]
  [-0.0581002  -0.08441985  0.00981993 -0.11347026 -0.10313985
   -0.14142308 -0.00306875 -0.1604866 ]
  [-0.02429421 -0.09013332 -0.10361656 -0.13974014 -0.20900866
   -0.26837015 -0.07549819 -0.33952928]]], shape=(1, 7, 8), dtype=float32)

go_backwards=False, return_sequences=True, zero_output_for_mask=True: tf.Tensor(
[[[ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [-0.0499199  -0.05824665  0.02581842 -0.0982305  -0.09205388
   -0.0882687   0.02808495 -0.13197206]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [-0.0581002  -0.08441985  0.00981993 -0.11347026 -0.10313985
   -0.14142308 -0.00306875 -0.1604866 ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [-0.02429421 -0.09013332 -0.10361656 -0.13974014 -0.20900866
   -0.26837015 -0.07549819 -0.33952928]]], shape=(1, 7, 8), dtype=float32)

go_backwards=True, return_sequences=True, zero_output_for_mask=False: tf.Tensor(
[[[ 0.01506061 -0.03969907 -0.1244414  -0.08451919 -0.15994504
   -0.1811547  -0.0766468  -0.19487911]
  [ 0.01506061 -0.03969907 -0.1244414  -0.08451919 -0.15994504
   -0.1811547  -0.0766468  -0.19487911]
  [-0.00483934 -0.08338716 -0.09934962 -0.10254624 -0.12932442
   -0.22227708 -0.06278189 -0.19041967]
  [-0.00483934 -0.08338716 -0.09934962 -0.10254624 -0.12932442
   -0.22227708 -0.06278189 -0.19041967]
  [-0.00483934 -0.08338716 -0.09934962 -0.10254624 -0.12932442
   -0.22227708 -0.06278189 -0.19041967]
  [-0.05535218 -0.11682055 -0.04727266 -0.14445937 -0.16172363
   -0.2500477  -0.00716362 -0.29466543]
  [-0.05535218 -0.11682055 -0.04727266 -0.14445937 -0.16172363
   -0.2500477  -0.00716362 -0.29466543]]], shape=(1, 7, 8), dtype=float32)

go_backwards=True, return_sequences=True, zero_output_for_mask=True: tf.Tensor(
[[[ 0.01506061 -0.03969907 -0.1244414  -0.08451919 -0.15994504
   -0.1811547  -0.0766468  -0.19487911]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [-0.00483934 -0.08338716 -0.09934962 -0.10254624 -0.12932442
   -0.22227708 -0.06278189 -0.19041967]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [-0.05535218 -0.11682055 -0.04727266 -0.14445937 -0.16172363
   -0.2500477  -0.00716362 -0.29466543]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]]], shape=(1, 7, 8), dtype=float32)
tilakrayal commented 1 year ago

@sachinprasadhs, I was able to reproduce the issue on tensorflow v2.10, v2.11 and nightly. Kindly find the gist of it here.