dotnet / infer

Infer.NET is a framework for running Bayesian inference in graphical models
https://dotnet.github.io/infer/
MIT License
1.54k stars 229 forks source link

Multivariate hidden markov model questions #436

Closed anthonypuppo closed 9 months ago

anthonypuppo commented 1 year ago

I've attempted to create a multivariate hidden markov model using the code from here as a base.

Few questions regarding my implementation:

  1. Features are standardized to their z scores (mean zero and variance one) before being passed to the model. Am I creating the mean and precision priors correctly? I can't seem to figure out what shape and scale are in reference to, and if the values being set passing in are appropriate for my standardized data.
  2. Prediction, how would this be done? Create a new instance of the HMM using the priors from the the previously trained one, add the next potential emission, save log odds and continue through all possible emissions and pick the best based on the log odds? Is there an easier way to do with with Infer.NET? Sampling would just give results similar to MCMC and not the most "likely".
  3. How would I adjust the model so the state probability takes into account he last n states instead of only the current? I know this goes against the markov property, but curious how it would be implemented.

Appreciate any help!

public class HiddenMarkovModel<T> where T : IFloatingPointIeee754<T>
{
    protected readonly Range k;
    protected readonly Range t;
    protected readonly Range e;

    protected readonly Variable<int> zeroState;
    protected readonly VariableArray<int> states;
    protected readonly VariableArray<Vector> emissions;

    protected readonly Variable<Vector> probInit;
    protected readonly VariableArray<Vector> cptTrans;
    protected readonly VariableArray<Vector> emitMean;
    protected readonly VariableArray<PositiveDefiniteMatrix> emitPrec;

    protected readonly Variable<Dirichlet> probInitPrior;
    protected readonly VariableArray<Dirichlet> cptTransPrior;
    protected readonly VariableArray<VectorGaussian> emitMeanPrior;
    protected readonly VariableArray<Wishart> emitPrecPrior;

    protected readonly Variable<bool> modelEvidence;

    protected readonly InferenceEngine engine;

    public Dirichlet ProbInitPosterior { get; private set; }
    public Dirichlet[] CPTTransPosterior { get; private set; }
    public VectorGaussian[] EmitMeanPosterior { get; private set; }
    public Wishart[] EmitPrecPosterior { get; private set; }
    public Discrete[] StatesPosterior { get; private set; }
    public Bernoulli ModelEvidencePosterior { get; private set; }

