amidst / toolbox

A Java Toolbox for Scalable Probabilistic Machine Learning
http://www.amidsttoolbox.com
Apache License 2.0
119 stars 35 forks source link

Exception with FactoredFrontier and VMP #73

Open rcabanasdepaz opened 7 years ago

rcabanasdepaz commented 7 years ago

When I use FactoredFrontier with VMP, I get the following exception. This is not an issue with importance Sampling.

Exception in thread "main" java.lang.IllegalStateException: NaN KL
    at eu.amidst.core.exponentialfamily.EF_Normal.kl(EF_Normal.java:355)
    at eu.amidst.core.inference.messagepassing.VMP.computeELBO(VMP.java:165)
    at eu.amidst.core.inference.messagepassing.VMP.lambda$computeLogProbabilityOfEvidence$90(VMP.java:144)
    at java.util.stream.ReferencePipeline$6$1.accept(ReferencePipeline.java:244)
    at java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:175)
    at java.util.ArrayList$ArrayListSpliterator.forEachRemaining(ArrayList.java:1374)
    at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:481)
    at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:471)
    at java.util.stream.ReduceOps$ReduceOp.evaluateSequential(ReduceOps.java:708)
    at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
    at java.util.stream.DoublePipeline.collect(DoublePipeline.java:476)
    at java.util.stream.DoublePipeline.sum(DoublePipeline.java:388)
    at eu.amidst.core.inference.messagepassing.VMP.computeLogProbabilityOfEvidence(VMP.java:144)
    at eu.amidst.core.inference.messagepassing.VMP.testConvergence(VMP.java:118)
    at eu.amidst.core.inference.messagepassing.MessagePassingAlgorithm.runInference(MessagePassingAlgorithm.java:195)
    at eu.amidst.dynamic.inference.FactoredFrontierForDBN.runInference(FactoredFrontierForDBN.java:144)
    at eu.amidst.tutorial.usingAmidst.inference.DynModelInference.main(DynModelInference.java:70)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:497)
    at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)

The full code:

    Random rand = new Random(0);

        String path = "datasets/dynamic/noclassdata/";

        DataStream<DynamicDataInstance> data = DynamicDataStreamLoader.open(path+"data0.arff");

        //Learn the model
        DynamicModel model =
                new KalmanFilter(data.getAttributes())
                .setNumHidden(2);

        //Learn the distributions
        model.updateModel(data);

        //Obtain the learned dynamic BN
        DynamicBayesianNetwork dbn = model.getModel();

        // Print the dynamic BN and save it
        System.out.println(dan);

        //Select the inference algorithm
        InferenceAlgorithmForDBN infer = new FactoredFrontierForDBN(new VMP()); // new ImportanceSampling(),  new VMP(),
        infer.setModel(dan);

        // Set the Variables of interest
        Variable varTarget = dbn.getDynamicVariables().getVariableByName("gaussianHiddenVar1");

        for(int t=0; t<10; t++) {
            // Set the evidence
            HashMapDynamicAssignment assignment = new HashMapDynamicAssignment(2);
            assignment.setValue(dbn.getDynamicVariables().getVariableByName("GaussianVar9"), rand.nextDouble());
            assignment.setValue(dbn.getDynamicVariables().getVariableByName("GaussianVar8"), rand.nextDouble());
            assignment.setTimeID(t);

            // Run the inference
            infer.addDynamicEvidence(assignment);
            infer.runInference();

            // Get the posterior at current instant of time
            Distribution posterior_t = infer.getFilteredPosterior(varTarget);
            System.out.println("t="+t+" "+posterior_t);

            // Get the posterior in the future
            Distribution posterior_t_1 = infer.getPredictivePosterior(varTarget, 1);
            System.out.println("t="+t+"+1 "+posterior_t_1);

        }