TuringLang / SSMProblems.jl

Common abstractions for state-space models
http://turinglang.org/SSMProblems.jl/
MIT License
2 stars 2 forks source link

Predict and update API for inference #39

Open THargreaves opened 1 month ago

THargreaves commented 1 month ago

It seems like we're starting to converge on the idea that on top of the transition and observation methods defined for SSMs, it would be helpful to provide some consistent structure for inference algorithms through the introduction of predict and update methods (named form the two steps forming the Kalman filter).

These two APIs would be fairly orthogonal in design, but with the inference methods calling on the SSM methods.

I feel that we need to be careful to implement these in a flexible way, and will provided some examples of what the interface could look like below.

Initial discussions of the issue are linked directly below.

THargreaves commented 1 month ago

Use in bootstrap PF and SIR

For the bootstrap filter, we would simply have predict call transition.

For general SIR, predict would use the supplied proposal function and then the update method would call transition_logdensity for reweighting.

Use in Kalman Filter

Suppose that the transition method is written with sufficient generality that it works on both a particle and a GaussianDistribution (this requires transition to return a distribution which is then sampled from).

function transition(dyn, state, ...)
    return dyn.A * state + Gaussian(zeros(dyn.d), dyn.Q))

Then, this transition method can be used in both the Kalman filter and a regular particle filter by dispatching on where state is a float or Gaussian.

Use in Hidden Markov Models

A similar approach can be taken as for the Kalman filter if we define a Multinomial(p) with methods

# Assuming matrix is stochastic
*(A::Matrix, d::Multinomial) = Multinomial(A * d)

It's not clear how you would make this work with an integer state though (i.e. how you'd run a particle filter with this transition function).

THargreaves commented 1 month ago

Had a more in-depth think about how this would apply to the Kalman filter and it seems like this four methods alone are insufficient.

Naturally we would have

function transition(dyn::LinearGaussian, state::Gaussian)
    return dyn.A * state + Gaussian(zeros(dyn.d), dyn.Q)
end

function observation(dyn::LinearGaussian, state::Gaussian)
    return dyn.H * state + Gaussian(zeros(dyn.d), dyn.R)
end

The prediction step (estimate of p(xk | y{1:k-1}) is obvious,

function predict(estimate::Gaussian, dyn::LinearGaussian, alg::KalmanFilter)
    return transition(dyn, state)

Recall the update step for the Kalman filter. This is less obvious.

Calling yi - observation(dyn, estimate) where estimate is the result of predict gives us the innovation distribution.

The trouble is, we then need to computer the Kalman gain, and then update the posterior covariance. Both of these operations require the H matrix of the model, which is not available.

And f you were going to make this available, you might as well make all A, b, Q, R, etc. available and we're back at the old interface.

I feel like this sort of problem is only going to become more common with more complex filters or smoothing algorithms.

FredericWantiez commented 1 month ago

I don't think we can make this completely general, but we could maybe impose more structure on the LatentDynamics and ObservationProcess processes ?

struct AdditiveNoiseLinearObservationProcess <: ObservationProcess
  Q::Matrix
  H::Matrix
end 

noise_covariance(obs::AdditiveNoiseLinearObservationProcess) = obs.Q
charlesknipp commented 1 month ago

Suppose that the transition method is written with sufficient generality that it works on both a particle and a GaussianDistribution (this requires transition to return a distribution which is then sampled from).

I think this implies an inherent difference in the meaning behind each dispatch of transition. The Gaussian dispatch propagates the entire filtered distribution, not just an individual state. I would instead dispatch predict and update on the sample, and leave transition and observation unused for the Kalman filter.

function predict(states::Gaussian, proc::LinearGaussianLatentDynamics, ::KalmanFilter)
    @unpack A, Q = proc

    predicted_states = let μ = states.μ, Σ = states.Σ
        Gaussian(A*μ, A*Σ*A' + Q)
    end

    return predicted_states
end
function update(predicted_states::Gaussian, proc::LinearGaussianObservationProcess, observation, ::KalmanFilter)
    @unpack H, R = proc

    states, residual, S = GaussianDistributions.correct(
        predicted_states,
        Gaussian(observation, R), H
    )

    log_marginal = logpdf(
        Gaussian(zero(residual), Symmetric(S)),
        residual
    )

    return states, log_marginal
end

I personally don't see anything wrong with this design pattern, but maybe there's something I'm missing.

THargreaves commented 1 month ago

I agree that both of the above approaches are sensible.

To add some context that is from a previous PR, the motivation for this issue was noting that the draft implementation of the Kalman filter didn't actually use the transition or observation methods from the SSM, rather requesting the model matrices and vectors [1].

The above three comments seem to suggest that it's much cleaner to not force the use of transition and observation, even though they are often useful.

Regardless of this, the predict/update structure for inference algorithms could be useful.

[1] Note, these were supplied by calc_A, calc_b etc. rather than accessing fields to, 1. allow these matrices/vectors to depend on control variables (required for Rao-Blackwellisation), 2. allow for alternative parameterisations, e.g. OU dynamics are parameterised by a single mean-reversion parameter which A can be generated from.

FredericWantiez commented 1 month ago

A small caveats is update would most likely introduce something like an AbstractFilterAlgorithm type to allow for dispatch.

yebai commented 1 month ago

Suppose that the transition method is written with sufficient generality that it works on both a particle and a GaussianDistribution (this requires transition to return a distribution which is then sampled from).

I slightly lean towards supporting transition for sampling and distributional representation of states. Then, Julia's multi-dispatching will be used to implement different algorithms.

The trouble is, we then need to computer the Kalman gain, and then update the posterior covariance. Both of these operations require the H matrix of the model, which is not available.

@FredericWantiez's approach works.

In addition, is there any reason we can't pass the additional inference-specific information (e.g. H) via the extra variable?

FredericWantiez commented 1 month ago

Looking at @charlesknipp example, makes me think that the predict/update API probably sits somewhere in an AbstractFilter program that subscribes to the SSMProblems API.

Here for example:


function predict(
        ::AbstractRNG,
        particles::Gaussian,
        model::LinearGaussianModel,
        ::KalmanFilter,
        step::Integer,
        extra
    )
    # this can be replaced with linear_system(proc, step, extra)
    @unpack A, Q = model.latent_dynamics

    predicted_particles = let μ = particles.μ, Σ = particles.Σ
        Gaussian(A*μ, A*Σ*A' + Q)
    end

    return predicted_particles
end

We could leverage the SSM structure:

function predict(rng, model, ...)
    proposed_state = transition(model.latent_dynamics, state, step, extra)
    # Or the calc_XYZ variants to get the corresponding latent parameters
    A, Q = linear_system(model.latent_dynamics, step, extra)
    predicted_state = let μ = proposed_state.μ, Σ = proposed_state.Σ
        # In practice, for a more general Kalman filter
        # we'd probably expect some sort of `update_covariance` / `update_mean` handlers
        Gaussian(μ, A*Σ*A' + Q)
    end
    return predicted_state
end
THargreaves commented 1 month ago

Looking at @charlesknipp example, makes me think that the predict/update API probably sits somewhere in an AbstractFilter program that subscribes to the SSMProblems API.

Definitely agree with this. You can then have step be the same as calling predict then update.

Allows for nice decomposability. I could run (predict, update)x4 as 4 pieces of data come in and then predict ahead 3 steps by just calling predict.

I think the two things we need to check before committing to this are: