blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
802 stars 102 forks source link

Add a "How to use forward-mode differentiation instead of backward-mode differentiation" tutorial #374

Open rlouf opened 1 year ago

johannahaffner commented 1 month ago

Hi!

I'm working on implementing forward-mode automatic differentiation for a function I want to sample and I would appreciate such a tutorial! I've been trying to define a custom jvp for my logdensity function, and I get strange errata about exceeded recursion depths.

I might have taken a flawed approach, so a best practice would be amazing!

Best

Johanna

johannahaffner commented 1 month ago

PS: I have a logdensity that contains a solved ODE, and I would like to use an adjoint that only supports forward-mode automatic differentiation, since that gives me better performance in all other parts of my program.

AdrienCorenflos commented 1 month ago

I'm not sure what Rémi meant by this issue. JVP vs VJP seems like a JAX-only problem to explain and is a bit orthogonal to BlackJAX. Additionally, most of the time JVP can automatically be translated to VJP by JAX at 0 overhead, so it's not clear to me what there is to explain further than pointing to the core JAX documentation.

About your specific requirement of using JVP not VJP, I'm not sure why you'd want this, but an easy way to handle it is to define a wrapper function whose custom_vjp is calling the jvp of your function of interest.

Example:


def func_of_interest(x):
    return jnp.sum(x)

@jax.custom_vjp
def wrapped_function(x):
    return func_of_interest(x)

def f_fwd(x):
    y = func_of_interest(x)
    jacfwd = jax.jacfwd(func_of_interest)(x)
    return y, jacfwd

def f_bwd(jacfwd, g):
    return (g * jacfwd,)

wrapped_function.defvjp(f_fwd, f_bwd)
johannahaffner commented 1 month ago

I want this because I have to differentiate through an ODE solver, and this is faster with forward-mode automatic differentiation for my use case.

I would like to supply a logdensity function with a custom jvp to enable this, and tried implementing something that is very similar to what you are doing here. Your example runs for me. In my own version, I had an ODE solve and that combination gave me a "maximum recursion depth reached" error. I believe that this occurs due to the implementation of jacfwd as vmap of jvp (https://github.com/google/jax/discussions/19973#discussioncomment-8598622), combined with the complexity of the underlying ODE solver.

I tried this on real data straight away, so I would need to write an MWE to figure out where this comes from specifically.

I would prefer not to have to specify an analytic definition of the logdensity, since this would have to be adapted to the dimensions of the model and data for each problem.

AdrienCorenflos commented 1 month ago

I want this because I have to differentiate through an ODE solver, and this is faster with forward-mode automatic differentiation for my use case.

This is just very surprising to me. Reverse-mode AD for ODEs is now well-understood, and implementations are available in JAX. I'm just curious why JVP would be faster than these (that typically are based on adjoint ODEs).

johannahaffner commented 1 month ago

I did some benchmarking on the data I am working with, trying different adjoints implemented in Diffrax. I posted a short recap here.

It may be due to the recursive implementation & memory requirements, my time-series are quite long and the data is noisy. I'm also doing this in parallel on hundreds of these time-series, getting gradients for each of them separately. For sampling, I likewise want to do this in many parallel chains.

Since my logdensity function is based on the loglikelihood I am using during non-linear optimisation, I would not expect the autodiff performance to be any different during sampling.