tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

tfp.mcmc.HamiltonianMonteCarlo: add example how to infer parameters of Bayesian neural network #292

Open janosh opened 5 years ago

janosh commented 5 years ago

Suggestion: add a 3rd example to tfp.mcmc.HamiltonianMonteCarlo showing how to infer the posterior parameters of a Bayesian neural network (e.g. a simple few-layer Keras Sequential model) using Hamiltonian Monte Carlo.

janosh commented 5 years ago

cc @csuter @davmre @jvdillon

davmre commented 5 years ago

We'd welcome this as a contribution!

You'd need to write a log joint function that takes the network weights as arguments, and evaluates the prior+label likelihood on the full dataset. Assuming you wanted the network to be a Keras model, it'd take a bit of plumbing to evaluate the model with exogeneously provided weights -- I'm not sure off the top of my head what the most idiomatic approach would be -- but it should certainly be doable.

janosh commented 5 years ago

@davmre Thanks for the quick reply! It doesn't have to be a Keras model. I'm happy to use tfp internals. What would be the easiest approach in that case? tfp.trainable_distributions?

SiegeLordEx commented 5 years ago

I believe a particularly clean way to pass weights inside a Keras model would be to use the new DenseVariational layer and the Empirical distribution. Something like:

def posterior_mcmc(kernel_size, bias_size=0, dtype=None):
  return tfp.layers.DistributionLambda(lambda t: tfd.Empirical(mcmc_samples)))

model = tfpl.DenseVariational(posterior_mcmc, ...)

# model.loss -> joint log prob

But I haven't attempted it.

Long term, my vision is that we'd have MCMC Optimizers in the stype of tfp.optimizers.StochasticGradientLangevinDynamics and you'd write your BNN models using them.

davmre commented 5 years ago

For a simple model, you could always just write the network by hand. Something like

def build_network(weights_list, biases_list, activation=tf.nn.relu):
  def model(X):
    net = X
    for (weight, bias) in zip(weights_list[:-1], biases_list[:-1]):
      net = activation(tf.matmul(weight, net) + bias)
    net = tf.matmul(weights_list[-1], net) + biases_list[-1]  # Final linear layer.
    return tfd.Categorical(logits=net)   # or build a trainable normal, etc.
  return model

Then the log joint function would just build the network for the given weights and evaluate the prior and likelihood on your dataset:

X, y = get_full_dataset()
weights_prior = tfd.Normal(0., 1.)
bias_prior = tfd.Normal(0., 1e6)  # near-uniform
def log_joint_fn(weights1, bias1, weights2, bias2, weights3, bias3):
  weights_list = [weights1, weights2, weights3]
  biases_list = [bias1, bias2, bias3]

  # prior log-prob
  lp = sum([tf.reduce_sum(weights_prior.log_prob(weights)) for weights in weights_list])
  lp += sum([tf.reduce_sum(bias_prior.log_prob(bias)) for weights in bias_list])

  # likelihood of predicted labels
  network = build_network(weights_list, bias_list)
  labels_dist = network(X)
  lp += tf.reduce_sum(labels_dist.log_prob(y))

  return lp
davmre commented 5 years ago

@SiegeLordEx : yeah I think that'd work! For the log prob at a single sample point, you'd probably want the Deterministic distribution rather than Empirical, right?

SiegeLordEx commented 5 years ago

I'd still write is as Empirical, as that'd let you get the sample-based posterior predictive distribution a bit more easily. I.e. after you compute the chain, you'd pass the chain again into the model and after calling the model multiple times, it'll randomly pick one element of the chain.

janosh commented 5 years ago

@davmre I'm looking to use a Gaussian likelihood that combines aleatoric and epistemic uncertainty according to eq. (7) of What uncertainties do we need in deep learning?. I'm a first-time user of tfp so this might be a dumb question: Would I implement that by simply replacing

tfd.Categorical(logits=net)

with two trainable normals (or a batch?), i.e.

tfp.trainable_distributions.normal(
    net,
    layer_fn=tf.layers.dense,
    scale_fn=1.0,
)

one for the predictive mean y and one for the learned loss attenuation sigma, and then replace

labels_dist = network(X)
lp += tf.reduce_sum(labels_dist.log_prob(y))

with

