dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
64 stars 26 forks source link

Restructure results/checkpoint + New features: (analysis tools, large dataset support, 3D viz and other more) #79

Closed calebweinreb closed 11 months ago

calebweinreb commented 11 months ago

Summary

This PR introduces the following changes/features, which are explained in more detail below.

New logic for syllable indexing

Until now, the "extract_results" of keypoint-MoSeq saved saved syllable sequences in their original indexing (as they were represented during modeling) along with a "reindexed" version in which syllables were re-labeled by frequency (so syllable "0" was the most frequent, and so on). But this approach had a fatal flaw: when a fitted model was applied to new data, the syllable frequencies could be different, which would lead to a slightly different re-labeling, so that e.g. syllable "0" would refer to one state in a subset of recordings and a different state in another subset.

To prevent this issue, we now reindex syllable directly inside the model object. That way, if the model is used later to generate syllable for new data, the resulting labels will always be consistent. See https://github.com/dattalab/keypoint-moseq/issues/72 for details. Concretely, this means that

1) The standard modeling pipeline now includes a new step after model fitting but before extracting results:

kpms.reindex_syllables_in_checkpoint(project_dir, model_name);

2) The results files no longer include separate "syllables" and "syllables_reindexed" fields (see below for more details).

New format for results and checkpoint files

This PR introduces a new format for the results.h5 and checkpoint.p files saved during modeling. This is a breaking change, meaning that results/checkpoints generated with a previous version of the code will no longer work. Below we explain the changes and provide code for converting to the new format.

How the formats have changed

From a user perspective, the main change is that the results.h5 no longer contains separate syllables and syllables_reindexed

For the results.h5 files, we have removed some fields and renamed others. Previously the format was

    results.h5
    ├──session_name1
    │  ├──estimated_coordinates  # denoised coordinates
    │  ├──syllables_reindexed    # syllables reindexed by frequency
    │  ├──syllables              # non-reindexed syllables labels (z)
    │  ├──latent_state           # inferred low-dim pose state (x)
    │  ├──centroid               # inferred centroid (v)
    │  └──heading                # inferred heading (h)
    ⋮

Now the format is

    results.h5
    ├──recording_name1
    │  ├──syllable      # syllable labels (z)
    │  ├──latent_state  # inferred low-dim pose state (x)
    │  ├──centroid      # inferred centroid (v)
    │  └──heading       # inferred heading (h)
    ⋮

The checkpoint.p files have changed more substantively. They are now saved as hdf5 files (rather than joblib) and their internal organization has changed.

Converting to the new format

The following code converts results and checkpoint files to the new format. Given a project directory and model name, a new project directory is generated with the updated files. As part of the reformatting, syllables are reindexed inside the model (see previous section) and a list of the resulting syllable name-changes is printed.

Make sure you are using the most up-to-date version of keypoint_moseq before running.

import keypoint_moseq as kpms
import numpy as np
import os, shutil
import joblib

def update_checkpoint_format(checkpoint):
    model = {k:checkpoint[k] for k in ['seed','noise_prior','params','states','hypparams']}
    model_snapshots = {str(checkpoint['iteration']): model}

    for i,hist in checkpoint['history'].items():
        model_snapshots[str(i)] = {
            'noise_prior': checkpoint['noise_prior'],
            'hypparams': checkpoint['hypparams'],
            'states': hist['states'],
            'params': hist['params'],
            'seed': hist['seed']
        }

    data = {'Y': checkpoint['Y'], 'conf':checkpoint['conf'], 'mask':checkpoint['mask']}
    keys = [l[0] for l in checkpoint['labels']]
    bounds = np.array([l[1:] for l in checkpoint['labels']])
    new_checkpoint = {'data':data, 'metadata':(keys, bounds), 'model_snapshots':model_snapshots}
    return new_checkpoint

def update_results_format(results, index=None):
    for k,v in results.items():
        if 'estimated_coordinates' in v:
            v['est_coords'] = v['estimated_coordinates']
            del v['estimated_coordinates']

        if 'syllables' in v:
            if index is None:
                v['syllable'] = v['syllables']
            else:
                v['syllable'] = np.argsort(index)[v['syllables']]
            del v['syllables']

        if 'syllables_reindexed' in v:
            del v['syllables_reindexed']  
    return results

os.makedirs(new_project_dir) os.makedirs(os.path.join(new_project_dir, model_name))

for filename in ['pcs-xy.pdf', 'pca_scree.pdf', 'config.yml', 'pca.p']: src_path = os.path.join(old_project_dir, filename) if os.path.exists(src_path): shutil.copy(src_path, new_project_dir)

- Convert saved checkpoint to new format

old_checkpoint_path = os.path.join(old_project_dir, model_name, 'checkpoint.h5') new_checkpoint_path = os.path.join(new_project_dir, model_name, 'checkpoint.h5')

old_checkpoint = joblib.load(os.path.join(old_project_dir, model_name, 'checkpoint.p')) new_checkpoint = update_checkpoint_format(old_checkpoint) kpms.save_hdf5(new_checkpoint_path, new_checkpoint)

- Reindex syllables in the model checkpoint

