JARVIS-MoCap / JARVIS-HybridNet

JARVIS Markerless 3D Motion Capture Pytorch Library
https://jarvis-mocap.github.io/jarvis-docs/
GNU Lesser General Public License v2.1
32 stars 7 forks source link

training hybridnet on multiple datasets with different number of cameras #8

Open timsainb opened 1 year ago

timsainb commented 1 year ago

Hey Timo,

I'm looking into pretraining a keypoint network again and I'm having trouble figuring out how to follow the instructions you gave via email (shown below).

I found the best way to use a different dataset for pretraining is as follows:

  1. Train a model on the pretraining dataset (in your case the combined rodent annotations). You can just use the train all function for this step.
  2. Retrain the full Hybridnet using the Training Mode 'all' WITHOUT retraining the EfficientNet 2D detector separately. It might be worth experimenting around with the learning rate in this step a little bit (lower ones might work better).
  3. Also train the CenterDetect network, initializing it with the pretrained weights.

    This works because the structure of the 3D network does not depend on the number of cameras used, so loading the pretrained HybridNet weights shouldn't be a problem, regardless of the camera configuration. With this training strategy I got very solid results on a hand-tracking task with only 50 labeled frame-sets, even though it used only 7 cameras (instead of the 12 in the pretraining dataset) and lighting and background where completely different.

I have three types of datasets:

I'm trying to do inference on new data from the 5-camera rig, but since I have all of these other labeled data, I want to use them all together for training (at least the 2D network.

So my understanding is there are three steps in the Jarvis pipeline (center detect, 2D keypoints, 3D hybridnet). For your first step (train a model on the pretraining dataset) I first made a training set of all of the datasets combined. I then train a center detect network

train_interface.train_efficienttrack(
    'CenterDetect', 
    project_name,
    num_epochs, 
    weights = weights
)

Then I train a 2D keypoint network:

train_interface.train_efficienttrack(
    'KeypointDetect', 
    project_name,
    num_epochs, 
    weights = 'latest',
)

Both of these work fine and I get pretty good performance. The issue is when I then try to train the hybridnet:

train_interface.train_hybridnet(
    project_name=project_name, 
    num_epochs = num_epochs,
    weights_keypoint_detect=weights_keypoint_detect, 
    weights=weights_hybridnet,
    mode=mode, 
    finetune =finetune
)

Training starts to run smoothly for a few batches:

Successfully loaded project 23-05-02-train-chronic-only.
[Info] Training HybridNet on project 23-05-02-train-chronic-only for 100 epochs!
[Info] Successfully loaded weights: /n/groups/datta/tim_sainburg/projects/JARVIS-HybridNet/projects/23-05-02-train-chronic-only/models/KeypointDetect/Run_20230504-002624/EfficientTrack-medium_final.pth
Epoch: 1/100. Loss: 298.4349. Acc: 68.30:   1%|▏         | 6/479 [00:02<03:30,  2.24it/s]

Then get the following error:

IndexError: Caught IndexError in DataLoader worker process 6.
Original Traceback (most recent call last):
  File "/home/tis697/.conda/envs/jarvis_launcher/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/tis697/.conda/envs/jarvis_launcher/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/tis697/.conda/envs/jarvis_launcher/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/n/groups/datta/tim_sainburg/projects/JARVIS-HybridNet/jarvis/dataset/dataset3D.py", line 204, in __getitem__
    centerHM[frame_idx] = np.array([center_x, center_y])
IndexError: index 5 is out of bounds for axis 0 with size 5

I believe what is happening is that since the 'dataset3D' object only has one version of self.reproTools, the network always expects the same reprojection parameters (despite different parameters existing for each trainingset, inside calib_params.

Any help on this method would be appreciated!

Thanks

timsainb commented 1 year ago

I think I localized the issue. The problem is that with multiple cameras, you're resetting the number of cameras for each calibration you add, so its only the last callibration that counts.

Create a trainingset https://github.com/JARVIS-MoCap/JARVIS-HybridNet/blob/master/jarvis/train_interface.py#L15

training_set = Dataset3D(project.cfg, set='train',
cameras_to_use = camera_list)

in the trainingset, create a calibration/reprotool for each dataset
https://github.com/JARVIS-MoCap/JARVIS-HybridNet/blob/master/jarvis/dataset/dataset3D.py#L50
self.reproTools = {}
for calibParams in self.dataset['calibrations']:
            calibPaths = {}
            for cam in self.dataset['calibrations'][calibParams]:
                if self.cameras_to_use == None or cam in self.cameras_to_use:
                    calibPaths[cam] = self.dataset['calibrations'] \
                                [calibParams][cam]
            self.reproTools[calibParams] = ReprojectionTool(
                        os.path.join(cfg.PARENT_DIR, self.root_dir), calibPaths)
            **self.num_cameras = self.reproTools[calibParams].num_cameras**
            self.reproTools[calibParams].resolution = [width,height]

So later when we're trying to grab something from the dataset, it's the last number of cameras that it expects: https://github.com/JARVIS-MoCap/JARVIS-HybridNet/blob/master/jarvis/dataset/dataset3D.py#L189 centerHM = np.full((self.num_cameras, 2), 128, dtype = int)