handley-lab / lsbi

Linear Simulation Based Inference
MIT License
2 stars 0 forks source link

Autodiff likelihoods from lsbi #43

Open yallup opened 1 month ago

yallup commented 1 month ago

It would be useful for testing numerical inference algorithms to have differentiable likelihoods in lsbi. In theory I think the whole package can swap to jax, however things like rng are quite different and would require some excavation, links to #41.

The basic thing one needs is the ability to furnish the distributions with a jax log_prob function. The most useful would be the likelihood, this can be done fairly simply below.

from lsbi.model import MixtureModel, LinearModel
from jax.scipy.stats import multivariate_normal
import jax.numpy as jnp
import numpy as np

d = 100
t = 5
C = np.eye(d) * 50
model = LinearModel(M=np.random.randn(d, t))
true_theta, true_data = np.split(model.joint().rvs(), [t], axis=-1)

def log_prob(theta):
    mu = model.m + jnp.einsum(
        "...ja,...a->...j", model._M, theta * jnp.ones(model.n)
    )
    return multivariate_normal.logpdf(true_data, mean=mu, cov=model._C)

from jax import random
from jax import vmap, value_and_grad

rng = random.PRNGKey(0)

theta_samples = random.normal(rng, (100, t))
np_log_prob = model.likelihood(theta_samples).logpdf(true_data)
jax_log_prob = log_prob(theta_samples)

value, grad = vmap(value_and_grad(log_prob))(theta_samples)

print((np_log_prob - jax_log_prob).mean())
print((np_log_prob - value).mean())

and for basic mixtures

from lsbi.model import MixtureModel, LinearModel
from jax.scipy.stats import multivariate_normal
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import logsumexp

d = 100
t = 5
k = 3
C = np.eye(d) * 50
# model = LinearModel(M=np.random.randn(d, t))
model = MixtureModel(M=np.random.randn(k, d, t))
true_theta, true_data = np.split(model.joint().rvs(), [t], axis=-1)

def log_prob(theta):
    mu = model.m + jnp.einsum(
        "...ja,...a->...j",
        model._M,
        jnp.expand_dims(theta, -2) * jnp.ones(model.n),
    )
    mixture_weights = logsumexp(model.logw * jnp.ones(model.k))
    # return logsumexp(multivariate_normal.logpdf(theta, mean=mu, cov=model._C))
    return (
        logsumexp(
            multivariate_normal.logpdf(
                true_data, mean=mu, cov=model._C
            ),
            axis=-1,
        )
        - mixture_weights
    )

from jax import random
from jax import vmap, value_and_grad

rng = random.PRNGKey(0)

theta_samples = random.normal(rng, (100, t))
np_log_prob = model.likelihood(theta_samples).logpdf(true_data)
jax_log_prob = log_prob(theta_samples)

value, grad = vmap(value_and_grad(log_prob))(theta_samples)

print((np_log_prob - jax_log_prob).mean())
print((np_log_prob - value).mean())

Not sure if this can be elegantly integrated but I will put this here for now as potentially useful for other projects

nb: correct weighting for mixtures with non trivial weights is wrong here, to be fixed later

williamjameshandley commented 1 month ago

The other option here is to have analytic gradients (and hessians) -- I don't know if this would be less flexible/faster or slower?

yallup commented 1 month ago

Good point! probably better and fits the ethos more, I will say this is not expensive and relatively easy to modify, so until we know what we actually want to optimize/sample, this is probably sufficient.

Below example fitting a model matrix from a single joint observation

Maximum Likelihood model_opt

Maximum Evidence model_opt

from lsbi.model import MixtureModel, LinearModel
from jax.scipy.stats import multivariate_normal
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import logsumexp
import anesthetic as ns
import matplotlib.pyplot as plt

d = 100
t = 5
k = 3
C = np.eye(d) * 50
model = LinearModel(M=np.random.randn(d, t))
# model = MixtureModel(M=np.random.randn(k, d, t))
true_theta, true_data = np.split(model.joint().rvs(), [t], axis=-1)

def log_prob(theta_m):
    #evidence
    # mu = model.m + jnp.einsum(
    #     "...ja,...a->...j", theta_m, true_theta * jnp.ones(model.n)
    # )
    # Σ = model._C + jnp.einsum(
    #             "...ja,...ab,...kb->...jk", theta_m, model._Σ, theta_m
    #         )
    # return multivariate_normal.logpdf(true_data, mean=mu, cov=Σ)

    #likelihood
    mu = model.m + jnp.einsum(
        "...ja,...a->...j", theta_m, true_theta * jnp.ones(model.n)
    )
    return  - multivariate_normal.logpdf(true_data, mean=mu, cov=model._C)

from jax import random
from jax import vmap, value_and_grad, jit
import optax
from jaxopt import LBFGS
rng = random.PRNGKey(0)

theta_m_samples = random.normal(rng, (d, t))
# np_log_prob = model.likelihood(theta_m_samples).logpdf(true_data)
jax_log_prob = log_prob(theta_m_samples)

# value, grad = vmap(value_and_grad(log_prob))(theta_m_samples)

theta_m = random.normal(rng, (d, t))
steps = 1000
# optimizer = optax.adam(1)
# opt_state = optimizer.init(theta_m)
solver = LBFGS(jit(log_prob), maxiter=steps)
# losses = []
# for i in range(steps):
#     value, grad = jit(value_and_grad(log_prob))(theta_m)
#     updates, optimizer_state = optimizer.update(grad, opt_state)
#     theta_m = optax.apply_updates(theta_m, updates)
#     losses.append(value)
#     print(value)

res = solver.run(theta_m)

surrogate_model = LinearModel(M=res[0])

a = ns.MCMCSamples(surrogate_model.posterior(true_data).rvs(500)).plot_2d(figsize=(6,6), label = "Fitted Surrogate Posterior")
ns.MCMCSamples(model.posterior(true_data).rvs(500)).plot_2d(a, label = "True Posterior")
a.iloc[-1, 0].legend(
    loc="lower center",
    bbox_to_anchor=(len(a) / 2, len(a)),
)
plt.savefig("model_opt.pdf")