dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
68 stars 28 forks source link

compute data log likelihood over iterations of fitting #50

Closed r-shruthi11 closed 1 year ago

r-shruthi11 commented 1 year ago

Is it possible when we save history to add P(Y | model) to the history?

calebweinreb commented 1 year ago

Hey! Yes that's a good idea. We should add that feature. In the meantime you can calculate it retroactively using the code below (of course replacing project_dir and name).

from jax_moseq.models.keypoint_slds import model_likelihood
import jax, os, joblib, tqdm, matplotlib.pyplot as plt

project_dir = 'demo_project'
name = '2023_05_04-11_34_10'

checkpoint = joblib.load(os.path.join(project_dir,name,'checkpoint.p'))
hypparams = jax.device_put(checkpoint['hypparams'])
noise_prior = jax.device_put(checkpoint['noise_prior'])
data = jax.device_put({'Y':checkpoint['Y'], 'mask':checkpoint['mask']})
saved_iters = sorted(checkpoint['history'].keys())

log_Y_and_model = []
log_Y_given_model = []
for ix in tqdm.tqdm(saved_iters):
    states = jax.device_put(checkpoint['history'][ix]['states'])
    params = jax.device_put(checkpoint['history'][ix]['params'])
    ll = model_likelihood(data, states, params, hypparams, noise_prior)
    log_Y_and_model.append(sum([v.item() for v in ll.values()]))
    log_Y_given_model.append(ll['Y'].item())

fig,axs = plt.subplots(2,1,sharex=True)
axs[0].plot(saved_iters, log_Y_given_model)
axs[1].plot(saved_iters, log_Y_and_model)
axs[0].set_ylabel('P(model | Y)')
axs[1].set_ylabel('P(model, Y)')
axs[1].set_xlabel('Iteration')

This code calculates the conditional probability:

P(Y | model states) = P(Y | pose trajectory, heading, centroid, noise scales)

and the joint probability

P(Y, model states) = P(Y | model states) P(model states)

where

P(model states) = P(pose trajectory | syllable sequence) P(syllable sequence)
r-shruthi11 commented 1 year ago

thanks, @calebweinreb! Closing this one out.