keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.61k stars 19.42k forks source link

Wrong binary accuracy with Jax #20178

Open eli-osherovich opened 2 weeks ago

eli-osherovich commented 2 weeks ago

I have some very strange results out of the `

Consider the code below:

import os
os.environ["KERAS_BACKEND"] = "jax"
import keras

inp = keras.Input(shape=(1,))
out = inp > 0.5
mm = keras.Model(inputs=inp, outputs=out) 

x = np.random.rand(32, 1)

res = mm.predict(x)
met = keras.metrics.BinaryAccuracy()
met.update_state(x>0.5, res>0.5)
met.result()

I would expect to get 1 every single run. Instead I get some random result (close to 0.5).

Packages' versions (tf, keras, jax, np)

'2.17.0', '3.5.0', '0.4.26', '1.26.4'
eli-osherovich commented 2 weeks ago

The result is correct if I cast the second parameter of update_state to a float or int.

mehtamansi29 commented 2 weeks ago

Hi @eli-osherovich-

While updating state(met.update_state(x>0.5, res>0.5)), x>0.5 and res>0.5 are in boolean arrays. But BinaryAccuracy metrics accepts only numerical values(floats or integers) only.

While running same code in tensorflow backend it is giving error message. Error: InvalidArgumentError: Value for attr 'T' of bool is not in the list of allowed values: float, double, int32, uint8, int16, int8, int64, bfloat16, uint16, half, uint32, uint64 ; NodeDef: {{node Greater}}; Op<name=Greater; signature=x:T, y:T -> z:bool; attr=T:type,allowed=[DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64]> [Op:Greater] name

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np

inp = keras.Input(shape=(1,))
out = inp > 0.5
mm = keras.Model(inputs=inp, outputs=out) 

x = np.random.rand(32, 1)

res = mm.predict(x)
met = keras.metrics.BinaryAccuracy()
met.update_state(x>0.5, res>0.5)
met.result()

So in the JAX there should be same error message comes while giving boolean into BinaryAccuracy metrics. You can create new issue in JAX repo for adding the error message.

fchollet commented 2 weeks ago

We could consider casting the values to floatx() in update_state() -- would you like to open a PR @eli-osherovich ?

mehtamansi29 commented 2 weeks ago

We could consider casting the values to floatx() in update_state() -- would you like to open a PR @eli-osherovich ?

Hi @fchollet - I will raise PR for casting the values to floatx() in update_state().