dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
68 stars 28 forks source link

resume fitting from an existing checkpoint but with new data #45

Closed r-shruthi11 closed 1 year ago

r-shruthi11 commented 1 year ago

I have more data than can fit in GPU memory to train my model with. If I reload a checkpoint, can I resume fitting with a different Y and mask?

calebweinreb commented 1 year ago

Hey, great questions! The dataset size limit imposed by GPUs is a big drag and something we're hoping to address in the future. In the meantime, here's a discussion on the FAQ page about big datasets.

Regarding your specific question: yes, you can resume fitting from an existing checkpoint but with new data. Here's a recipe:

import keypoint_moseq as kpms

project_dir = 'project/directory'
config = lambda: kpms.load_config(project_dir)
name = 'name_of_model' (e.g. '2023_03_16-15_50_11')

# load and format new data (e.g. from DeepLabCut)
coordinates, confidences,bodyparts = kpms.load_deeplabcut_results(dlc_results_directory)
data, labels = kpms.format_data(coordinates, confidences=confidences, **config())

# load previously saved PCA and model checkpoint
pca = kpms.load_pca(project_dir)
checkpoint = kpms.load_checkpoint(project_dir=project_dir, name=name)

# initialize a new model using saved parameters
model = kpms.init_model(data, pca=pca, params=checkpoint['params'], **config())

# continue fitting, now with the new data
model, history, name = kpms.fit_model(model, data, labels, num_iters=20, project_dir=project_dir)
r-shruthi11 commented 1 year ago

Thanks @calebweinreb - I'm using NVIDIA A100 GPUs (~80G of GPU RAM) on my jobs but my datasets are very large; I have on the order of ~2.7 million frames of video; I'll let you know how using your recipe goes.

calebweinreb commented 1 year ago

Closing for now. If there are any issues with the above code let me know!