Closed Qazalbash closed 1 week ago
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float32[10000,2].
This indicates that your model contains an operation that tries to retrieve the eager value of a tensor.
Earlier, I see:
jnp.exp(mass_model.log_prob(m1q) + self.logVT.predict(m1m2).flatten()),
So it sounds like you are calling predict()
inside a tracing scope. This is impossible. Perhaps you meant to call self.logVT(m1m2)
instead?
See also: https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call
Oh, and if your call method is stateful in any way, you'll need to use stateless_call()
instead and manage the state updates manually.
So it sounds like you are calling
predict()
inside a tracing scope. This is impossible. Perhaps you meant to callself.logVT(m1m2)
instead?
@fchollet This solved the problem. Thanks!
I have trained a simple Deep-MLP model and saved it in
.keras
format. I am utilizing JAX jitted functions for predictions, passing two inputs asjax.numpy.column_stack
. Despite attempting alternative methods, including usingnumpy.column_stack
and settingJAX_TRACEBACK_FILTERING=off
, the issue persists. Notably, my Keras backend is configured asKERAS_BACKEND=jax
.