    public HiddenMarkovModel(T[,] data, int numStates)
    {
        var dataVectors = new Vector[data.GetLength(0)];

        for (var i = 0; i < data.GetLength(0); i++)
        {
            var row = data.SliceRow(i).Select((x) => Double.CreateSaturating(x)).ToArray();

            dataVectors[i] = Vector.FromArray(row);
        }

        modelEvidence = Variable.Bernoulli(0.5).Named("evidence");

        using (Variable.If(modelEvidence))
        {
            k = new Range(numStates).Named(nameof(k));
            t = new Range(data.GetLength(0)).Named(nameof(t));
            e = new Range(data.GetLength(1)).Named(nameof(e));

            probInitPrior = Variable.New<Dirichlet>().Named(nameof(probInitPrior));
            probInit = Variable<Vector>.Random(probInitPrior).Named(nameof(probInit));
            probInit.SetValueRange(k);

            cptTransPrior = Variable.Array<Dirichlet>(k).Named(nameof(cptTransPrior));
            cptTrans = Variable.Array<Vector>(k).Named(nameof(cptTrans));
            cptTrans[k] = Variable<Vector>.Random(cptTransPrior[k]);
            cptTrans.SetValueRange(k);

            emitMeanPrior = Variable.Array<VectorGaussian>(k).Named(nameof(emitMeanPrior));
            emitMean = Variable.Array<Vector>(k).Named(nameof(emitMean));
            emitMean[k] = Variable<Vector>.Random(emitMeanPrior[k]);
            emitMean.SetValueRange(k);

            emitPrecPrior = Variable.Array<Wishart>(k).Named(nameof(emitPrecPrior));
            emitPrec = Variable.Array<PositiveDefiniteMatrix>(k).Named(nameof(emitPrec));
            emitPrec[k] = Variable<PositiveDefiniteMatrix>.Random(emitPrecPrior[k]);
            emitPrec.SetValueRange(k);

            zeroState = Variable.Discrete(probInit).Named("z0");
            states = Variable.Array<int>(t);
            emissions = Variable.Array<Vector>(t);

            using (var block = Variable.ForEach(t))
            {
                var i = block.Index;
                var previousState = states[i - 1];

                using (Variable.If((i == 0).Named("Initial")))
                {
                    using (Variable.Switch(zeroState))
                    {
                        states[t] = Variable.Discrete(cptTrans[zeroState]);
                    }
                }

                using (Variable.If((i > 0).Named("Transition")))
                {
                    using (Variable.Switch(previousState))
                    {
                        states[i] = Variable.Discrete(cptTrans[previousState]);
                    }
                }

                using (Variable.Switch(states[i]))
                {
                    emissions[i] = Variable.VectorGaussianFromMeanAndPrecision(emitMean[states[i]], emitPrec[states[i]]);
                }
            }
        }

        var zInit = Variable.Array<Discrete>(t).Named("zInit");

        zInit.ObservedValue = Util.ArrayInit(t.SizeAsInt, (_) => Discrete.PointMass(Rand.Int(k.SizeAsInt), k.SizeAsInt));
        states[t].InitialiseTo(zInit[t]);

        engine = new InferenceEngine(new ExpectationPropagation())
        {
            NumberOfIterations = 15,
            //ShowWarnings = true,
            ShowProgress = false,
            //ShowTimings = true
        };
        engine.Compiler.WriteSourceFiles = true;
        engine.Compiler.AddComments = false;
        engine.Compiler.CompilerChoice = Microsoft.ML.Probabilistic.Compiler.CompilerChoice.Roslyn;

        emissions.ObservedValue = dataVectors;
    }

    public void SetUniformPriors()
    {
        probInitPrior.ObservedValue = Dirichlet.Uniform(k.SizeAsInt);
        cptTransPrior.ObservedValue = Enumerable.Repeat(Dirichlet.Uniform(k.SizeAsInt), k.SizeAsInt).ToArray();
        emitMeanPrior.ObservedValue = Enumerable.Repeat(VectorGaussian.FromMeanAndPrecision(Vector.Zero(e.SizeAsInt), PositiveDefiniteMatrix.IdentityScaledBy(e.SizeAsInt, 0.01)), k.SizeAsInt).ToArray();
        emitPrecPrior.ObservedValue = Enumerable.Repeat(Wishart.FromShapeAndScale(100, PositiveDefiniteMatrix.IdentityScaledBy(e.SizeAsInt, 0.01)), k.SizeAsInt).ToArray();
    }

    public void InferPosteriors()
    {
        CPTTransPosterior = engine.Infer<Dirichlet[]>(cptTrans);
        ProbInitPosterior = engine.Infer<Dirichlet>(probInit);
        EmitMeanPosterior = engine.Infer<VectorGaussian[]>(emitMean);
        EmitPrecPosterior = engine.Infer<Wishart[]>(emitPrec);
        StatesPosterior = engine.Infer<Discrete[]>(states);
        ModelEvidencePosterior = engine.Infer<Bernoulli>(modelEvidence);
    }

    public void Compute()
    {
        SetUniformPriors();
        InferPosteriors();
    }
}
tminka commented 1 year ago
  1. You are creating the priors correctly. If you want to know how the Wishart distribution is defined in Infer.NET, see the code doc for the Wishart class.
  2. If you want the most likely emission sequence, you must solve a so-called "marginal MAP" (or max-sum-product) problem. See "Variational algorithms for marginal MAP". This can be done exactly in an HMM. Infer.NET does not include algorithms for this so you would have to enumerate the options for each emission as you described.
  3. You would have multiple previousState variables and have nested switch statements on all of them. cptTrans would be a multidimensional (or jagged) array. See AddChildFromTwoParents.