labels_dist, loss_atten = network(X)
lp += tf.reduce_sum(labels_dist.log_prob(y)) / tf.math.square(loss_atten) + tf.math.log(loss_atten)
davmre commented 5 years ago

@janosh It's sufficient to just replace the Categorical with a trainable Normal: the log_prob of a normal distribution already incorporates both the mean and stddev. To model aleatoric uncertainty you need to ensure that the stddev is learned (coming from the network), but you don't need a separate loss_atten term.

Put differently: in TFP you don't need to write a normal density by hand; you can just use the Normal distribution.

By default the trainable_distributions.normal utility constructs a new dense layer internally, which you don't actually want; you need to control all the weights yourself so the HMC sampler can set them. It's possible to work around this but it might be simpler to just parameterize the normal directly:

# assumes you've constructed weights so that final `net` has shape [..., 2]
loc = net[..., 0]
scale = tf.nn.softplus(net[..., 1]) + 1e-6  # ensure scale is positive
return tfd.Normal(loc=loc, scale=scale)

which is what the trainable normal would do under the hood anyway.

janosh commented 5 years ago

@davmre Thanks for the advice. I attempted an implementation and it seems to be learning initially but stops very quickly as you can see from this plot:

neg-log-prob_vs_steps

The performance is terrible. The last 10 MSEs between predictive mean and ground truth are

[165.20711955090275, 191.70900226850847, 177.29640950234673, 192.7450666053791, 180.95249241999576, 176.5543714158893, 173.36850118140472, 209.5881465581332, 187.77771913021044]

For comparison, a comparably-sized network trained with dropout achieves an MSE of 0.2 - 0.3 on my data. So I must be doing something wrong.

Do you have some further advice on how to initialize the chain? I'm currently sampling the weigths and biases from a normal distribution but maybe that's a bad idea? Here's the code:

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

X, y = get_full_dataset()

def build_network(weights_list, biases_list, activation=tf.nn.relu):
    def model(X):
        net = X
        for (weight, bias) in zip(weights_list[:-1], biases_list[:-1]):
            net = activation(tf.matmul(net, weight) + bias)
        # split network into two heads
        pred = tf.matmul(net, weights_list[-1]) + biases_list[-1]
        var = tf.matmul(net, weights_list[-1]) + biases_list[-1]
        # `pred` and `var` each have size N = X.shape(0) (the number of data samples)
        # and are the model's prediction and learned loss attenuation, resp.
        scale = tf.nn.softplus(var) + 1e-6  # ensure scale is positive
        return tfd.Normal(loc=pred, scale=scale)

    return model

weights_prior = tfd.Normal(0.0, 1.0)
bias_prior = tfd.Normal(0.0, 1e6)  # near-uniform
neg_log_probs = []
mses = []

def joint_log_prob_fn(
    weights1, biases1, weights2, biases2, weights3, biases3, weights4, biases4
):
    weights_list = (weights1, weights2, weights3, weights4)
    biases_list = (biases1, biases2, biases3, biases4)

    # prior log-prob
    lp = sum(
        [tf.reduce_sum(weights_prior.log_prob(weights)) for weights in weights_list]
    )
    lp += sum([tf.reduce_sum(bias_prior.log_prob(bias)) for bias in biases_list])

    # likelihood of predicted labels
    network = build_network(weights_list, biases_list)
    labels_dist = network(X.astype("float32"))
    lp += tf.reduce_sum(labels_dist.log_prob(y))

    neg_log_probs.append(-lp.numpy())
    mse = ((labels_dist.loc.numpy() - y) ** 2).mean()
    mses.append(mse)
    return lp

def get_initial_state(layers=(X.shape[1], 50, 25, 10, 1)):
    architecture = []
    for idx in range(len(layers) - 1):
        weigths_loc = tf.zeros((layers[idx], layers[idx + 1]))
        biases_loc = tf.zeros(layers[idx + 1])
        weigths = tfd.Normal(loc=weigths_loc, scale=1.0).sample()
        biases = tfd.Normal(loc=biases_loc, scale=1.0).sample()
        architecture.extend((weigths, biases))
    return architecture

