keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

JAX array conversion failure in Keras model prediction #19674

Closed Qazalbash closed 1 week ago

Qazalbash commented 1 week ago

I have trained a simple Deep-MLP model and saved it in .keras format. I am utilizing JAX jitted functions for predictions, passing two inputs as jax.numpy.column_stack. Despite attempting alternative methods, including using numpy.column_stack and setting JAX_TRACEBACK_FILTERING=off, the issue persists. Notably, my Keras backend is configured as KERAS_BACKEND=jax.

File "/media/project/inference/lippl.py", line 114, in exp_rate_integral
    jnp.exp(mass_model.log_prob(m1q) + self.logVT.predict(m1m2).flatten()),
                                       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/optree/ops.py", line 594, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/jax/_src/core.py", line 684, in __array__
    raise TracerArrayConversionError(self)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[10000,2].
The error occurred while tracing the function likelihood at /media/project/inference/lippl.py:119 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /media/project/inference/lippl.py:107:29 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)

  operation a:f32[10000] = pjit[
  name=_uniform
  jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
      e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
      f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
      g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
      h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
      i:u32[10000] = random_bits[bit_width=32 shape=(10000,)] b
      j:u32[10000] = shift_right_logical i 9
      k:u32[10000] = or j 1065353216
      l:f32[10000] = bitcast_convert_type[new_dtype=float32] k
      m:f32[10000] = sub l 1.0
      n:f32[1] = sub h g
      o:f32[10000] = mul m n
      p:f32[10000] = add o g
      q:f32[10000] = max g p
    in (q,) }
] r s t
    from line /media/project/inference/lippl.py:107:13 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /media/project/inference/lippl.py:108:29 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)

  operation a:f32[10000] = pjit[
  name=_uniform
  jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
      e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
      f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
      g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
      h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
      i:u32[10000] = random_bits[bit_width=32 shape=(10000,)] b
      j:u32[10000] = shift_right_logical i 9
      k:u32[10000] = or j 1065353216
      l:f32[10000] = bitcast_convert_type[new_dtype=float32] k
      m:f32[10000] = sub l 1.0
      n:f32[1] = sub h g
      o:f32[10000] = mul m n
      p:f32[10000] = add o g
      q:f32[10000] = max g p
    in (q,) }
] r s t
    from line /media/project/inference/lippl.py:108:13 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
fchollet commented 1 week ago

jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float32[10000,2].

This indicates that your model contains an operation that tries to retrieve the eager value of a tensor.

Earlier, I see:

jnp.exp(mass_model.log_prob(m1q) + self.logVT.predict(m1m2).flatten()),

So it sounds like you are calling predict() inside a tracing scope. This is impossible. Perhaps you meant to call self.logVT(m1m2) instead?

See also: https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call

fchollet commented 1 week ago

Oh, and if your call method is stateful in any way, you'll need to use stateless_call() instead and manage the state updates manually.

Qazalbash commented 1 week ago

So it sounds like you are calling predict() inside a tracing scope. This is impossible. Perhaps you meant to call self.logVT(m1m2) instead?

@fchollet This solved the problem. Thanks!