keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.92k stars 19.45k forks source link

Inconsistent manner of the metric `SpecificityAtSensitivity` among different backends #19376

Open Star9daisy opened 7 months ago

Star9daisy commented 7 months ago

Hi, developers. I want to monitor the metric SpecificityAtSensitivity values during the training process. I've checked the doc to make sure it could be used with compile API.

However, I find this metric does not work out of box like others, for example TruePositives. It could not be used directly in different backends, and even sometimes with "accuracy"! Below are the tables of three backends to show the results I found. "with acc" means set metrics=["acc", SpecificityAtSensitivity(...)]. And a colab link here to reproduce the error: issue_inconsistent_manner_of_SpecificityAtSensitivity.ipynb

tensorflow

run_eager \ metrics with "acc" without "acc"
False (default)
True ✔️ ✔️

<class 'NotImplementedError'> Cannot convert a symbolic tf.Tensor (Cast_12:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

torch

run_eager \ metrics with "acc" without "acc"
False (default) ✔️
True ✔️

<class 'NotImplementedError'> Cannot copy out of meta tensor; no data!

jax

run_eager \ metrics with "acc" without "acc"
False (default)
True

<class 'jax.errors.ConcretizationTypeError'> Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. The error occurred while tracing the function wrapped_fn at /usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/core.py:153 for make_jaxpr. This concrete value was not available in Python because it depends on the values of the arguments args[1] and args[2].

SuryanarayanaY commented 7 months ago

Hi @Star9daisy ,

I have replicated the reported behaviour and attached gist here.Need to dig more and will come back.

grasskin commented 6 months ago

Possibly related to #19433

grasskin commented 6 months ago

Thank you @Star9daisy! Would you be able to submit a fix for this? Otherwise we're leaving this open for contributions from the community.

Star9daisy commented 6 months ago

Hi @grasskin, it looks kind of difficult for me. But I'd like to give it try!

grasskin commented 6 months ago

Go ahead @Star9daisy! The nonzero error is likely related to #19407 which we just closed with a fix. Feel free to update this thread with concerns/questions. One thing I would strongly recommend is disabling XLA compilation to make debugging easier. Specifically tf.config.run_functions_eagerly(False) for tensorflow and jax.config.update("jax_disable_jit", True) for JAX.

You will also likely benefit from disabling Keras traceback filtering keras.config.disable_traceback_filtering() to get full error tracebacks.

github-actions[bot] commented 1 week ago

This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.