tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

tfp.math.scan_associative doesn't work for all associative functions (it should be using `vmap` for `lowered_fn`) #1812

Open Joshuaalbert opened 4 weeks ago

Joshuaalbert commented 4 weeks ago

Here is a simple example of an associative function that scan_associative fails to handle because it assumes the associative op broadcasts.

The solution is to use jax.vmap to distributed elements in lowered_fn here rather than rely on broadcasting.

MVCE

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp

def explicit_verify_associative(op, elems):
    output_1 = op(op(elems[0], elems[1]), elems[2])
    output_2 = op(elems[0], op(elems[1], elems[2]))
    print(output_1, output_2)
    assert output_1 == output_2

def main():
    elems = jax.random.normal(jax.random.PRNGKey(0), shape=(3,))

    elem_shape = jax.tree.map(lambda x: np.shape(x[0]), elems)  # ()

    def per_elem_op(x) -> jax.Array:
        return jnp.sum(x)

    def associative_op(x, y):
        print(f"x.shape={np.shape(x)}, y.shape={np.shape(y)}")
        assert np.shape(x) == elem_shape
        assert np.shape(y) == elem_shape
        return per_elem_op(x) + per_elem_op(y)

    explicit_verify_associative(associative_op, elems)

    _ = tfp.math.scan_associative(associative_op, elems)

if __name__ == '__main__':
    main()