Open eli-osherovich opened 2 weeks ago
The result is correct if I cast the second parameter of update_state
to a float
or int
.
Hi @eli-osherovich-
While updating state(met.update_state(x>0.5, res>0.5)), x>0.5 and res>0.5 are in boolean arrays. But BinaryAccuracy metrics accepts only numerical values(floats or integers) only.
While running same code in tensorflow backend it is giving error message. Error: InvalidArgumentError: Value for attr 'T' of bool is not in the list of allowed values: float, double, int32, uint8, int16, int8, int64, bfloat16, uint16, half, uint32, uint64 ; NodeDef: {{node Greater}}; Op<name=Greater; signature=x:T, y:T -> z:bool; attr=T:type,allowed=[DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64]> [Op:Greater] name
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np
inp = keras.Input(shape=(1,))
out = inp > 0.5
mm = keras.Model(inputs=inp, outputs=out)
x = np.random.rand(32, 1)
res = mm.predict(x)
met = keras.metrics.BinaryAccuracy()
met.update_state(x>0.5, res>0.5)
met.result()
So in the JAX there should be same error message comes while giving boolean into BinaryAccuracy metrics. You can create new issue in JAX repo for adding the error message.
We could consider casting the values to floatx()
in update_state()
-- would you like to open a PR @eli-osherovich ?
We could consider casting the values to
floatx()
inupdate_state()
-- would you like to open a PR @eli-osherovich ?
Hi @fchollet - I will raise PR for casting the values to floatx() in update_state().
I have some very strange results out of the `
Consider the code below:
I would expect to get 1 every single run. Instead I get some random result (close to 0.5).
Packages' versions (tf, keras, jax, np)