jax-ml / coix

Inference Combinators in JAX
https://coix.readthedocs.io/en/latest/
Apache License 2.0
43 stars 2 forks source link

Tutorial for neural adaptive smc #11

Open fehiepsi opened 1 year ago

fehiepsi commented 1 year ago

Neural Adaptive SMC, Gu etc. is a nice framework that allows us to train proposals for non-linear state space models. We can use forward KL in a nested variational inference scheme because both derivations provide similar grad estimations.

For state space models, we typically don't have reverse kernel because the state dimension grows over time. This example will greatly illustrate how to deal with growing-dimensional variables in JAX. The trick will be to prepare a full dimensional variable and perform index update in each smc step.

deoxyribose commented 2 days ago

Hi @fehiepsi,

I'd like to take a stab at this, but could use a little help getting started. Given the model from 5.1 in the paper

def ssm(xs = None, T_max = 1000):
    z_0 = numpyro.sample("z_0", dist.Normal(0, 5))
    z_t_m1 = z_0
    for t in range(1, T_max):
        z_t_loc = z_t_m1 / 2 + 25 * z_t_m1 / (1 + z_t_m1 ** 2) + 8 * jnp.cos(1.2 * t)
        z_t = numpyro.sample(f"z_{t}", dist.Normal(z_t_loc, jnp.sqrt(10)))
        x_t = numpyro.sample(f"x_{t}", dist.Normal(z_t ** 2 / 20, 1), obs=xs[t - 1] if xs is not None else None)
        z_t_m1 = z_t
    return x_t

I figure what needs to be implemented is an LSTM-based mixture density network which parametrizes q(z_t | z_1:t-1, x_1:t) (or q(v_t | z_1:t-1, x_1:t, f(z_t-1, t)), since that works better according to the paper). Then make a list of proposals, one for each z_t, each of which is sampled and used to update the full dimensional variable using zs.at[t].set(sample) ? Would the targets simply be the model above, conditioned on x_1:t ?

I will try to code something up, but some guidance would be very helpful!

fehiepsi commented 7 hours ago

Great to hear that you are interested in this issue, @deoxyribose! The main theme of using coix is to define subprograms, then combine them together. Each subprogram is modelled by using a PPL, e.g. numpyro.

Your model is already in the form of a "combined" one. You can factor it out by creating subprograms: init_proposal, proposal_t, target_t. Here target_t is your body function of your for loop. proposal_t is your lstm-based model. target_t defines the joint distribution of p(z_t,x_t|...) while proposal_t is q(z_t|...). Let's walk through this step first. Please let me know if you have any question.

The next step is to combine those programs together. You can use the algorithm in coix.algo.nasmc or even better, combine them in your own way. But let's discuss this later.