def run_hmc(num_results=100, num_burnin_steps=0, step_size=0.01):
    hmc_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
        tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=joint_log_prob_fn,
            num_leapfrog_steps=2,
            step_size=step_size,
            state_gradients_are_stopped=True,
        ),
        num_adaptation_steps=num_results + num_burnin_steps,
    )
    initial_state = get_initial_state()
    weights, kernel_results = tfp.mcmc.sample_chain(
        num_results=num_results,
        num_burnin_steps=num_burnin_steps,
        current_state=initial_state,
        kernel=hmc_kernel,
        # trace_fn=None,
    )

    print("mses", mses[-10:-1])
    plt.plot(neg_log_probs)
    plt.yscale("log")
    plt.xlabel("steps")
    plt.ylabel("negative log likelihood")
    plt.show()

run_hmc(num_results=100, num_burnin_steps=10)
davmre commented 5 years ago

You need the predictive mean and scale to be two different network outputs, i.e., the last layer should have size 2 (not 1), so you can split it into a separate mean and scale:

        # split network into two heads
        pred_and_var = tf.matmul(net, weights_list[-1]) + biases_list[-1]
        pred = pred_and_var[..., 0]
        var = pred_and_var[..., 1]

        # `pred` and `var` each have size N = X.shape(0) (the number of data samples)
        # and are the model's prediction and learned loss attenuation, resp.
        scale = tf.nn.softplus(var) + 1e-6  # ensure scale is positive
        return tfd.Normal(loc=pred, scale=scale)

