Closed calebweinreb closed 11 months ago
Is this the same issue @calebweinreb ? I'm running your lab's 3D dataset, but didn't realize I'd need multiple GPUs :(
You shouldn't need multiple GPUs. Just use "mixed_map_iters" as described in here https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#troubleshooting
You're a wizard, thank you again! 😀 I'll give it a try
On Sat, Feb 3, 2024, 14:12 Caleb Weinreb @.***> wrote:
You shouldn't need multiple GPUs. Just use "mixed_map_iters" as described in here https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#troubleshooting
— Reply to this email directly, view it on GitHub https://github.com/dattalab/keypoint-moseq/pull/79#issuecomment-1925468546, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHTNARA3L7W35543IVJG773YR2Y3NAVCNFSM6AAAAAA3HV535SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMRVGQ3DQNJUGY . You are receiving this because you commented.Message ID: @.***>
Summary
This PR introduces the following changes/features, which are explained in more detail below.
New logic for syllable indexing
Until now, the "extract_results" of keypoint-MoSeq saved saved syllable sequences in their original indexing (as they were represented during modeling) along with a "reindexed" version in which syllables were re-labeled by frequency (so syllable "0" was the most frequent, and so on). But this approach had a fatal flaw: when a fitted model was applied to new data, the syllable frequencies could be different, which would lead to a slightly different re-labeling, so that e.g. syllable "0" would refer to one state in a subset of recordings and a different state in another subset.
To prevent this issue, we now reindex syllable directly inside the model object. That way, if the model is used later to generate syllable for new data, the resulting labels will always be consistent. See https://github.com/dattalab/keypoint-moseq/issues/72 for details. Concretely, this means that
1) The standard modeling pipeline now includes a new step after model fitting but before extracting results:
2) The results files no longer include separate "syllables" and "syllables_reindexed" fields (see below for more details).
New format for results and checkpoint files
This PR introduces a new format for the
results.h5
andcheckpoint.p
files saved during modeling. This is a breaking change, meaning that results/checkpoints generated with a previous version of the code will no longer work. Below we explain the changes and provide code for converting to the new format.How the formats have changed
From a user perspective, the main change is that the
results.h5
no longer contains separatesyllables
andsyllables_reindexed
For the
results.h5
files, we have removed some fields and renamed others. Previously the format wasNow the format is
The
checkpoint.p
files have changed more substantively. They are now saved as hdf5 files (rather than joblib) and their internal organization has changed.Converting to the new format
The following code converts results and checkpoint files to the new format. Given a project directory and model name, a new project directory is generated with the updated files. As part of the reformatting, syllables are reindexed inside the model (see previous section) and a list of the resulting syllable name-changes is printed.
Make sure you are using the most up-to-date version of keypoint_moseq before running.
os.makedirs(new_project_dir) os.makedirs(os.path.join(new_project_dir, model_name))
for filename in ['pcs-xy.pdf', 'pca_scree.pdf', 'config.yml', 'pca.p']: src_path = os.path.join(old_project_dir, filename) if os.path.exists(src_path): shutil.copy(src_path, new_project_dir)
old_checkpoint_path = os.path.join(old_project_dir, model_name, 'checkpoint.h5') new_checkpoint_path = os.path.join(new_project_dir, model_name, 'checkpoint.h5')
old_checkpoint = joblib.load(os.path.join(old_project_dir, model_name, 'checkpoint.p')) new_checkpoint = update_checkpoint_format(old_checkpoint) kpms.save_hdf5(new_checkpoint_path, new_checkpoint)
index = kpms.reindex_syllables_in_checkpoint(new_project_dir, model_name) for i,j in enumerate(index): print(f'Syllable {j} is now labeled {i}')
old_results_path = os.path.join(old_project_dir, model_name, 'results.h5') new_results_path = os.path.join(new_project_dir, model_name, 'results.h5')
old_results = kpms.load_hdf5(os.path.join(old_project_dir, model_name, 'results.h5')) new_results = update_results_format(old_results, index) kpms.save_hdf5(new_results_path, new_results)
config = lambda: kpms.load_config(new_project_dir) keypoint_data_path = 'path/to/data' # modify as needed coordinates, confidences, bodyparts = kpms.load_keypoints(keypoint_data_path, 'deeplabcut') results = kpms.load_results(new_project_dir, model_name) kpms.save_results_as_csv(results, new_project_dir, model_name) kpms.generate_trajectory_plots(coordinates, new_results, new_project_dir, model_name, config()) kpms.generate_grid_movies(new_results, new_project_dir, model_name, coordinates=coordinates, config())
from jax_moseq.utils import set_mixed_map_iters set_mixed_map_iters(4) # adjust as needed
from jax_moseq.utils import set_mixed_map_gpus set_mixed_map_gpus(2)
def mixed_map(fun, in_axes=None, out_axes=None): """ Combine jax.pmap, jax.vmap and jax.lax.map for parallelization.
generate grid movies in the x/y plane
kpms.generate_grid_movies( results, project_dir, name, coordinates=coordinates, keypoints_only=True, use_dims=[0,1], **config())