keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.49k stars 19.41k forks source link

Jax Error with Stochastic Depth #18404

Open anas-rz opened 11 months ago

anas-rz commented 11 months ago

While converting a {model} from TensorFlow {to Keras Core}, I am facing error with the JAX backend at the end of the epoch. The error is thrown by the StochasticDepth layer {used here in the network}. The code works with PyTorch and TensorFlow backends and shows the leakage with JAX. I do understand the concept of pure functions in JAX. Still not able to figure out whether it's the problem of my framework or Keras Core. The StochasticDepth layer works perfectly CCT example {here}. Here's the stack trace if that can help:

Epoch 1/5
48/49 ━━━━━━━━━━━━━━━━━━━━ 0s 967ms/step - accuracy: 0.1990 - loss: 6.7065
---------------------------------------------------------------------------
UnexpectedTracerError                     Traceback (most recent call last)
[<ipython-input-14-61822f405d65>](https://localhost:8080/#) in <cell line: 12>()
     10         ],
     11     )
---> 12 history = model.fit(
     13     pipeline_train,
     14     batch_size=BATCH_SIZE,

4 frames
[/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    121             # To get the full stack trace, call:
    122             # `keras_core.config.disable_traceback_filtering()`
--> 123             raise e.with_traceback(filtered_tb) from None
    124         finally:
    125             del filtered_tb

    [... skipping hidden 20 frame]

[/content/./focalnet-keras-core/focalnet_keras_core/layers.py](https://localhost:8080/#) in call(self, x, training)
    154             keep_prob = 1 - self.drop_path_rate
    155             shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)
--> 156             random_tensor = keep_prob + keras.random.uniform(shape, 0, 1)
    157             random_tensor = ops.floor(random_tensor)
    158             return (x / keep_prob) * random_tensor

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in op(self, *args)
    721 def _forward_operator_to_aval(name):
    722   def op(self, *args):
--> 723     return getattr(self.aval, f"_{name}")(self, *args)
    724   return op
    725 

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in deferring_binary_op(self, other)
    254     args = (other, self) if swap else (self, other)
    255     if isinstance(other, _accepted_binop_types):
--> 256       return binary_op(*args)
    257     if isinstance(other, _rejected_binop_types):
    258       raise TypeError(f"unsupported operand type(s) for {opchar}: "

    [... skipping hidden 5 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/partial_eval.py](https://localhost:8080/#) in _assert_live(self)
   1577   def _assert_live(self) -> None:
   1578     if not self._trace.main.jaxpr_stack:  # type: ignore
-> 1579       raise core.escaped_tracer_error(self, None)
   1580 
   1581   def get_referent(self):

UnexpectedTracerError: Exception encountered when calling StochasticDepth.call().

Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was compiled_train_step at /usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/trainer.py:203 traced for jit.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/core.py:19 (_initialize). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/layer.py:867 (stateless_call)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/stateless_scope.py:66 (__exit__)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/variables.py:370 (initialize_all_variables)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/variables.py:87 (_deferred_initialize)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/core.py:19 (_initialize)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Arguments received by StochasticDepth.call():
  • x=jnp.ndarray(shape=(48, 3136, 96), dtype=float32)
  • training=True
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
fchollet commented 11 months ago

So you've verified the StochasticDepth layer works fine in JAX otherwise?

Typically the "tracer leak" type of issue happens when you take a tensor from an intermediate computation and store it as an attribute of a permanent object, like a layer or model.

I think the line random_tensor = keep_prob + keras.random.uniform(shape, 0, 1) in the trace is interesting. What happens if you replace it with a constant, like 0.6?

sachinprasadhs commented 7 months ago

@anas-rz , Could you please check with the latest Keras 3 package and let us know if you're still facing the issue.

If you could provide some sample reproducible code, it makes it easier for us to debug.