keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.13k stars 19.49k forks source link

Possible JIT compilation bug with JAX #20165

Closed neo-alex closed 3 days ago

neo-alex commented 3 months ago

I have the minimal code below to check that JIT-compiled model outputs match non-JIT ones:

import os
os.environ["KERAS_BACKEND"] = ...  # "jax" or "tensorflow"

import numpy as np
import keras

x = keras.ops.convert_to_tensor([
    [[1], [2], [3]],
    [[1], [2], [-99]],
    [[1], [-99], [-99]],
])

model = get_model()
model_output = model(x)  # this is NOT JIT-compiled

model.compile(jit_compile=True)
jit_model_output = model.predict_on_batch(x)  # this is JIT-compiled

assert np.allclose(model_output, jit_model_output)

For this example, assume that we want to create a model that will average x "line by line" above, ignoring the -99 values that we will mask.

def get_model():
    return keras.Sequential([
        keras.layers.Masking(-99),
        keras.layers.GlobalAveragePooling1D()
    ])
class MaskedGlobalAveragePooling1D(keras.layers.Layer):
    def __init__(self, mask_value, **kwargs):
        super().__init__(**kwargs)
        self.masking = keras.layers.Masking(mask_value)
        self.pooling = keras.layers.GlobalAveragePooling1D()

    def call(self, inputs):
        x = self.masking(inputs)
        return self.pooling(x)

def get_model():
    return keras.Sequential([
        MaskedGlobalAveragePooling1D(mask_value=-99)
    ])

Note: I know that using keras.layers.Masking inside a custom layer is not common (I actually need it for a more advanced use case), but I see no reason why it shouldn't work consistently across all backends.

I would appreciate any help fixing this bug, thank you!

neo-alex commented 3 months ago

In the meantime, I also tried with "torch" backend and everything works fine, like with "tensorflow" (so the issue mentioned above seems specific to JAX with JIT compilation)

sachinprasadhs commented 2 months ago

I was able to reproduce the reported behavior here

neo-alex commented 2 months ago

My bad, I think the issue is solved if I change the call function of my MaskedGlobalAveragePooling1D to:

    def call(self, inputs):
        mask = self.masking.compute_mask(inputs)
        return self.pooling(inputs, mask=mask)

Still, I would argue that the original issue is rather tricky and can happen quite "silently" (it is at least unexpected that the output can differ across backends... I don't know if there would be an easy way to warn users somehow to mitigate it). By the way, it would be nice for the Masking & Padding guide to find its way back to the documentation (it seems to have disappeared from the Developer guides). Thanks!

mehtamansi29 commented 1 month ago

Hi @neo-alex -

I have reproduce the issue with keras Masking layer get_model() function and also with MaskedGlobalAveragePooling1D subclassing in latest keras3.6.0. And it's working fine for both the case with jax and tensorflow backend.

Attached gist for your reference here.

github-actions[bot] commented 2 weeks ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] commented 3 days ago

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

google-ml-butler[bot] commented 3 days ago

Are you satisfied with the resolution of your issue? Yes No