keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.12k stars 19.49k forks source link

Models do not propagate symbolic masks when called with symbolic inputs. #18417

Open jackd opened 1 year ago

jackd commented 1 year ago

Describe the bug Models (both functional and sequential) do not propagate symbolic masks when called with symbolic inputs.

Example:

# from tensorflow import keras  # everything works with this import
import keras_core as keras  # errors as outlined below with this import

model = keras.Sequential(
    [keras.layers.InputLayer((3,)), keras.layers.Embedding(10, 100, mask_zero=True)]
)

x = keras.Input((3,))
out = model(x)
print(out._keras_mask)  #  AttributeError: 'KerasTensor' object has no attribute '_keras_mask'

inp = keras.Input((3,))
out = keras.layers.Embedding(10, 100, mask_zero=True)(inp)
print(out._keras_mask)  # works fine, <KerasTensor shape=(None, 3), dtype=float32, name=keras_tensor_3>
model = keras.Model(inp, out)

out2 = model(x)
print(out2._keras_mask)  # AttributeError: 'KerasTensor' object has no attribute '_keras_mask'

Expected behavior To be consistent with tf.keras

Additional context This came up in keras-nlp example porting issue.

fchollet commented 1 year ago

That's a private API, so consistency with tf.keras is not a concern here -- what did you need the symbolic masks for? They don't even exist in the functional API, they're only computed at runtime.

jackd commented 1 year ago

@fchollet this is relevant when building compound models, e.g. a transformer with separate encoder and decoder models.

Minimal example demonstrating non-private difference:

import keras_core as keras

inp = keras.Input((3,))
x = keras.layers.Embedding(10, 100, mask_zero=True)(inp)
x = keras.layers.Conv1D(3, 3)(x)
model = keras.Model(inp, x)

z1 = model(inp)**2

model1 = keras.Model(inp, z1)

z2 = x**2
model2 = keras.Model(inp, z2)

model1 and model2 should represent the same computation, but in reality they differ significantly during training due to the way masked inputs are treated.

This is obviously a highly contribed example, but there's a realistic example here, which is a port of the keras-nlp spanish-to-english translation transformer example to use keras_core rather than tensorflow.keras (works using tf.keras, raises mask-related error using keras-core). Further disucssion at issue linked above.

Also, I'm not sure what you mean by "they don't even exist in the functional API". The result of calling an Embedding layer with mask_zero=True with a KerasTensor input has a _keras_mask attribute (shown in above example).

fchollet commented 1 year ago

model1 and model2 should represent the same computation, but in reality they differ significantly during training due to the way masked inputs are treated.

I don't understand the nature of the difference by reading the code, can you explain?

jackd commented 1 year ago

@fchollet apologies, I'm ballsing this up by trying to make it a minimal example - the above code does indeed seem to behave identically.

Below is the most minimal example I can get that illustrates a difference. Apologies for not being able to simplify futher. Note this may mean the error is from keras-nlp and not keras-core.

import os

os.environ["KERAS_BACKEND"] = "tensorflow"  # ensure keras imports are consistent
os.environ["CUDA_VISIBLE_DEVICES"] = ""  # I have other models training...
import keras_core as keras
import keras_nlp
import numpy as np

ENG_VOCAB_SIZE = 3
SPA_VOCAB_SIZE = 11
MAX_SEQUENCE_LENGTH = 5
EMBED_DIM = 2
INTERMEDIATE_DIM = 7
NUM_HEADS = 1

def build_without_component_models():
    encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")

    x = keras_nlp.layers.TokenAndPositionEmbedding(
        vocabulary_size=ENG_VOCAB_SIZE,
        sequence_length=MAX_SEQUENCE_LENGTH,
        embedding_dim=EMBED_DIM,
        mask_zero=True,
    )(encoder_inputs)
    encoder_outputs = keras_nlp.layers.TransformerEncoder(
        intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
    )(x)

    # Decoder
    decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")

    x = keras_nlp.layers.TokenAndPositionEmbedding(
        vocabulary_size=SPA_VOCAB_SIZE,
        sequence_length=MAX_SEQUENCE_LENGTH,
        embedding_dim=EMBED_DIM,
        mask_zero=True,
    )(decoder_inputs)

    x = keras_nlp.layers.TransformerDecoder(
        intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
    )(decoder_sequence=x, encoder_sequence=encoder_outputs)
    x = keras.layers.Dropout(0.5)(x)
    decoder_outputs = keras.layers.Dense(SPA_VOCAB_SIZE, activation="softmax")(x)

    transformer = keras.Model(
        (encoder_inputs, decoder_inputs),
        decoder_outputs,
        name="transformer",
    )
    transformer.summary()
    return transformer

