Open Star9daisy opened 7 months ago
Hi @Star9daisy ,
I have replicated the reported behaviour and attached gist here.Need to dig more and will come back.
Possibly related to #19433
Thank you @Star9daisy! Would you be able to submit a fix for this? Otherwise we're leaving this open for contributions from the community.
Hi @grasskin, it looks kind of difficult for me. But I'd like to give it try!
Go ahead @Star9daisy! The nonzero error is likely related to #19407 which we just closed with a fix. Feel free to update this thread with concerns/questions. One thing I would strongly recommend is disabling XLA compilation to make debugging easier. Specifically tf.config.run_functions_eagerly(False)
for tensorflow and jax.config.update("jax_disable_jit", True)
for JAX.
You will also likely benefit from disabling Keras traceback filtering keras.config.disable_traceback_filtering()
to get full error tracebacks.
This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.
Hi, developers. I want to monitor the metric
SpecificityAtSensitivity
values during the training process. I've checked the doc to make sure it could be used withcompile
API.However, I find this metric does not work out of box like others, for example
TruePositives
. It could not be used directly in different backends, and even sometimes with "accuracy"! Below are the tables of three backends to show the results I found. "with acc" means setmetrics=["acc", SpecificityAtSensitivity(...)]
. And a colab link here to reproduce the error: issue_inconsistent_manner_of_SpecificityAtSensitivity.ipynbtensorflow
<class 'NotImplementedError'> Cannot convert a symbolic tf.Tensor (Cast_12:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
torch
<class 'NotImplementedError'> Cannot copy out of meta tensor; no data!
jax
<class 'jax.errors.ConcretizationTypeError'> Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. The error occurred while tracing the function wrapped_fn at /usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/core.py:153 for make_jaxpr. This concrete value was not available in Python because it depends on the values of the arguments args[1] and args[2].