probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
633 stars 70 forks source link

Make HMM learning work with variable length time series #99

Open slinderman opened 1 year ago

slinderman commented 1 year ago

I don't think the current hmm_fit_sgd function is using the length of the time series as we hoped. At least with the default loss function, the length is just scaling the loss. Really, we need to change marginal_log_prob to only compute the log probability of observation up to the specified length.

Following up on our slack conversation, I see two ways of doing that:

  1. Pad time series with nan's and modify _conditional_logliks to put zeros wherever the emission is nan. That way the hmm_filter will still compute the marginal log prob of just the observed data. I think trick should also leave the hmm_smoother computations unchanged. A cool added benefit of this is it would allow us to interpolate over chunks of missing data.

It would look like this:

       # Perform a nested vmap over timeteps and states
        f = lambda emission: \
            vmap(lambda state: \
                self.emission_distribution(state).log_prob(emission))(
                    jnp.arange(self.num_states)
                )

        lls = vmap(f)(emissions)
        return jnp.where(jnp.isnan(lls), 0, lls)

I tested this out and the only problem is that we can't take gradients back through this function wrt model parameters. They nan out because one of the paths through the where is nan. See https://github.com/google/jax/issues/1052.

There's a somewhat clunky fix, which is to find the nan's first, replace them with a default value of the emissions, compute the log likelihoods, and then zero out the entries that were originally nan. That would look something like this:

        bad = jnp.any(jnp.isnan(emissions), axis=1)
        tmp = jnp.where(jnp.broadcast_to(bad[:, None], emissions.shape), 0.0, emissions)
        lls = vmap(f)(tmp)
        return jnp.where(jnp.broadcast_to(bad[:, None], lls.shape), 0.0, lls)

It's not the prettiest, but it works.

  1. Alternatively, we could pass the length of the time series to the underlying inference functions like hmm_filter. Then those functions would need to use a while loop to dynamically stop the message passing once the length has been reached. (I tried implementing this by calling filter on a dynamic slice of the data, but JAX barfed on that...) This approach is totally doable, but it would lead to lots of extra logic in the inference code.

I'm working on a demo of approach 1 right now. Will keep you posted!

murphyk commented 1 year ago

See also https://github.com/probml/dynamax/issues/50

slinderman commented 1 year ago

Just commenting here to note that this request (or variants of it) has come up multiple times in the past few weeks. A simple change would be to make the low level inference code allow missing data, and then update the model based code when time allows.

The HMM inference code is simple enough: you can indicate missing data by passing zeros to the corresponding rows of log_likelihoods. The *GSSM code could handle missing data by similarly "zeroing out" potentials (making emission covariance ~ infinite) if the emissions are nan.

murphyk commented 1 year ago

If we pass the valid length off each sequence, we can lax.scan only over that prefix. missing data at random times could be handled with an if statement for conditional update, or local evidence vector which is all 1s for missing time steps.

KeAWang commented 1 year ago

I actually have a fork of dynamax that handles missingness for the EKF (as well as allow time varying transitions and emissions): https://github.com/KeAWang/dynamax/commit/a991219873358f42b282af05c9e666d3b52ecf56. Though it's not for the HMM, I'm happy to open a PR for it

slinderman commented 1 year ago

Sure, that would be great Alex!