improbable-research / keanu

A probabilistic approach from an Improbabilistic company
MIT License
149 stars 33 forks source link

HalfCauchy vertex (positive support) gets a negative sampled value after NUTS #345

Closed luke14free closed 5 years ago

luke14free commented 6 years ago

Describe the bug Implementing the 8 schools model I noticed that a couple of things seem to go unexpectedly: 1) The MC gets negative values for TAU, which should only be able to take positive values 2) Although the MC seems to converge to a negative value (-1.9 in my case) the most likely value assigned to the node is quite different (-5.6 in my case)

To Reproduce Steps to reproduce the behavior:

Run this model which I adapted from the starter boilerplate (it's the famous non-centered 8 schools - reference: https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html):


import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.algorithms.mcmc.NUTS;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.KeanuRandom;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.*;

import java.util.*;

public class Eight {

    public static void main(String[] args) {
        Eight model = new Eight();
        model.run();
    }

    private KeanuRandom random;
    public double results;

    Eight() {

    }

    void run() {
        random = new KeanuRandom(42);

        int J = 8;
        double[] yObs = new double[]{28., 8., -3., 7., -1., 1., 18., 12.};
        double[] sigma = new double[]{15., 10., 16., 11., 9., 11., 10., 18.};

        ConstantDoubleVertex sigma_vertex = new ConstantDoubleVertex(sigma);

        DoubleVertex mu = new GaussianVertex(0, 5);
        DoubleVertex tau = new HalfCauchyVertex(5);
        DoubleVertex theta_tilde = new GaussianVertex(new long[]{1, J}, 0, 1);
        DoubleVertex theta = mu.plus(tau.multiply(theta_tilde));
        DoubleVertex y = new GaussianVertex(theta, sigma_vertex);
        y.observe(yObs);

        BayesianNetwork bayesNet = new BayesianNetwork(
                y.getConnectedGraph()
        );

        NUTS sampler = NUTS.builder()
                .adaptCount(100)
                .random(random)
                .build();

        NetworkSamples samples = sampler.getPosteriorSamples(
                bayesNet,
                bayesNet.getLatentVertices(),
                1000
        );

        List tauPosterior = samples.get(tau).asList();

        for (int i = 0; i < tauPosterior.size(); i++) {
            System.out.println(tauPosterior.get(i));
        }

        results = tau.getValue().scalar();

        System.out.println("Most probable value for tau (should be around 2.7): " + results);
    }

}

Expected behavior Convergence of the tau parameter around a positive value (range: 2.5-3)

Desktop (please complete the following information):

luke14free commented 6 years ago

HMC instead produces a value of 50 for TAU (but it doesn't seem to be affected by the chain converging to a different value than the output value)

gordoncaleb commented 6 years ago

Thanks for reporting this! I don’t see any issue with what you’re doing here. Someone will be taking a look tomorrow and we’ll get back to you.

gordoncaleb commented 6 years ago

Hi @luke14free, I've taken a look at the issue you're seeing. There's a bug in the step size adaptation code in Keanu's NUTS implementation that can cause it to be either extremely large or extremely small. The stack overflow you were seeing is due to it being extremely large and jumping to an position that tries to calculate digamma(-8000), which overflows the stack before it can calculate.

I've raised a PR https://github.com/improbable-research/keanu/pull/346 to fix the issue as well as add the usual sampling progress bar to NUTS. This will need to be merged and then we'll do a release in the next few days.

At the end of your example you results = tau.getValue().scalar(); which will actually only get you the tau value at the last sample, not the most probable value for tau. Since tau's posterior isn't bimodal or anything tricky, you can take the mean of the samples to find the most probable. I've taken what you posted and moved things around a bit to take means of tau, theta, and theta_tilde. The results I get match https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html OUT [26]. This OUT [26] puts the mean of tau at ~3.7 instead of 2.5-3. Where in that link points out 2.5-3?

Thanks for looking at this. We really appreciate the feedback. We haven't used NUTS much internally, so there's still some work to be done to make it more robust and informative.

luke14free commented 6 years ago

I don't honestly remember how I came up with 2.5-3, but it looks like I might have taken a value from a previous analysis I run on an uncentered version on the model, so definitely stick with ~3.7 from the pymc guide.

Happy to help, NUTS is pretty much the most used feature in stan/pymc (given that it's the default sampler for continuous variables) so glad that keanu is moving in that direction as well. One thing to definitely consider to implement is the divergence checks and implementation of multiple chains that are quite essential when working with samplers to check if things are going south (I will open a new issue and I am also happy to help code-wise if needed).

luke14free commented 6 years ago

Might seem like a dumb question; but how do you do this?

Since tau's posterior isn't bimodal or anything tricky, you can take the mean of the samples to find the most probable

I don't seem to be able to access posterior actual values

GeorgeNash commented 5 years ago

the sampler will return a NetworkSamples object. Once you've got that, let's call it posteriorSamples, you can do:

double averageValue = posteriorSamples.getDoubleTensorSamples(vertex).getAverages().scalar();

where vertex is the Vertex you want to get the average for

GeorgeNash commented 5 years ago

@gordoncaleb are we happy to close this with the recent NUTS fixes?

gordoncaleb commented 5 years ago

@GeorgeNash I think so.