def build_with_component_models():
    encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")

    x = keras_nlp.layers.TokenAndPositionEmbedding(
        vocabulary_size=ENG_VOCAB_SIZE,
        sequence_length=MAX_SEQUENCE_LENGTH,
        embedding_dim=EMBED_DIM,
        mask_zero=True,
    )(encoder_inputs)
    encoder_outputs = keras_nlp.layers.TransformerEncoder(
        intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
    )(x)

    # Decoder
    decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
    encoded_seq_inputs = keras.Input(
        shape=(None, EMBED_DIM), name="decoder_state_inputs"
    )

    x = keras_nlp.layers.TokenAndPositionEmbedding(
        vocabulary_size=SPA_VOCAB_SIZE,
        sequence_length=MAX_SEQUENCE_LENGTH,
        embedding_dim=EMBED_DIM,
        mask_zero=True,
    )(decoder_inputs)

    x = keras_nlp.layers.TransformerDecoder(
        intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
    )(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
    x = keras.layers.Dropout(0.5)(x)
    decoder_outputs = keras.layers.Dense(SPA_VOCAB_SIZE, activation="softmax")(x)
    decoder = keras.Model(
        (decoder_inputs, encoded_seq_inputs),
        decoder_outputs,
    )
    decoder_outputs = decoder([decoder_inputs, encoder_outputs])

    transformer = keras.Model(
        (encoder_inputs, decoder_inputs),
        decoder_outputs,
        name="transformer",
    )
    transformer.summary()
    return transformer

model = build_with_component_models()
# model = build_without_component_models()
encoder_inputs = np.array([[1, 1, 1, 2, 0]])
decoder_inputs = np.array([[1, 2, 4, 3, 0]])
print(model((encoder_inputs, decoder_inputs)))

For convenience, the diff: diff

The code runs fine using build_without_component_models, but build_with_component_models raises the following:

/home/jackd/anaconda3/envs/keras-ema/lib/python3.10/site-packages/keras_core/src/layers/layer.py:764: UserWarning: Layer 'position_embedding1' (of type PositionEmbedding) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.
  warnings.warn(
/home/jackd/anaconda3/envs/keras-ema/lib/python3.10/site-packages/keras_core/src/layers/layer.py:764: UserWarning: Layer 'query' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.
  warnings.warn(
/home/jackd/anaconda3/envs/keras-ema/lib/python3.10/site-packages/keras_core/src/layers/layer.py:764: UserWarning: Layer 'key' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.
  warnings.warn(
/home/jackd/anaconda3/envs/keras-ema/lib/python3.10/site-packages/keras_core/src/layers/layer.py:764: UserWarning: Layer 'value' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.
  warnings.warn(
Traceback (most recent call last):
  File "/home/jackd/Development/python/keras-ema/examples/nlp/eng-spa-translation/play.py", line 105, in <module>
    print(model((encoder_inputs, decoder_inputs)))
  File "/home/jackd/anaconda3/envs/keras-ema/lib/python3.10/site-packages/keras_core/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/jackd/anaconda3/envs/keras-ema/lib/python3.10/site-packages/keras_core/src/layers/layer.py", line 1319, in __init__
    raise ValueError(
ValueError: Exception encountered when calling Functional.call().

In a nested call() argument, you cannot mix tensors and non-tensors. Received invalid mixed argument: mask=[None, <tf.Tensor: shape=(1, 5), dtype=bool, numpy=array([[ True,  True,  True,  True, False]])>]

Arguments received by Functional.call():
  • inputs=('array([[1, 1, 1, 2, 0]])', 'array([[1, 2, 4, 3, 0]])')
  • training=None
  • mask=None

Maybe the issue can be traced to the warnings, but the fact that it behaves differently compared to the non-component-model implementation seems fishy. Note even the non-error-raising implementation trains differently compared to what happens if you use from tensorflow import keras, so there's something else going on here too...

jackd commented 1 year ago

Here's a more minimal example without using only keras-core:

import numpy as np

## use either of the following pairs or imports
# from tensorflow import keras
# from tensorflow import logical_and

import keras_core as keras
from keras_core.ops import logical_and

class MySum(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        a, b = inputs
        return a + b

    def compute_output_shape(self, input_shape):
        a, b = input_shape
        assert a == b
        return a

    def compute_mask(self, inputs, previous_mask):
        if previous_mask is None:
            return None
        a, b = previous_mask
        return logical_and(a, b)

embedding = keras.layers.Embedding(3, 5, mask_zero=True)
inp = keras.Input((3,))
out = embedding(inp)
# embedding_model is a Model wrapper around just embedding layer
embedding_model = keras.Model(inp, out)

# construct a model without model components
out1 = embedding(inp)
s = MySum()((out, out1))
model1 = keras.Model(inp, s)

# construct a model with model components
out2 = embedding_model(inp)
s = MySum()((out, out2))  # <- error here with keras_core implementation
model2 = keras.Model(inp, s)

x = np.array([[1, 1, 0]], dtype="int64")
print(embedding(x))
print(model1(x))
print(model2(x))
Traceback (most recent call last):
  File ".../play.py", line 45, in <module>
    s = MySum()((out, out2))  # <- error here with keras_core implementation
  File ".../keras_core/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File ".../keras_core/src/layers/layer.py", line 1319, in __init__
    raise ValueError(
ValueError: Exception encountered when calling MySum.call().

In a nested call() argument, you cannot mix tensors and non-tensors. Received invalid mixed argument: mask=(<KerasTensor shape=(None, 3), dtype=float32, name=keras_tensor_2>, None)

Arguments received by MySum.call():
  • args=(('<KerasTensor shape=(None, 3, 5), dtype=float32, name=keras_tensor_1>', '<KerasTensor shape=(None, 3, 5), dtype=float32, name=keras_tensor_6>'),)
  • kwargs={'mask': ('<KerasTensor shape=(None, 3), dtype=float32, name=keras_tensor_2>', 'None')}
fchollet commented 1 year ago

Here's a more minimal example without using only keras-core:

Thanks for the code snippet. Are you sure this is the same issue, though? This is a known behavior difference in Keras Core: it does not allow None as part of a nested structure of input tensors. An easy workaround is to split your nested argument into separate arguments, e.g.

def call(self, a, b, mask_a=None, mask_b=None):

We can also look at lifting the limitation entirely, though that would require quite a bit of work.

jackd commented 1 year ago

I'm fairly confident it's the same issue.

Not-withstanding the work-around, it's surprising to have an error raised as the result of wrapping a layer in a Model - a refactoring that, in my opinion, should be encouraged (not in the context of the contrived example above, but e.g. separating encoders/decoders in transformers).