danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
101 stars 14 forks source link

Autodiff problem with block_neural_autoregressive_flow #176

Open aseyboldt opened 1 month ago

aseyboldt commented 1 month ago

Very nice package :-)

While playing around with different optimization objectives I ran into an autodiff issue. The following always returns exactly zeros, which I think isn't correct. This might be related to the bisection search (I think that's what's used here?) but if the while_loop in there is a problem I would have expected an error, not an incorrect result.

import flowjax.flows
import jax.numpy as jnp
import jax
import numpy as np

flow_key = jax.random.PRNGKey(0)
point = np.random.randn(5)
cotan = np.random.randn(5)

base_dist = flowjax.distributions.Normal(jnp.zeros(5))
flow = flowjax.flows.block_neural_autoregressive_flow(flow_key, base_dist=base_dist, invert=True)

out, pull_grad_fn = jax.vjp(lambda x: flow.bijection.transform_and_log_det(x), point)
pullback = pull_grad_fn((cotan, 1.))
pullback
# (Array([0., 0., 0., 0., 0.], dtype=float32),)
danielward27 commented 1 month ago

Interesting, I would have initially assumed it would have errored too. Here's an example that gets to the root of the problem (in both senses of the word)

import jax.numpy as jnp
from jax import lax
import jax

def _bisection_search(func, *, lower, upper, tol: float, max_iter: int):

    def cond_fn(state):
        lower, upper, iterations = state
        return jnp.logical_and((upper - lower) > 2 * tol, iterations < max_iter)

    def body_fn(state):
        lower, upper, iterations = state
        midpoint = (lower + upper) / 2
        sign = jnp.sign(func(midpoint))
        lower = jnp.where(sign == 1, lower, midpoint)
        upper = jnp.where(sign == 1, midpoint, upper)
        return lower, upper, iterations + 1

    init_state = (lower, upper, 0)
    lower, upper, iterations = lax.while_loop(cond_fn, body_fn, init_state)
    root = (lower + upper) / 2
    return root, iterations

def get_root(x):
    return _bisection_search(
        func=lambda arr: arr + x,
        lower=-10,
        upper=10,
        tol =1e-7,
        max_iter=100,
    )[0]

Note the gradient is actually zero everywhere it is defined, because the result is a stepwise function. We can see that visually by plotting it for a very small region

x = jnp.linspace(-1e-6, 1e-6, 1000)
roots = jax.vmap(get_root)(x)

import matplotlib.pyplot as plt
plt.plot(x, roots)

Regardless, I think that an error when differentiating through _bisection_search might be better than returning a gradient of zero, because it's likely a mistake. I presume this could be done with jax.custom_jvp?