Closed r-shruthi11 closed 1 year ago
Hey! Yes that's a good idea. We should add that feature. In the meantime you can calculate it retroactively using the code below (of course replacing project_dir
and name
).
from jax_moseq.models.keypoint_slds import model_likelihood
import jax, os, joblib, tqdm, matplotlib.pyplot as plt
project_dir = 'demo_project'
name = '2023_05_04-11_34_10'
checkpoint = joblib.load(os.path.join(project_dir,name,'checkpoint.p'))
hypparams = jax.device_put(checkpoint['hypparams'])
noise_prior = jax.device_put(checkpoint['noise_prior'])
data = jax.device_put({'Y':checkpoint['Y'], 'mask':checkpoint['mask']})
saved_iters = sorted(checkpoint['history'].keys())
log_Y_and_model = []
log_Y_given_model = []
for ix in tqdm.tqdm(saved_iters):
states = jax.device_put(checkpoint['history'][ix]['states'])
params = jax.device_put(checkpoint['history'][ix]['params'])
ll = model_likelihood(data, states, params, hypparams, noise_prior)
log_Y_and_model.append(sum([v.item() for v in ll.values()]))
log_Y_given_model.append(ll['Y'].item())
fig,axs = plt.subplots(2,1,sharex=True)
axs[0].plot(saved_iters, log_Y_given_model)
axs[1].plot(saved_iters, log_Y_and_model)
axs[0].set_ylabel('P(model | Y)')
axs[1].set_ylabel('P(model, Y)')
axs[1].set_xlabel('Iteration')
This code calculates the conditional probability:
P(Y | model states) = P(Y | pose trajectory, heading, centroid, noise scales)
and the joint probability
P(Y, model states) = P(Y | model states) P(model states)
where
P(model states) = P(pose trajectory | syllable sequence) P(syllable sequence)
thanks, @calebweinreb! Closing this one out.
Is it possible when we save history to add
P(Y | model)
to the history?