(as originally written you've computed literally the same value for 'pred' and 'var').

janosh commented 5 years ago

@davmre Oops, that was stupid. Thanks for pointing that out!

I tried some more in #356 with a lot of help from @SiegeLordEx, @brianwa84 and @csuter. Thanks to them, I have a working implementation that's able to pause and resume the computation. It works nicely for a simple Gaussian, i.e. when replacing bnn_joint_log_prob in the following with dist.log_prob where dist = tfd.Normal(0.0, 1.0). However, with bnn_joint_log_prob (the joint log-likelihood of all parameters of the Bayesian NN), sample_chain just seems to stand still. The chain is just num_results identical samples. No idea what the problem might be.

Code block removed for brevity. See updated code block below.

SiegeLordEx commented 5 years ago

That looks good @janosh. What is the get_full_dataset function? If I had it, I'd help you debug this. In general, if things are stuck, you should try reducing the step size although the step size adaptation should have done for you if all the HMC steps get rejected.

janosh commented 5 years ago

@SiegeLordEx I slightly extended the above code to include a minimal working example of loading the data (attached as a ZIP file of CSVs) and attempting inference on it. It now requires numpy and sklearn. I also tried reducing the step_size by a factor of 10 and 100 but it made not difference.

import math
from datetime import datetime
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import sklearn.model_selection
import sklearn.preprocessing
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

def dense(X, W, b, activation):
    return activation(tf.matmul(X, W) + b)

def build_network(weights_list, biases_list, activation=tf.nn.relu):
    def model(X):
        net = X
        for (weights, biases) in zip(weights_list[:-1], biases_list[:-1]):
            net = dense(net, weights, biases, activation)
        # final linear layer
        net = tf.matmul(net, weights_list[-1]) + biases_list[-1]
        preds = net[:, 0]
        std_devs = tf.exp(-net[:, 1])
        # preds and std_devs each have size N = X.shape(0) (the number of data samples)
        # and are the model's predictions and (log-sqrt of) learned loss attenuations, resp.
        return tfd.Normal(loc=preds, scale=std_devs)

    return model

def get_initial_state(weight_prior, bias_prior, num_features, layers=None):
    """generate starting point for creating Markov chain
        of weights and biases for fully connected NN
    Keyword Arguments:
        layers {tuple} -- number of nodes in each layer of the network
    Returns:
        list -- architecture of FCNN with weigths and bias tensors for each layer
    """
    # make sure the last layer has two nodes, so that output can be split into
    # predictive mean and learned loss attenuation (see https://arxiv.org/abs/1703.04977)
    # which the network learns individually
    if layers is not None:
        assert layers[-1] == 2
    if layers is None:
        layers = (
            num_features,
            num_features // 2,
            num_features // 5,
            num_features // 10,
            2,
        )
    else:
        layers.insert(0, num_features)

    architecture = []
    for idx in range(len(layers) - 1):
        weigths = weight_prior.sample((layers[idx], layers[idx + 1]))
        biases = bias_prior.sample((layers[idx + 1]))
        # weigths = tf.zeros((layers[idx], layers[idx + 1]))
        # biases = tf.zeros((layers[idx + 1]))
        architecture.extend((weigths, biases))
    return architecture

def bnn_joint_log_prob_fn(weight_prior, bias_prior, X, y, *args):
    weights_list = args[::2]
    biases_list = args[1::2]

    # prior log-prob
    lp = sum(
        [tf.reduce_sum(weight_prior.log_prob(weights)) for weights in weights_list]
    )
    lp += sum([tf.reduce_sum(bias_prior.log_prob(bias)) for bias in biases_list])

    # likelihood of predicted labels
    network = build_network(weights_list, biases_list)
    labels_dist = network(X.astype("float32"))
    lp += tf.reduce_sum(labels_dist.log_prob(y))
    return lp

def trace_fn(current_state, results, summary_freq=100):
    step = results.step
    with tf.summary.record_if(tf.equal(step % summary_freq, 0)):
        for idx, tensor in enumerate(current_state, 1):
            count = str(math.ceil(idx / 2))
            name = "weights_" if idx % 2 == 0 else "biases_" + count
            tf.summary.histogram(name, tensor, step=tf.cast(step, tf.int64))
    return results

@tf.function
def graph_hmc(*args, **kwargs):
    """Compile static graph for tfp.mcmc.sample_chain.
    Since this is bulk of the computation, using @tf.function here
    signifcantly improves performance (empirically about ~5x).
    """
    return tfp.mcmc.sample_chain(*args, **kwargs)

def nest_concat(*args):
    return tf.nest.map_structure(lambda *parts: tf.concat(parts, axis=0), *args)

def run_hmc(
    target_log_prob_fn,
    step_size=0.01,
    num_leapfrog_steps=3,
    num_burnin_steps=1000,
    num_adaptation_steps=800,
    num_results=1000,
    num_steps_between_results=0,
    current_state=None,
    logdir="data/output/hmc/",
    resume=None,
):
    """Populates a Markov chain by performing `num_results` gradient-informed steps with a
    Hamiltonian Monte Carlo transition kernel to produce a Metropolis proposal. Either
    that or the previous state is appended to the chain at each step.

    Arguments:
        target_log_prob_fn {callable} -- Determines the HMC transition kernel
        and thereby the stationary distribution that the Markov chain will approximate.

    Returns:
        (chain(s), trace, final_kernel_result) -- The Markov chain(s), the trace created by `trace_fn`
        and the kernel results of the last step.
    """
    assert (current_state, resume) != (None, None)

    # Set up logging.
    stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = logdir + stamp
    summary_writer = tf.summary.create_file_writer(logdir)

    hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps
    )
    adaptive_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
        hmc_kernel, num_adaptation_steps=num_adaptation_steps
    )
    if resume is None:
        prev_kernel_results = adaptive_kernel.bootstrap_results(current_state)
        step = 0
    else:
        prev_chain, prev_trace, prev_kernel_results = resume
        step = len(prev_chain)
        current_state = tf.nest.map_structure(lambda chain: chain[-1], prev_chain)

    tf.summary.trace_on(graph=True, profiler=True)
    with summary_writer.as_default():
        tf.summary.trace_export(
            name="mcmc_sample_trace", step=step, profiler_outdir=logdir
        )
        chain, trace, final_kernel_results = graph_hmc(
            kernel=adaptive_kernel,
            current_state=current_state,
            num_results=num_results,
            previous_kernel_results=prev_kernel_results,
            num_steps_between_results=num_steps_between_results,
            num_burnin_steps=num_burnin_steps,
            trace_fn=partial(trace_fn, summary_freq=20),
            return_final_kernel_results=True,
        )
    summary_writer.close()

    if resume:
        chain = nest_concat(prev_chain, chain)
        trace = nest_concat(prev_trace, trace)

    return chain, trace, final_kernel_results

