rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Hidden Markov Models #49

Closed rlouf closed 3 years ago

rlouf commented 3 years ago

DRAFT

Please comment if you see issues with this design or have ideas, know use cases I did not think about!

Finding the right abstraction: the HMM distribution

We would like simplify the expression of hidden markov models (HMMs) in MCX. The underlying idea is that HMMs are made of units that are repeated.

Simple HMM

In their simplest form:

x[t-1] ---> x[t] ---> x[t+1]
  |          |           |
y[t-1]      y[t]      y[t+1]

And the elementary unit is:

x[t-1] ---> x[t]
             |
            y[t]

Let us see if this is possible to build a model from the expression of one unit as a generative model:

def hmm_unit(x_prev):
    x <~ Categorical(x_probs[x_prev])
    y <~ Bernoulli(y_probs[x])
    return y

Knowing the previous value of x and the observation y we can compute the posterior distribution of x_prev. How do we combine the units? We can create a new distribution!

class HMM(mcx.Distribution):
    def __init__(self, unit):
        pass

    def sample(self):
        # to be defined

    def logpdf(self):
        # to be defined

Let us assume this hmm distribution exists. A simple HMM would thus be loosely expressed in MCX as

@mcx.model
def mymodel(hidden_dims, num_units):
    x_probs <~ dist.Dirichlet(0.5 * np.eye(hidden_dims))
    y_probs <~ Beta(1,1, batch_size=(hidden_dims, num_units))

    @mcx.model
    def hmm_unit(x_prev):
        x <~ Categorical(x_probs[x_prev])
        y <~ Bernoulli(y_probs[x])
        return y

    x_init = np.zeros(num_units)

    obs <~ HMM(hmm_unit, x_init)

    return obs

HMM where observations depend on previous observations

To challenge this abstraction let us assume a more complex model:

x[t-1] ---> x[t] ---> x[t+1]
  |          |           |
y[t-1] ---> y[t] ---> y[t+1]

The elementary unit becomes:

x[t-1] ---> x[t]
             |           
y[t-1] ---> y[t]

So the model can be written:

@mcx.model
def mymodel(hidden_dims, num_units):
    x_probs <~ Dirichlet(0.5 * np.eye(hidden_dims))
    y_probs <~ Beta(1,1, batch_size=(hidden_dims, 2, num_units))

    @mcx.model
    def hmm_unit(x_prev, y_prev):
        x <~ Categorical(x_probs[x_prev])
        y <~ Bernoulli(y_probs[x, y_prev])
        return y

    x_init = np.zeros(num_units)

    obs <~ HMM(hmm_unit, x_init)

    return obs

Factorial HMM

What if we have a Factorial HMM instead:

x[t-1] ---> x[t] ---> x[t+1]
  |          |           |
  v          v           v
y[t-1]     y[t]       y[t+1]
  ^          ^           ^
  |          |           |
v[t-1] ---> v[t] ---> v[t+1]

Elementary unit is:

x[t-1] ---> x[t]
             | 
             v 
            y[t]
             ^  
             |  
v[t-1] ---> v[t]

And in code:

@mcx.model
def mymodel(hidden_dims, num_units):
    x_probs <~ Dirichlet(0.5 * np.eye(hidden_dims))
    v_probs <~ Dirichlet(0.3 * np.eye(hidden_dims))
    y_probs <~ Beta(1,1, batch_size=(hidden_dims, 2, num_units))

    @mcx.model
    def hmm_unit(x_prev, v_prev):
        x <~ Categorical(x_probs[x_prev])
        v <~ Categorical(v_probs[v_prev])
        y <~ Bernoulli(y_probs[x, v])
        return y

    x_init = np.zeros(num_units)

    obs <~ HMM(hmm_unit, x=x_init, v=v_init)

    return obs

The abstraction seems to be robust.

Implementing the HMM distribution

We need to provide an implementation for the sample and logpdf methods of the HMM distribution.

Sample

When parsing the model to compile it into a sampling function, MCX will transform hmm_unit into the sample_hmm_unit function below:

def sample_hmm_unit(rng_key, x):
    x_new = Categorical(x_probs[x]).sample(rng_key)
    y_new = Bernoulli(y_probs[x_new]).sample(rng_key)
    return x_new, y_new

HMM.sample(rng_key) should return samples for y and x's prior distribution. We can achieve it with:

def scan_update(x, rng_key):
    x_new, y = sample_hmm_unit(rng_key, x)
    return x_new, (x_new, y)

rng_key = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, num_units)
_, (x_samples, y_samples) = jax.lax.scan(scan_update, x_init, keys)

likelihood

When parsing the model to compile it into a loglikelihood, MCX with transform hmm_unit into the logpdf_hmm_unit function below:

def logpdf_hmm_unit(x_prev, x, y):
    loglikelihood = 0
    loglikelihood += Categorical(x_probs[x_prev]).logpdf(x)
    loglikelihood += Bernoulli(y_probs[x]).logpdf(y)
    return loglikelihood

HMM.logpdf(x, y) should return the loglikelihood of the model given the values of x_probs, y_probs (in the higher-level context) x and y=obs. We could use:

def scan_update(x_prev, (x, y)):
    loglikelihood = logpdf_hmm_unit(x_prev, x, y)
    return x, loglikelihood

_, loglikelihoods = jax.lax.scan(scan_update, x_init, (x, obs))
loglikelihood = np.sum(loglikelihoods)

Note: How do you implement time-dependent transitions?

ericmjl commented 3 years ago

Just leaving a note: lax.scan is rad! :smile:

rlouf commented 3 years ago

I can probably JIT-compile the unit for really fast forward passes through the model.

ericmjl commented 3 years ago

Yes. We did similar things with the unirep model!

rlouf commented 3 years ago

This is what is happening in the sampler and makes it fast: https://github.com/rlouf/mcx/blob/95bd8331e8d5a72b7db8287217416a24afbd35f1/mcx/sampling.py#L351

What do you think of this design draft for HMMs?

ericmjl commented 3 years ago

@rlouf I like it! There are a few points I don't understand though, would you be kind enough to clarify?

In the section "HMM where observations depend on previous observations", I see the syntax:

y <~ Bernoulli(y_probs[x, y_prev])

Is y_probs indexable that way? A Beta distribution is is continuous, rather than discrete, so I wasn't sure how this would work out in terms of the implementation. The syntax definitely is readable, but a seasoned reader of Python might be confused by the indexing I think.

rlouf commented 3 years ago

Look at the batch size, y_probs is a 3D array with independent rvs that are Beta-distributed. I first came across that on your blog actually :)

I believe Pyro has an expand method that does this, maybe it's more readable?

rlouf commented 3 years ago

Moved to discussions in #93 to declutter the issue tracker for real issues.