Closed neo-alex closed 3 days ago
In the meantime, I also tried with "torch" backend and everything works fine, like with "tensorflow" (so the issue mentioned above seems specific to JAX with JIT compilation)
My bad, I think the issue is solved if I change the call function of my MaskedGlobalAveragePooling1D to:
def call(self, inputs):
mask = self.masking.compute_mask(inputs)
return self.pooling(inputs, mask=mask)
Still, I would argue that the original issue is rather tricky and can happen quite "silently" (it is at least unexpected that the output can differ across backends... I don't know if there would be an easy way to warn users somehow to mitigate it). By the way, it would be nice for the Masking & Padding guide to find its way back to the documentation (it seems to have disappeared from the Developer guides). Thanks!
Hi @neo-alex -
I have reproduce the issue with keras Masking layer get_model() function and also with MaskedGlobalAveragePooling1D subclassing in latest keras3.6.0. And it's working fine for both the case with jax and tensorflow backend.
Attached gist for your reference here.
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.
I have the minimal code below to check that JIT-compiled model outputs match non-JIT ones:
For this example, assume that we want to create a model that will average x "line by line" above, ignoring the -99 values that we will mask.
get_model()
function below, the test is successful both with "tensorflow" and "jax" backends:Note: I know that using keras.layers.Masking inside a custom layer is not common (I actually need it for a more advanced use case), but I see no reason why it shouldn't work consistently across all backends.
I would appreciate any help fixing this bug, thank you!