def get_data(test_size=0.1, random_state=0):
    with open("features.csv") as file:
        features = np.genfromtxt(file, delimiter=",")
    with open("labels.csv") as file:
        labels = np.genfromtxt(file, delimiter=",")

    labels = np.log(labels).reshape(-1, 1)

    X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
        features, labels, test_size=test_size, random_state=random_state
    )
    X_scaler = sklearn.preprocessing.StandardScaler().fit(X_train)
    y_scaler = sklearn.preprocessing.StandardScaler().fit(y_train)
    X_train = X_scaler.transform(X_train)
    X_test = X_scaler.transform(X_test)
    y_train = y_scaler.transform(y_train)
    y_test = y_scaler.transform(y_test)

    return (X_train, X_test), (y_train, y_test), (X_scaler, y_scaler)

def plot_neg_log_likelihood(neg_log_probs):
    plt.plot(neg_log_probs)
    plt.yscale("log")
    plt.xlabel("steps")
    plt.ylabel("negative log likelihood")
    plt.show()

weight_prior = tfd.Normal(0.0, 0.1)
bias_prior = tfd.Normal(0.0, 1.0)  # near-uniform
(X_train, X_test), (y_train, y_test), scalers = get_data()

bnn_joint_log_prob = partial(
    bnn_joint_log_prob_fn, weight_prior, bias_prior, X_train, y_train
)
num_features = X_train.shape[1]
initial_state = get_initial_state(weight_prior, bias_prior, num_features)

results = run_hmc(bnn_joint_log_prob, num_results=100, current_state=initial_state)

chain, trace, final_kernel_results = run_hmc(
    bnn_joint_log_prob, num_results=100, resume=results
)

tf.print(chain)
print("Acceptance rate:", trace.inner_results.is_accepted.numpy().mean())
target_log_probs = trace.inner_results.accepted_results.target_log_prob
plot_neg_log_likelihood(np.negative(target_log_probs))
janosh commented 5 years ago

@SiegeLordEx This seems to have been an issue with numerical stability. Replacing

std_devs = tf.exp(-net[:, 1])
scale = tf.nn.softplus(std_dev) + 1e-6  # ensure scale is positive

with

scale = tf.exp(-net[:, 1])

i.e. letting the network predict the log of the variance rather than the variance itself combined with choosing smaller weight and bias priors (I updated the code above to include these changes.)

weight_prior = tfd.Normal(0.0, 0.1)  # previously tfd.Normal(0.0, 1.0)
bias_prior = tfd.Normal(0.0, 1.0)  # previously tfd.Normal(0.0, 1e5)

the transition kernel is now able to make progress, albeit very slowly. It also stops abruptly in most runs and then seems to bob around in a small region of the target space as suggested by this plot.

Negative  log likelihood

Perhaps I'm starting in a really bad region of parameter space. I suppose the best thing to try now is to first train a regular NN on my data and use its final configuration as the initial state for the transition kernel?

janosh commented 5 years ago

@SiegeLordEx I implemented getting the starting point for HMC from a pre-trained network. Still doesn't work properly I'm afraid. This is the code now. If you have the time, could you take a look and see if you can spot what I'm doing wrong?

from datetime import datetime
from functools import partial

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

tfd = tfp.distributions

def nest_concat(*args):
    return tf.nest.map_structure(lambda *parts: tf.concat(parts, axis=0), *args)

def get_data(
    data_dir="data/input/", samples_file="features.csv", labels_file="resistivity.csv"
):
    with open(data_dir + samples_file) as file:
        features = np.genfromtxt(file, delimiter=",")
    with open(data_dir + labels_file) as file:
        labels = np.genfromtxt(file, delimiter=",")

    labels = np.log(labels).reshape(-1, 1)

    X_train, X_test, y_train, y_test = train_test_split(
        features, labels, test_size=0.1, random_state=0
    )
    X_scaler = StandardScaler().fit(X_train)
    y_scaler = StandardScaler().fit(y_train)
    X_train = X_scaler.transform(X_train)
    X_test = X_scaler.transform(X_test)
    y_train = y_scaler.transform(y_train)
    y_test = y_scaler.transform(y_test)

    return (X_train, y_train), (X_test, y_test), (X_scaler, y_scaler)

def dense(X, W, b, activation):
    return activation(tf.matmul(X, W) + b)

def build_network(weights_list, biases_list, activation=tf.nn.relu):
    def model(samples, training=True):
        net = samples
        for (weights, biases) in zip(weights_list[:-1], biases_list[:-1]):
            net = dense(net, weights, biases, activation)
        # final linear layer
        net = tf.matmul(net, weights_list[-1]) + biases_list[-1]
        y_pred, y_log_var = tf.unstack(net, axis=1)
        # preds and y_log_var (of size samples.shape(0)) are the model's
        # predictive mean and log variance, resp.
        if training:
            return tfd.Normal(loc=y_pred, scale=tf.sqrt(tf.exp(y_log_var)))
        else:
            return y_pred, tf.exp(y_log_var)

    return model

def robust_mse(y_true, y_pred, var):
    """
    compute mean sqaured error with learned loss attenuation (hence robust)
    see eq. (7) of https://arxiv.org/abs/1703.04977 for details
    """
    loss = 0.5 * tf.square(y_true - y_pred) * tf.exp(-var) + 0.5 * var
    return tf.reduce_mean(loss)

def pre_train_nn(X_train, y_train, nodes_per_layer, epochs=1000):
    """pre-train NN to get good weights for HMC initialization"""
    last_layer = nodes_per_layer.pop(-1)
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Input(nodes_per_layer.pop(0)))
    for nodes in nodes_per_layer:
        model.add(tf.keras.layers.Dense(nodes, activation="relu"))
        model.add(tf.keras.layers.Dropout(0.3))
    model.add(tf.keras.layers.Dense(last_layer, activation="linear"))

    model.compile(
        loss=lambda x, y: robust_mse(x, *tf.unstack(y, axis=1)), optimizer="adam"
    )
    model.fit(X_train, y_train, epochs=epochs, verbose=0)
    weights = model.get_weights()
    return [tf.convert_to_tensor(w) for w in weights], model

def get_nodes_per_layer(n_features, net_taper=(1, 0.5, 0.2, 0.1)):
    nodes_per_layer = [int(n_features * x) for x in net_taper]
    # ensure the last layer has two nodes so that output can be
    # split into predictive mean and variance
    nodes_per_layer.append(2)
    return nodes_per_layer

def bnn_log_prob_fn(weight_prior, bias_prior, X, y, *args):
    weights_list, biases_list = args[::2], args[1::2]

    # prior log-prob
    lp = sum([tf.reduce_sum(weight_prior.log_prob(w)) for w in weights_list])
    lp += sum([tf.reduce_sum(bias_prior.log_prob(b)) for b in biases_list])

    # likelihood of predicted labels
    network = build_network(weights_list, biases_list)
    labels_dist = network(X.astype("float32"))
    lp += tf.reduce_sum(labels_dist.log_prob(y))
    return lp

@tf.function
def graph_hmc(*args, **kwargs):
    """Compile static graph for tfp.mcmc.sample_chain.
    Since this is bulk of the computation, using @tf.function here
    significantly improves performance (empirically about ~5x).
    """
    return tfp.mcmc.sample_chain(*args, **kwargs)

def run_hmc(
    target_log_prob_fn,
    step_size=0.01,
    num_leapfrog_steps=10,
    num_burnin_steps=1000,
    num_adaptation_steps=800,
    num_results=1000,
    num_steps_between_results=0,
    current_state=None,
    log_dir="data/output/hmc/",
    resume=None,
):
    assert (current_state, resume) != (None, None)
    summary_writer = tf.summary.create_file_writer(log_dir)

    hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps
    )
    adaptive_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
        hmc_kernel, num_adaptation_steps=num_adaptation_steps
    )
    if resume is None:
        prev_kernel_results = adaptive_kernel.bootstrap_results(current_state)
        step = 0
    else:
        prev_chain, prev_trace, prev_kernel_results = resume
        step = len(prev_chain)
        current_state = tf.nest.map_structure(lambda chain: chain[-1], prev_chain)

    tf.summary.trace_on(graph=True, profiler=False)
    chain, trace, final_kernel_results = graph_hmc(
        kernel=adaptive_kernel,
        current_state=current_state,
        num_results=num_results,
        previous_kernel_results=prev_kernel_results,
        num_steps_between_results=num_steps_between_results,
        num_burnin_steps=num_burnin_steps,
        trace_fn=partial(trace_fn, summary_freq=20),
        return_final_kernel_results=True,
    )
    with summary_writer.as_default():
        tf.summary.trace_export(name="hmc_trace", step=step)
    summary_writer.close()

    if resume:
        chain = nest_concat(prev_chain, chain)
        trace = nest_concat(prev_trace, trace)

    return chain, trace, final_kernel_results

