dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
64 stars 26 forks source link

Applying trained model to new data from SLEAP #75

Closed lencriv closed 12 months ago

lencriv commented 12 months ago

Hello,

I am trying to apply a model I have trained to new data from a sleap .h5 file when I run the following code...

import os 
import numpy as np
import keypoint_moseq as kpms
project_dir = 'npix_kpms_4kp'

checkpoint = kpms.load_checkpoint(project_dir=project_dir, name='2023_07_23-21_35_50')
config = lambda: kpms.load_config(project_dir)

new_data = 'C:/Users/LEncR/keypoint-moseq/proj/whole_vids/mr3' # can be a file, a directory, or a list of files
coordinates, confidences, bodyparts = kpms.load_sleap_results(new_data)
results = kpms.apply_model(coordinates=coordinates, confidences=confidences,
                           project_dir=project_dir, pca=kpms.load_pca(project_dir),
                           **config(), **checkpoint)

I receive the following value error

Applying model:   0%|                                                                           | 0/20 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[67], line 1
----> 1 results = kpms.apply_model(coordinates=coordinates, confidences=confidences,
      2                            project_dir=project_dir, pca=kpms.load_pca(project_dir),
      3                            **config(), **checkpoint)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\keypoint_moseq\fitting.py:371, in apply_model(params, coordinates, confidences, num_iters, ar_only, save_results, verbose, project_dir, name, results_path, **kwargs)
    369 with tqdm.trange(num_iters, desc='Applying model') as pbar:
    370     for iteration in pbar:
--> 371         try: model = _wrapped_resample(
    372                 data, model, pbar=pbar, ar_only=ar_only, 
    373                 states_only=True, verbose=verbose)
    374         except StopResampling: break
    376 return extract_results(
    377     **model, labels=labels, save_results=save_results, name=name,
    378     project_dir=project_dir, results_path=results_path)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\keypoint_moseq\fitting.py:34, in _wrapped_resample(data, model, pbar, **resample_options)
     32 def _wrapped_resample(data, model, pbar=None, **resample_options):
     33     try: 
---> 34         model = resample_model(data, **model, **resample_options)
     35     except KeyboardInterrupt: 
     36         print('Early termination of fitting: user interruption')

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\gibbs.py:330, in resample_model(data, seed, states, params, hypparams, noise_prior, ar_only, states_only, skip_noise, fix_heading, verbose, **kwargs)
    291 def resample_model(data, seed, states, params, hypparams,
    292                    noise_prior, ar_only=False, states_only=False,
    293                    skip_noise=False, fix_heading=False, verbose=False,
    294                    **kwargs):
    295     """
    296     Resamples the Keypoint SLDS model given the hyperparameters,
    297     data, noise prior, current states, and current parameters.
   (...)
    328         updated seed, states, and parameters of the model.
    329     """
--> 330     model = arhmm.resample_model(data, seed, states, params,
    331                                  hypparams, states_only, verbose=verbose)
    332     if ar_only:
    333         model['noise_prior'] = noise_prior

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\gibbs.py:211, in resample_model(data, seed, states, params, hypparams, states_only, verbose, **kwargs)
    206     params['Ab'], params['Q']= resample_ar_params(
    207         seed, **data, **states, **params, 
    208         **hypparams['ar_hypparams'])
    210 if verbose: print('Resampling z (discrete latent states)')
--> 211 states['z'] = resample_discrete_stateseqs(
    212     seed, **data, **states, **params)
    214 return {'seed': seed,
    215         'states': states, 
    216         'params': params, 
    217         'hypparams': hypparams}

    [... skipping hidden 14 frame]

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\gibbs.py:59, in resample_discrete_stateseqs(seed, x, mask, Ab, Q, pi, **kwargs)
     56 num_samples = mask.shape[0]
     58 log_likelihoods = jax.lax.map(partial(ar_log_likelihood, x), (Ab, Q))
---> 59 _, z = jax.vmap(sample_hmm_stateseq, in_axes=(0,na,0,0))(
     60     jr.split(seed, num_samples),
     61     pi,
     62     jnp.moveaxis(log_likelihoods,0,-1),
     63     mask.astype(float)[:,nlags:])
     64 return convert_data_precision(z)

    [... skipping hidden 2 frame]

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\api.py:1731, in _mapped_axis_size(tree, vals, dims, name, kws)
   1723       sizes[x.shape[d]].append(i)
   1724   lines2 = ["{} {} {} {} to be mapped of size {}".format(
   1725               "args" if len(idxs) > 1 else "arg",
   1726               ", ".join(map(str, idxs)),
   (...)
   1729               size)
   1730             for size, idxs in sizes.items()]
-> 1731   raise ValueError(msg.format("\n".join(lines1 + ["so"] + lines2))) from None
   1732 else:
   1733   sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]

ValueError: vmap got inconsistent sizes for array axes to be mapped:
arg 0 has shape (46, 2) and axis 0 is to be mapped
arg 1 has shape (100, 100) and axis None is to be mapped
arg 2 has shape (19, 10027, 100) and axis 0 is to be mapped
arg 3 has shape (46, 10027) and axis 0 is to be mapped
so
args 0, 3 have axes to be mapped of size 46
arg 2 has an axis to be mapped of size 19
calebweinreb commented 12 months ago

Hmm that's odd. I am happy to suggest some debugging steps but the easiest thing would be for you to send me the files necessary to run the above code on my end... Would that be OK? Just send to calebsw@gmail.com (any method you prefer).

lencriv commented 12 months ago

Thanks! I emailed you

lencriv commented 12 months ago

Hi Caleb,

note that I fixed this by specifying the parameters as shown below

results = kpms.apply_model(
    coordinates=coordinates, 
    confidences=confidences, 
    project_dir=project_dir, 
    name=name, 
    results_path=new_data,
    pca=kpms.load_pca(project_dir),
    params=checkpoint['params'],
    hypparams=checkpoint['hypparams'],
    **config())

issue resolved! I appreciate your prompt reply.

calebweinreb commented 12 months ago

Glad its resolved! The issue was states being passed from checkpoint (i.e. a bug in how we wrote the function+example). It's addressed in https://github.com/dattalab/keypoint-moseq/commit/ab627d83525c4e0d462b1660b38b83583d57ca01 which will be part of the next release.

Also when using "apply_model", make sure you used "syllables" and not "syllables_reindexed" in the results that get output. And make sure to set "use_reindexed=False" for all the visualization functions. This is because of a subtle issue with reindexing (see https://github.com/dattalab/keypoint-moseq/issues/72). We're going to phase out "syllables_reindexed" in the next release.