patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.03k stars 135 forks source link

Bayesian NNs with Equinox #541

Open cdagher opened 11 months ago

cdagher commented 11 months ago

This is really an exploratory question about implementing a Bayesian layer in Equinox. My research requires Bayesian neural networks, and I've been playing with Equinox to see how it stacks up to Flax - definitely impressed with the ease of using Equinox.

I'm thinking that it shouldn't be too hard to add this functionality with another Module in nn , but I figured I'd throw the question out there and see if there are any other thoughts or suggestions. Maybe a library like NumPyro could be used alongside Equinox?

patrick-kidger commented 11 months ago

Thanks, I'm glad you like Equinox! I've not tried that particular combination myself, so I don't have a ready-made example to hand. If anyone else has then maybe they can chime in.

For what it's worth, Equinox basically operates "at the same level" as JAX -- it's all just pytrees and eqx.Module is really just a neat syntax for a parameterised function -- so I suspect this is probably doable. If you figure it out then I'd be happy to host an example in the docs.

lockwo commented 11 months ago

I've only used numpyro a little, but as Patrick noted, equinox is pretty compatible with most jax libraries because it operates (or you can get it to operate) at a level very similar to jax itself. As an example, I did a simple adaptation of https://num.pyro.ai/en/0.7.1/examples/bnn.html (the equinox usage is simple, but it highlights some possibility, since there wasn't really a specific question).

import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from jax import vmap
import jax.numpy as jnp
import jax.random as random

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import equinox as eqx

matplotlib.use("Agg")  # noqa: E402

# the non-linearity we use in our neural network
def nonlin(x):
    return jnp.tanh(x)

# a two-layer bayesian neural network with computational flow
# given by D_X => D_H => D_H => D_Y where D_H is the number of
# hidden units. (note we indicate tensor dimensions in the comments)
class Model(eqx.Module):
    w1: dist.Normal
    w2: dist.Normal

    def __init__(self, D_X, D_H):
        self.w1 = dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H)))
        self.w2 = dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H)))

    def __call__(self, X, Y, D_H, D_Y=1):
        N, D_X = X.shape

        # sample first layer (we put unit normal priors on all weights)
        w1 = numpyro.sample("w1", self.w1)
        assert w1.shape == (D_X, D_H)
        z1 = nonlin(jnp.matmul(X, w1))  # <= first layer of activations
        assert z1.shape == (N, D_H)

        # sample second layer
        w2 = numpyro.sample("w2", self.w2)
        assert w2.shape == (D_H, D_H)
        z2 = nonlin(jnp.matmul(z1, w2))  # <= second layer of activations
        assert z2.shape == (N, D_H)

        # sample final layer of weights and neural network output
        w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
        assert w3.shape == (D_H, D_Y)
        z3 = jnp.matmul(z2, w3)  # <= output of the neural network
        assert z3.shape == (N, D_Y)

        # we put a prior on the observation noise
        prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
        sigma_obs = 1.0 / jnp.sqrt(prec_obs)

        # observe data
        with numpyro.plate("data", N):
            # note we use to_event(1) because each observation has shape (1,)
            numpyro.sample("Y", dist.Normal(z3, sigma_obs).to_event(1), obs=Y)

# helper function for HMC inference
def run_inference(model, args, rng_key, X, Y, D_H):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, X, Y, D_H)
    mcmc.print_summary()
    print("\nMCMC elapsed time:", time.time() - start)
    return mcmc.get_samples()

# helper function for prediction
def predict(model, rng_key, samples, X, D_H):
    model = handlers.substitute(handlers.seed(model, rng_key), samples)
    # note that Y will be sampled in the model because we pass Y=None here
    model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H)
    return model_trace["Y"]["value"]

# create artificial regression dataset
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = jnp.linspace(-1.3, 1.3, N_test)
    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

    return X, Y, X_test

def main(args):
    N, D_X, D_H = args.num_data, 3, args.num_hidden
    model = Model(D_X, D_H)
    X, Y, X_test = get_data(N=N, D_X=D_X)

    # do inference
    rng_key, rng_key_predict = random.split(random.PRNGKey(0))
    samples = run_inference(model, args, rng_key, X, Y, D_H)

    # predict Y_test at inputs X_test
    vmap_args = (
        samples,
        random.split(rng_key_predict, args.num_samples * args.num_chains),
    )
    predictions = vmap(
        lambda samples, rng_key: predict(model, rng_key, samples, X_test, D_H)
    )(*vmap_args)
    predictions = predictions[..., 0]

    # compute mean prediction and confidence interval around median
    mean_prediction = jnp.mean(predictions, axis=0)
    percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

    # make plots
    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

    # plot training data
    ax.plot(X[:, 1], Y[:, 0], "kx")
    # plot 90% confidence level of predictions
    ax.fill_between(
        X_test[:, 1], percentiles[0, :], percentiles[1, :], color="lightblue"
    )
    # plot mean prediction
    ax.plot(X_test[:, 1], mean_prediction, "blue", ls="solid", lw=2.0)
    ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

    plt.savefig("bnn_plot_eqx.pdf")
    plt.show()

if __name__ == "__main__":
    #assert numpyro.__version__.startswith("0.13.2")
    #parser = argparse.ArgumentParser(description="Bayesian neural network example")
    #parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
    #parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
    #parser.add_argument("--num-chains", nargs="?", default=1, type=int)
    #parser.add_argument("--num-data", nargs="?", default=100, type=int)
    #parser.add_argument("--num-hidden", nargs="?", default=5, type=int)
    #parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
    #args = parser.parse_args()
    from dataclasses import dataclass

    @dataclass
    class A:
        num_samples: int = 2000
        num_warmup: int = 1000
        num_chains: int = 1
        num_data: int = 100
        num_hidden: int = 5
        device: str = "gpu"

    args = A()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)

Yields

Screenshot 2023-10-04 at 3 48 40 PM