def trace_fn(current_state, kernel_results, summary_freq=10):
    step = kernel_results.step
    with tf.summary.record_if(tf.equal(step % summary_freq, 0)):
        for idx, tensor in enumerate(current_state, 1):
            count = str(int(idx / 2) + 1)
            name = "weights_" if idx % 2 == 0 else "biases_" + count
            tf.summary.histogram(name, tensor, step=tf.cast(step, tf.int64))
    return kernel_results

def get_log_dir(path="data/output/hmc/"):
    # Set up logging.
    stamp = datetime.now().strftime("%m-%d@%H:%M:%S")
    return path + stamp + "/"

log_dir = get_log_dir()
weight_prior = tfd.Normal(0.0, 1e5)
bias_prior = tfd.Normal(0.0, 1e5)

(X_train, y_train), (X_test, y_test), _ = get_data()

bnn_log_prob = partial(bnn_log_prob_fn, weight_prior, bias_prior, X_train, y_train)
n_features = X_train.shape[1]
nodes_per_layer = get_nodes_per_layer(n_features)
initial_state, model = pre_train_nn(X_train, y_train, nodes_per_layer)

chains, trace, final_kernel_results = run_hmc(
    bnn_log_prob,
    num_results=1000,
    num_burnin_steps=2000,
    num_adaptation_steps=2000,
    current_state=initial_state,
    log_dir=log_dir,
)
SiegeLordEx commented 5 years ago

Thanks, @janosh for persevering on this, sorry for the lack of feedback. I tried your code and indeed it's a little puzzling. It's somewhat hard to proceed with the code as is, so what I'll do over the next few days is:

SiegeLordEx commented 5 years ago

Ok, so I played around with this code some more. I think I got something that's working minimally okay. I put it in this colab (tell me if you can't access it): https://colab.research.google.com/drive/1bWQcuR5gaBPpow6ARKnPPL-dtf2EvTae

In there I show the result of running HMC, running regular SGD to find the MAP parameter estimate and also starting HMC from a MAP estimate. The main thing necessary to get HMC working was to just run it for a long time, while adapting the step size. It's very easy to accidentally cut the chains/adaptation short, yielding a non-converged chain.

For MAP, it's nice to see that it easily overfits... but it does reach a good test set loss pretty rapidly. Starting HMC from the MAP estimate yields mixed results: the chain is not really stationary either.

Caveats:

skeydan commented 5 years ago

@SiegeLordEx I think the colab is not publicly accessible?

SiegeLordEx commented 5 years ago

Sorry about that, it is now public.

skeydan commented 5 years ago

thanks!

tillschulz commented 3 years ago

@SiegeLordEx @davmre Hi,

regrading your bnn_joint_log_prob_fn from your colab example:

def bnn_joint_log_prob_fn(weight_prior, bias_prior, X, y, *args):
    weights_list = args[::2]
    biases_list = args[1::2]

    # prior log-prob
    lp = sum(
        [tf.reduce_sum(weight_prior.log_prob(weights)) for weights in weights_list]
    )
    lp += sum([tf.reduce_sum(bias_prior.log_prob(bias)) for bias in biases_list])

    # likelihood of predicted labels
    network = build_network(weights_list, biases_list)
    labels_dist = network(X.astype("float32"))
    lp += tf.reduce_sum(labels_dist.log_prob(y))
    return lp

Shouldn't the likelihood be scaled to the number of data points? If we use a large dataset and sum up all terms, doesn't the prior log prop become infinitesimal in comparison to the likelihood? Would it be sensible to compute the mean of the terms?

Thanks!

SiegeLordEx commented 3 years ago

Likelihood overwhelming the prior is a standard feature of Bayesian statistics for well specified models, so it's working as intended. Philosophically, the 'prior' represents your belief before any data has been observed, so it's natural that your beliefs drift away from your prior as you receive more data. The posteriors of BNNs are pretty fickle, however, and there you actually still see evidence of the prior choice even with huge amounts of data because certain features of BNN hyperparameters are non-identifiable. E.g. the posterior scale of the parameters empirically has been observed to match the prior scale (see https://arxiv.org/abs/2104.14421).