tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

Error in stats.kendalls_tau as Keras Metric #1417

Open gonzalesMK opened 3 years ago

gonzalesMK commented 3 years ago

I am trying to use TensorFlow probability as a metric in Keras. With respect to kendalls_tau, I get the following error:

import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np 

def kendalls_tau(y_true, y_pred):
    a = tf.reshape(y_true, shape=(-1,))
    b = tf.reshape(y_pred, shape=(-1,))
    kendall = tfp.stats.kendalls_tau(a, b)
    return kendall

inputs = tf.keras.layers.Input(shape=(3,))
outputs = tf.keras.layers.Dense(2)(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer="Adam", loss="mse", metrics=kendalls_tau)

x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2) )
model.fit(x, y)
TypeError: in user code:

    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function  *
        return step_function(self, iterator)
    <ipython-input-4-14a2210abe73>:5 kendalls_tau  *
        kendall = tfp.stats.kendalls_tau(a, b)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau  **
        lexa = lexicographical_indirect_sort(y_true, y_pred)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
        left, _, lexicographic = tf.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
        return while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
        return while_v2.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
        body_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
        tf.cond(not_equal, secondary_sort, lambda: lexicographic))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
        return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
        true_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
        tensorshape_util.set_shape(x, [n])
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
        tensor.set_shape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
        shape = tensor_shape.TensorShape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
        six.raise_from(
    <string>:3 raise_from

    TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/data/Mestrado/Ensaios/drbc_tf.py in 
     257 x = np.random.random((2, 3))
     258 y = np.random.randint(0, 2, (2, 2) )
---> 259 model.fit(x, y)
     260 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1190                 _r=1):
   1191               callbacks.on_train_batch_begin(step)
-> 1192               tmp_logs = self.train_function(iterator)
   1193               if data_handler.should_sync:
   1194                 context.async_wait()

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    927       # This is the first call of __call__, so we have to initialize.
    928       initializers = []
--> 929       self._initialize(args, kwds, add_initializers_to=initializers)
    930     finally:
    931       # At this point we know that the initialization is complete (or less

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    757     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    758     self._concrete_stateful_fn = (
--> 759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    760             *args, **kwds))
    761 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3057       args, kwargs = None, None
   3058     with self._lock:
-> 3059       graph_function, _ = self._maybe_define_function(args, kwargs)
   3060     return graph_function
   3061 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3454 
   3455           self._function_cache.missed.add(call_context_key)
-> 3456           graph_function = self._create_graph_function(args, kwargs)
   3457           self._function_cache.primary[cache_key] = graph_function
   3458 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3289     arg_names = base_arg_names + missing_arg_names
   3290     graph_function = ConcreteFunction(
-> 3291         func_graph_module.func_graph_from_py_func(
   3292             self._name,
   3293             self._python_function,

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    992           except Exception as e:  # pylint:disable=broad-except
    993             if hasattr(e, "ag_error_metadata"):
--> 994               raise e.ag_error_metadata.to_exception(e)
    995             else:
    996               raise

TypeError: in user code:

    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function  *
        return step_function(self, iterator)
    <ipython-input-4-14a2210abe73>:5 kendalls_tau  *
        kendall = tfp.stats.kendalls_tau(a, b)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau  **
        lexa = lexicographical_indirect_sort(y_true, y_pred)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
        left, _, lexicographic = tf.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
        return while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
        return while_v2.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
        body_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
        tf.cond(not_equal, secondary_sort, lambda: lexicographic))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
        return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
        true_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
        tensorshape_util.set_shape(x, [n])
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
        tensor.set_shape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
        shape = tensor_shape.TensorShape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
        six.raise_from(
    <string>:3 raise_from

    TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'

How can I fix this?

murphyja11 commented 3 years ago

You can wrap kendalls_tau with tf.py_function:

def tf_kendalls_tau(y_true, y_pred):
    kt = tf.py_function(
        kendalls_tau,
        (y_true, y_pred),
        tf.float32
    )
    return kt

model.compile(optimizer="Adam", loss="mse", metrics=tf_kendalls_tau)

The error is caused by calling the TensorFlow graph version of the function. When you call model.fit(), AutoGraph converts the kendall_tau function into a TensorFlow graph. We can use tf.py_function to prevent this, which allows us to represent kendall_tau in the graph using Python constructs.

While this works, I don't think this issue is resolved. Calling the AutoGraph-converted version of tfp.stats.kendalls_tau throws this error. I'm going to look into why this is the case.

Does someone with more experience already see an issue here? The AutoGraph-converted version of tfp.stats.kendalls_tau should be callable, right?

Full code that worked for me:

import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np

def kendalls_tau(y_true, y_pred):
    a = tf.reshape(y_true, shape=(-1,))
    b = tf.reshape(y_pred, shape=(-1,))
    kendall = tfp.stats.kendalls_tau(a, b)
    return kendall

def tf_kendalls_tau(y_true, y_pred):
    kt = tf.py_function(
        kendalls_tau,
        (y_true, y_pred),
        tf.float32
    )
    return kt

inputs = tf.keras.layers.Input(shape=(3,))
outputs = tf.keras.layers.Dense(2)(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer="Adam", loss="mse", metrics=tf_kendalls_tau)

x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2) )
model.fit(x, y)
sorensenjs commented 3 years ago

I've started looking at this, but the stats function is not suitable to use as a keras Metric, for a number of reasons. You probably want to use the approximate version which is O(n) instead and can be found in tensorflow addons, as that was intended for this use case. https://github.com/tensorflow/addons/blob/master/tensorflow_addons/metrics/kendalls_tau.py

the tfp version expects two [n] tensors and as far as I know keras models cannot output scalars.

sorensenjs commented 2 years ago

I've looked into this and I think there's a few issues, but fundamentally I'm not certain the original proposed use case makes much sense - if using the py_function shim you can use the scipy version of Kendall's Tau - but even then I'm not certain it is doing what one would want as the scipy version return nan for lists of length 1.

It's possible to remove the assertions and make the tfp kendall's tau behave more like the scipy one, but still working on it.

murphyja11 commented 2 years ago

Thank you for following up.

I'd love to help out on this. Let me know if there's anything I can do that would be valuable.

sorensenjs commented 2 years ago

I've created https://github.com/tensorflow/probability/pull/1455 which changes behavior to more closely match scipy's implementation. I don't know if this addresses all of the issues raised here but wanted to share it in case it helps.

sorensenjs commented 4 months ago

I think this is fixed.