keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

On JAX, Keras replaces any exception inside `call` method of `keras.Model` subclass with misleading error #19675

Closed burnpanck closed 1 week ago

burnpanck commented 1 week ago

MWE:

import os

os.environ["KERAS_BACKEND"] = "jax"

import keras

class Test(keras.Model):
    def call(self, x):
        raise RuntimeError("Random misspelling deeply nested in the model")

t = Test()

inp = keras.Input(shape=(32, 3))

t(inp)

Running the above example causes the following exception being raised:

TypeError: Exception encountered when calling Test.call().

Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 3).

Arguments received by Test.call():
  • args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
  • kwargs=<class 'inspect._empty'>

On tensorflow, the exception instead reads:

RuntimeError: Exception encountered when calling Test.call().

Could not automatically infer the output shape / dtype of 'test_1' (of type Test). Either the `Test.call()` method is incorrect, or you need to implement the `Test.compute_output_spec() / compute_output_shape()` method. Error encountered:

Random misspelling deeply nested in the model

Arguments received by Test.call():
  • args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
  • kwargs=<class 'inspect._empty'>

Note that in the case of tensorflow, the error message contains the original exception string, whereas under JAX, the message misleadingly makes a strong suggestion that there is a problem with a shape. Furthermore, the internal frames of the stack trace get erased (not shown highlighted in the example above to minimize the MWE). If this happens deeply inside a model, an unsuspecting user may be sent off to an many hours long hunt for mismatched shapes that doesn't turn up anything useful.

burnpanck commented 1 week ago

This was using Keras 3.1.1, JAX 0.4.26, and python 3.12

fchollet commented 1 week ago

Good catch. I fixed it at HEAD.