jax-ml / coix

Inference Combinators in JAX
https://coix.readthedocs.io/en/latest/
Apache License 2.0
45 stars 2 forks source link

Tutorial for neural adaptive smc #11

Open fehiepsi opened 1 year ago

fehiepsi commented 1 year ago

Neural Adaptive SMC, Gu etc. is a nice framework that allows us to train proposals for non-linear state space models. We can use forward KL in a nested variational inference scheme because both derivations provide similar grad estimations.

For state space models, we typically don't have reverse kernel because the state dimension grows over time. This example will greatly illustrate how to deal with growing-dimensional variables in JAX. The trick will be to prepare a full dimensional variable and perform index update in each smc step.

deoxyribose commented 1 month ago

Hi @fehiepsi,

I'd like to take a stab at this, but could use a little help getting started. Given the model from 5.1 in the paper

def ssm(xs = None, T_max = 1000):
    z_0 = numpyro.sample("z_0", dist.Normal(0, 5))
    z_t_m1 = z_0
    for t in range(1, T_max):
        z_t_loc = z_t_m1 / 2 + 25 * z_t_m1 / (1 + z_t_m1 ** 2) + 8 * jnp.cos(1.2 * t)
        z_t = numpyro.sample(f"z_{t}", dist.Normal(z_t_loc, jnp.sqrt(10)))
        x_t = numpyro.sample(f"x_{t}", dist.Normal(z_t ** 2 / 20, 1), obs=xs[t - 1] if xs is not None else None)
        z_t_m1 = z_t
    return x_t

I figure what needs to be implemented is an LSTM-based mixture density network which parametrizes q(z_t | z_1:t-1, x_1:t) (or q(v_t | z_1:t-1, x_1:t, f(z_t-1, t)), since that works better according to the paper). Then make a list of proposals, one for each z_t, each of which is sampled and used to update the full dimensional variable using zs.at[t].set(sample) ? Would the targets simply be the model above, conditioned on x_1:t ?

I will try to code something up, but some guidance would be very helpful!

fehiepsi commented 1 month ago

Great to hear that you are interested in this issue, @deoxyribose! The main theme of using coix is to define subprograms, then combine them together. Each subprogram is modelled by using a PPL, e.g. numpyro.

Your model is already in the form of a "combined" one. You can factor it out by creating subprograms: init_proposal, proposal_t, target_t. Here target_t is your body function of your for loop. proposal_t is your lstm-based model. target_t defines the joint distribution of p(z_t,x_t|...) while proposal_t is q(z_t|...). Let's walk through this step first. Please let me know if you have any question.

The next step is to combine those programs together. You can use the algorithm in coix.algo.nasmc or even better, combine them in your own way. But let's discuss this later.

deoxyribose commented 1 month ago

Thanks @fehiepsi! I've tried to do what you suggest here: https://github.com/deoxyribose/nasmc/blob/master/nasmc.ipynb, but I don't have a good handle on how it's supposed to look like yet. I'm not sure whether init_proposal and z_0 sampling should be separate from ssm_proposal and ssm_target respectively. In any case, I don't know how to progress from this current error message, but I figure I probably have some misconceptions apparent from the code which you could clear up :)

Edit 10/10/24: At present, I can run training if jit compilation is turned off, but judging by the metrics, it's not very stable and eventually crashes. I think I need to break the problem down to smaller tests than just running training, but I'm not sure what that could be.

fehiepsi commented 1 month ago

Sorry for the late response, @deoxyribose! I'll look into your notebook tomorrow.

fehiepsi commented 1 week ago

Hi @deoxyribose I addressed several issues in your notebook to match the paper; it seems to work: https://gist.github.com/fehiepsi/b7def6a77bf9ca150cf2f17f2ba1a2b5

Let me know if something is unclear.

deoxyribose commented 1 week ago

Thanks so much @fehiepsi! I got busy, but will have plenty of time to finish this in 2 weeks, if not before.