index = kpms.reindex_syllables_in_checkpoint(new_project_dir, model_name) for i,j in enumerate(index): print(f'Syllable {j} is now labeled {i}')

- Convert saved results to new format

old_results_path = os.path.join(old_project_dir, model_name, 'results.h5') new_results_path = os.path.join(new_project_dir, model_name, 'results.h5')

old_results = kpms.load_hdf5(os.path.join(old_project_dir, model_name, 'results.h5')) new_results = update_results_format(old_results, index) kpms.save_hdf5(new_results_path, new_results)

- Regenerate visualizations

config = lambda: kpms.load_config(new_project_dir) keypoint_data_path = 'path/to/data' # modify as needed coordinates, confidences, bodyparts = kpms.load_keypoints(keypoint_data_path, 'deeplabcut') results = kpms.load_results(new_project_dir, model_name) kpms.save_results_as_csv(results, new_project_dir, model_name) kpms.generate_trajectory_plots(coordinates, new_results, new_project_dir, model_name, config()) kpms.generate_grid_movies(new_results, new_project_dir, model_name, coordinates=coordinates, config())


# New analysis tools

This PR introduces a new set of analysis widgets and a tutorial notebook (`analysis.ipynb`) for using them. These widgets ingest results in the updated format described above. **So make sure to run the conversion code before applying the analysis pipeline to an existing project!**

# Support for large datasets

Currently it is not possible to model large datasets on a GPU without incurring out-of-memory (OOM errors). To address this problem, we have created a framework for mixed serial/parallel computation and added multi-GPU support.

### Partial serialization

By default, modeling is parallelized across the full dataset. Here we introduce a new option for mixed parallel/serial computation where the data is split into batches that are processed one at a time. To enable this option, run the following code *before fitting the model* (if you have already initiated model fitting the kernel must be restarted)

from jax_moseq.utils import set_mixed_map_iters set_mixed_map_iters(4) # adjust as needed

This will split the data into 4 batches, which should reduce the memory requirements about 4-fold but also result in a 4-fold slow-down. The number of batches can be adjusted as needed.

### Multi-GPU support

To use multiple GOUs, run the following code *before fitting the model* (if you have already initiated model fitting the kernel must be restarted)

from jax_moseq.utils import set_mixed_map_gpus set_mixed_map_gpus(2)

This will split the computation across two GPUs. 

### Additional info on implementation

Both of the above options (multi-GPU support and partial serialization) rely on a new utility called `mixed_map` that we added to the jax_moseq package. Below is a copy of its docstring:

def mixed_map(fun, in_axes=None, out_axes=None): """ Combine jax.pmap, jax.vmap and jax.lax.map for parallelization.

This function is similar to `jax.vmap`, except that it mixes together
`jax.pmap`, `jax.vmap` and `jax.lax.map` to prevent OOM errors and allow
for parallelization across multiple GPUs. The behavior is determined by
the global variables `_MIXED_MAP_ITERS` and `_MIXED_MAP_GPUS`, which can be
set using :py:func:`jax_moseq.utils.set_mixed_map_iters` and
py:func:`jax_moseq.utils.set_mixed_map_gpus` respectively.

Given an axis size of N to map, the data is padded such that the axis size
is a multiple of the number of `_MIXED_MAP_ITERS * _MIXED_MAP_GPUS`. The
data is then processed serially chunks, where the number of chunks is
determined by `_MIXED_MAP_ITERS`. Each chunk is processed in parallel
using jax.pmap to distribute across `_MIXED_MAP_GPUS` devices and jax.vmap
to parallelize within each device.
"""

# 3D plotting tools

- In addition to 2D projections of 3D keypoints, `plot_pcs` and `generate_trajectory_plots` now produce interactive 3D visualizations. These are rendered in the notebook and can also be viewed offline in a browser using the saved .html files. 

- It is now possible to generate grid movies for 3D keypoints, although they will only show 2D projections of the keypoints and not the underlying video. To generate grid movies from 3D data, include the flag ``keypoints_only=True`` and set the desired projection plane with the ``use_dims`` argument, e.g.

generate grid movies in the x/y plane

kpms.generate_grid_movies( results, project_dir, name, coordinates=coordinates, keypoints_only=True, use_dims=[0,1], **config())

bainro commented 5 months ago

Is this the same issue @calebweinreb ? I'm running your lab's 3D dataset, but didn't realize I'd need multiple GPUs :(

image

calebweinreb commented 5 months ago

You shouldn't need multiple GPUs. Just use "mixed_map_iters" as described in here https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#troubleshooting

bainro commented 5 months ago

You're a wizard, thank you again! 😀 I'll give it a try

On Sat, Feb 3, 2024, 14:12 Caleb Weinreb @.***> wrote:

You shouldn't need multiple GPUs. Just use "mixed_map_iters" as described in here https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#troubleshooting

— Reply to this email directly, view it on GitHub https://github.com/dattalab/keypoint-moseq/pull/79#issuecomment-1925468546, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHTNARA3L7W35543IVJG773YR2Y3NAVCNFSM6AAAAAA3HV535SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMRVGQ3DQNJUGY . You are receiving this because you commented.Message ID: @.***>