talmolab / sleap

A deep learning framework for multi-animal pose tracking.
https://sleap.ai
Other
435 stars 97 forks source link

TD-ID model training on multi-vid set is not converging #513

Closed catubc closed 3 years ago

catubc commented 3 years ago

Hello

We have a set of 23 videos that we use to sample frames from and label videos. Talmo helped us load all 23 videos into a single GUI session and we labeled ~470 frames from the videos (between 20-50 frames in each vid). We then exported the .slp file and trained the TD-id model.

The training takes only 30sec to setup, and the model starts to converge right away

INFO:sleap.nn.training:Creating tf.data.Datasets for training data generation...
INFO:sleap.nn.training:Finished creating training datasets. [30.8s]
INFO:sleap.nn.training:Starting training loop...
Epoch 1/600
200/200 - 15s - loss: 0.0026 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 1.3863 - ClassVectorsHead_accuracy: 0.2488 - val_loss: 0.0025 - val_CenteredInstanceConfmapsHead_loss: 0.0011 - val_ClassVectorsHead_loss: 1.3863 - val_ClassVectorsHead_accuracy: 0.2955
Epoch 2/600
200/200 - 16s - loss: 0.0025 - CenteredInstanceConfmapsHead_loss: 0.0011 - ClassVectorsHead_loss: 1.3856 - ClassVectorsHead_accuracy: 0.2637 - val_loss: 0.0025 - val_CenteredInstanceConfmapsHead_loss: 0.0011 - val_ClassVectorsHead_loss: 1.3845 - val_ClassVectorsHead_accuracy: 0.2500
Epoch 3/600
200/200 - 16s - loss: 0.0024 - CenteredInstanceConfmapsHead_loss: 0.0010 - ClassVectorsHead_loss: 1.3825 - ClassVectorsHead_accuracy: 0.2912 - val_loss: 0.0024 - val_CenteredInstanceConfmapsHead_loss: 0.0011 - val_ClassVectorsHead_loss: 1.3585 - val_ClassVectorsHead_accuracy: 0.2955
Epoch 4/600
200/200 - 16s - loss: 0.0024 - CenteredInstanceConfmapsHead_loss: 0.0010 - ClassVectorsHead_loss: 1.3698 - ClassVectorsHead_accuracy: 0.3150 - val_loss: 0.0024 - val_CenteredInstanceConfmapsHead_loss: 0.0010 - val_ClassVectorsHead_loss: 1.3533 - val_ClassVectorsHead_accuracy: 0.3182
Epoch 5/600
200/200 - 16s - loss: 0.0024 - CenteredInstanceConfmapsHead_loss: 0.0010 - ClassVectorsHead_loss: 1.3504 - ClassVectorsHead_accuracy: 0.3400 - val_loss: 0.0023 - val_CenteredInstanceConfmapsHead_loss: 9.6060e-04 - val_ClassVectorsHead_loss: 1.3111 - val_ClassVectorsHead_accuracy: 0.4091
Epoch 6/600
200/200 - 16s - loss: 0.0023 - CenteredInstanceConfmapsHead_loss: 0.0010 - ClassVectorsHead_loss: 1.3083 - ClassVectorsHead_accuracy: 0.3663 - val_loss: 0.0022 - val_CenteredInstanceConfmapsHead_loss: 0.0010 - val_ClassVectorsHead_loss: 1.1463 - val_ClassVectorsHead_accuracy: 0.4318
Epoch 7/600
200/200 - 16s - loss: 0.0022 - CenteredInstanceConfmapsHead_loss: 9.9667e-04 - ClassVectorsHead_loss: 1.2362 - ClassVectorsHead_accuracy: 0.3988 - val_loss: 0.0021 - val_CenteredInstanceConfmapsHead_loss: 0.0010 - val_ClassVectorsHead_loss: 1.0686 - val_ClassVectorsHead_accuracy: 0.5455
Epoch 8/600
200/200 - 14s - loss: 0.0022 - CenteredInstanceConfmapsHead_loss: 0.0010 - ClassVectorsHead_loss: 1.1488 - ClassVectorsHead_accuracy: 0.4737 - val_loss: 0.0024 - val_CenteredInstanceConfmapsHead_loss: 0.0010 - val_ClassVectorsHead_loss: 1.4239 - val_ClassVectorsHead_accuracy: 0.2955
Epoch 9/600
200/200 - 15s - loss: 0.0021 - CenteredInstanceConfmapsHead_loss: 0.0010 - ClassVectorsHead_loss: 1.0784 - ClassVectorsHead_accuracy: 0.5163 - val_loss: 0.0021 - val_CenteredInstanceConfmapsHead_loss: 0.0011 - val_ClassVectorsHead_loss: 1.0656 - val_ClassVectorsHead_accuracy: 0.5682
Epoch 10/600
200/200 - 16s - loss: 0.0020 - CenteredInstanceConfmapsHead_loss: 9.9455e-04 - ClassVectorsHead_loss: 1.0367 - ClassVectorsHead_accuracy: 0.5462 - val_loss: 0.0019 - val_CenteredInstanceConfmapsHead_loss: 0.0010 - val_ClassVectorsHead_loss: 0.8768 - val_ClassVectorsHead_accuracy: 0.7045
Epoch 11/600
200/200 - 16s - loss: 0.0020 - CenteredInstanceConfmapsHead_loss: 9.8802e-04 - ClassVectorsHead_loss: 1.0041 - ClassVectorsHead_accuracy: 0.5675 - val_loss: 0.0018 - val_CenteredInstanceConfmapsHead_loss: 9.8947e-04 - val_ClassVectorsHead_loss: 0.8439 - val_ClassVectorsHead_accuracy: 0.6364
Epoch 12/600
200/200 - 16s - loss: 0.0020 - CenteredInstanceConfmapsHead_loss: 9.6954e-04 - ClassVectorsHead_loss: 0.9805 - ClassVectorsHead_accuracy: 0.5512 - val_loss: 0.0017 - val_CenteredInstanceConfmapsHead_loss: 9.5197e-04 - val_ClassVectorsHead_loss: 0.7849 - val_ClassVectorsHead_accuracy: 0.6591
Epoch 13/600
200/200 - 16s - loss: 0.0019 - CenteredInstanceConfmapsHead_loss: 9.6350e-04 - ClassVectorsHead_loss: 0.9671 - ClassVectorsHead_accuracy: 0.5850 - val_loss: 0.0017 - val_CenteredInstanceConfmapsHead_loss: 9.5756e-04 - val_ClassVectorsHead_loss: 0.7593 - val_ClassVectorsHead_accuracy: 0.7273
Epoch 14/600
200/200 - 14s - loss: 0.0019 - CenteredInstanceConfmapsHead_loss: 9.5260e-04 - ClassVectorsHead_loss: 0.9192 - ClassVectorsHead_accuracy: 0.5925 - val_loss: 0.0017 - val_CenteredInstanceConfmapsHead_loss: 9.1793e-04 - val_ClassVectorsHead_loss: 0.8242 - val_ClassVectorsHead_accuracy: 0.6364
Epoch 15/600
200/200 - 16s - loss: 0.0018 - CenteredInstanceConfmapsHead_loss: 9.3577e-04 - ClassVectorsHead_loss: 0.8906 - ClassVectorsHead_accuracy: 0.6275 - val_loss: 0.0017 - val_CenteredInstanceConfmapsHead_loss: 9.3701e-04 - val_ClassVectorsHead_loss: 0.7639 - val_ClassVectorsHead_accuracy: 0.7273
Epoch 16/600
200/200 - 16s - loss: 0.0018 - CenteredInstanceConfmapsHead_loss: 9.4721e-04 - ClassVectorsHead_loss: 0.8763 - ClassVectorsHead_accuracy: 0.6263 - val_loss: 0.0017 - val_CenteredInstanceConfmapsHead_loss: 9.5534e-04 - val_ClassVectorsHead_loss: 0.7206 - val_ClassVectorsHead_accuracy: 0.6591
Epoch 17/600
...

But when we added more labeled frames, up to 800 now, the model no longer converges. It also takes almost 10 minutes to setup.  See below

INFO:sleap.nn.training:Creating tf.data.Datasets for training data generation... INFO:sleap.nn.training:Finished creating training datasets. [538.2s] INFO:sleap.nn.training:Starting training loop... Epoch 1/600 200/200 - 18s - loss: 0.0025 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 1.3378 - ClassVectorsHead_accuracy: 0.2713 - val_loss: 0.0025 - val_CenteredInstanceConfmapsHead_loss: 0.0012 - val_ClassVectorsHead_loss: 1.3105 - val_ClassVectorsHead_accuracy: 0.3375 Epoch 2/600 200/200 - 19s - loss: 0.0025 - CenteredInstanceConfmapsHead_loss: 0.0011 - ClassVectorsHead_loss: 1.3298 - ClassVectorsHead_accuracy: 0.2825 - val_loss: 0.0024 - val_CenteredInstanceConfmapsHead_loss: 0.0011 - val_ClassVectorsHead_loss: 1.3698 - val_ClassVectorsHead_accuracy: 0.3000 Epoch 3/600 200/200 - 19s - loss: 0.0024 - CenteredInstanceConfmapsHead_loss: 0.0010 - ClassVectorsHead_loss: 1.3660 - ClassVectorsHead_accuracy: 0.2637 - val_loss: 0.0023 - val_CenteredInstanceConfmapsHead_loss: 0.0010 - val_ClassVectorsHead_loss: 1.2876 - val_ClassVectorsHead_accuracy: 0.3500 Epoch 4/600 200/200 - 18s - loss: 0.0024 - CenteredInstanceConfmapsHead_loss: 0.0011 - ClassVectorsHead_loss: 1.3613 - ClassVectorsHead_accuracy: 0.2475 - val_loss: 0.0024 - val_CenteredInstanceConfmapsHead_loss: 0.0010 - val_ClassVectorsHead_loss: 1.3465 - val_ClassVectorsHead_accuracy: 0.3500 Epoch 5/600 200/200 - 17s - loss: 0.0033 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 2.1145 - ClassVectorsHead_accuracy: 0.2338 - val_loss: 0.0028 - val_CenteredInstanceConfmapsHead_loss: 0.0011 - val_ClassVectorsHead_loss: 1.6744 - val_ClassVectorsHead_accuracy: 0.3250 Epoch 6/600 200/200 - 17s - loss: 0.0032 - CenteredInstanceConfmapsHead_loss: 0.0011 - ClassVectorsHead_loss: 2.1346 - ClassVectorsHead_accuracy: 0.2350 - val_loss: 0.0024 - val_CenteredInstanceConfmapsHead_loss: 0.0010 - val_ClassVectorsHead_loss: 1.3860 - val_ClassVectorsHead_accuracy: 0.3125 Epoch 7/600 200/200 - 17s - loss: 0.0063 - CenteredInstanceConfmapsHead_loss: 0.0011 - ClassVectorsHead_loss: 5.2155 - ClassVectorsHead_accuracy: 0.2313 - val_loss: 0.0026 - val_CenteredInstanceConfmapsHead_loss: 0.0013 - val_ClassVectorsHead_loss: 1.3105 - val_ClassVectorsHead_accuracy: 0.2625 Epoch 8/600 200/200 - 17s - loss: 0.0111 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 9.9087 - ClassVectorsHead_accuracy: 0.2488 - val_loss: 0.0388 - val_CenteredInstanceConfmapsHead_loss: 0.0012 - val_ClassVectorsHead_loss: 37.5248 - val_ClassVectorsHead_accuracy: 0.3125 Epoch 9/600 ... Epoch 17/600 200/200 - 17s - loss: 154957.7656 - CenteredInstanceConfmapsHead_loss: 147142.5781 - ClassVectorsHead_loss: 7815238.0000 - ClassVectorsHead_accuracy: 0.2450 - val_loss: 78.5273 - val_CenteredInstanceConfmapsHead_loss: 0.0012 - val_ClassVectorsHead_loss: 78526.1016 - val_ClassVectorsHead_accuracy: 0.2875 Epoch 18/600 200/200 - 17s - loss: 1687.5400 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 1687538.7500 - ClassVectorsHead_accuracy: 0.2325 - val_loss: 4930.4688 - val_CenteredInstanceConfmapsHead_loss: 0.0013 - val_ClassVectorsHead_loss: 4930467.0000 - val_ClassVectorsHead_accuracy: 0.2500 Epoch 19/600 200/200 - 17s - loss: 10122.5576 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 10122554.0000 - ClassVectorsHead_accuracy: 0.2225 - val_loss: 10031.4629 - val_CenteredInstanceConfmapsHead_loss: 0.0013 - val_ClassVectorsHead_loss: 10031461.0000 - val_ClassVectorsHead_accuracy: 0.2125 Epoch 20/600 200/200 - 17s - loss: 3817.9763 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 3817974.5000 - ClassVectorsHead_accuracy: 0.2463 - val_loss: 16.2432 - val_CenteredInstanceConfmapsHead_loss: 0.0011 - val_ClassVectorsHead_loss: 16242.0996 - val_ClassVectorsHead_accuracy: 0.2125 ... 200/200 - 17s - loss: 790181305122816.0000 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 790181279889883136.0000 - ClassVectorsHead_accuracy: 0.2325 - val_loss: 7487455141822464.0000 - val_CenteredInstanceConfmapsHead_loss: 0.0012 - val_ClassVectorsHead_loss: 7487455932096446464.0000 - val_ClassVectorsHead_accuracy: 0.2000 Epoch 31/600 200/200 - 17s - loss: 11424157740826624.0000 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 11424150112964706304.0000 - ClassVectorsHead_accuracy: 0.2338 - val_loss: 10869713970135040.0000 - val_CenteredInstanceConfmapsHead_loss: 0.0012 - val_ClassVectorsHead_loss: 10869713678077263872.0000 - val_ClassVectorsHead_accuracy: 0.2875 ..... Epoch 600/600 200/200 - 17s - loss: 456925662062575616.0000 - CenteredInstanceConfmapsHead_loss: 0.0012 - ClassVectorsHead_loss: 456925859150034894848.0000 - ClassVectorsHead_accuracy: 0.2587 - val_loss: 509130164911734784.0000 - val_CenteredInstanceConfmapsHead_loss: 0.0012 - val_ClassVectorsHead_loss: 509130143471258042368.0000 - val_ClassVectorsHead_accuracy: 0.2125 INFO:sleap.nn.training:Finished training loop. [171.3 min] INFO:sleap.nn.training:Saving evaluation metrics to model folder... WARNING:sleap.nn.evals:Failed to compute metrics. INFO:sleap.nn.evals:Saved predictions: models/gerbils.multiclass_topdown_4/ labels_pr.train.slp /usr/local/lib/python3.7/dist-packages/sleap/nn/evals.py:505: RuntimeWarning: Mean of empty slice "dist.avg": np.nanmean(dists), /usr/local/lib/python3.7/dist-packages/sleap/nn/evals.py:538: RuntimeWarning: Mean of empty slice. mPCK = mPCK_parts.mean() /usr/local/lib/python3.7/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars ret = ret.dtype.type(ret / rcount) /usr/local/lib/python3.7/dist-packages/sleap/nn/evals.py:632: RuntimeWarning: Mean of empty slice. pair_pck = metrics["pck.pcks"].mean(axis=-1).mean(axis=-1) /usr/local/lib/python3.7/dist-packages/numpy/core/_methods.py:154: RuntimeWarning: invalid value encountered in true_divide ret, rcount, out=ret, casting='unsafe', subok=False) /usr/local/lib/python3.7/dist-packages/sleap/nn/evals.py:634: RuntimeWarning: Mean of empty slice. metrics["oks.mOKS"] = pair_oks.mean() WARNING:sleap.nn.evals:Failed to compute metrics. INFO:sleap.nn.evals:Saved predictions: models/gerbils.multiclass_topdown_4/ labels_pr.val.slp /usr/local/lib/python3.7/dist-packages/sleap/nn/evals.py:505: RuntimeWarning: Mean of empty slice "dist.avg": np.nanmean(dists), /usr/local/lib/python3.7/dist-packages/sleap/nn/evals.py:538: RuntimeWarning: Mean of empty slice. mPCK = mPCK_parts.mean() /usr/local/lib/python3.7/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars ret = ret.dtype.type(ret / rcount) /usr/local/lib/python3.7/dist-packages/sleap/nn/evals.py:632: RuntimeWarning: Mean of empty slice. pair_pck = metrics["pck.pcks"].mean(axis=-1).mean(axis=-1) /usr/local/lib/python3.7/dist-packages/numpy/core/_methods.py:154: RuntimeWarning: invalid value encountered in true_divide ret, rcount, out=ret, casting='unsafe', subok=False) /usr/local/lib/python3.7/dist-packages/sleap/nn/evals.py:634: RuntimeWarning: Mean of empty slice. metrics["oks.mOKS"] = pair_oks.mean()

catubc commented 3 years ago

Our best guess is there is some intermediate step that Talmo was doing on the .slp file to get it prepared for training. But it's not clear what this step was.

I see there is "pkg" part to the first file name. Is this related to our issue? Perhaps the labels were first used in the centroid model and then used in the ID model? If that's the case, can you advise on how to do this? Should we just try and train the centroid model and grab one of the files from there?

Here is the .slp file for the 471 labels that Talmo exported and works:

labels_tracked.v021.pkg.slp: https://drive.google.com/file/d/1enJU1SE0tjE0uO1yQBpkDPZVkv5Ak8xu/view?usp=sharing


Using prediction file fname:  /content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp
Skeleton: Skeleton(nodes=[nose, lefteye, righteye, leftear, rightear, spine1, spine2, spine3, spine4, spine5, tail1, tail2, tail3, tail4], edges=[spine3->spine2, spine2->spine1, spine1->lefteye, spine1->leftear, spine1->nose, spine1->righteye, spine1->rightear, spine3->spine4, spine4->spine5, spine5->tail1, tail1->tail2, tail2->tail3, tail3->tail4], symmetries=[lefteye<->righteye, leftear<->rightear])
Videos: ['/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp', '/content/drive/MyDrive/data/labels/feb_26/labels_tracked.v021.pkg.slp']
Frames (user/predicted): 471/0
Instances (user/predicted): 1,828/0
Tracks: [Track(spawned_on=1998, name='female'), Track(spawned_on=1998, name='male'), Track(spawned_on=1998, name='pup shaved'), Track(spawned_on=1998, name='pup unshaved')]
Suggestions: 4,600
Provenance: {}

and the .slp file for the 800 labeled frames that we exported that doesn't work anymore:

labels.v028.slp: https://drive.google.com/file/d/1SCsBudZS5J-CMo5uYQQYq7fmFb-1YSLf/view?usp=sharing


Using prediction file fname:  /content/drive/MyDrive/data/labels/mar_12/labels.v028.slp
Skeleton: Skeleton(nodes=[nose, lefteye, righteye, leftear, rightear, spine1, spine2, spine3, spine4, spine5, tail1, tail2, tail3, tail4], edges=[spine3->spine2, spine2->spine1, spine1->lefteye, spine1->leftear, spine1->nose, spine1->righteye, spine1->rightear, spine3->spine4, spine4->spine5, spine5->tail1, tail1->tail2, tail2->tail3, tail3->tail4], symmetries=[lefteye<->righteye, leftear<->rightear])
Videos: ['2020-3-6_0day_5mins_compressedTalmo.mp4', '2020-3-6_1night_5mins_compressedTalmo.mp4', '2020-3-7_daytime_0_300sec_compressedTalmo.mp4', '2020-3-7_nighttime_0_300sec_compressedTalmo.mp4', '2020-3-8_0day_5mins_compressedTalmo.mp4', '2020-3-8_1night_5mins_compressedTalmo.mp4', '2020-3-9_1day_5mins_compressedTalmo.mp4', '2020-3-9_1night_5mins_compressedTalmo.mp4', '2020-3-10_daytime_5mins_compressedTalmo.mp4', '2020-3-10_nighttime_5mins_2_compressedTalmo.mp4', '2020-3-10_nighttime_5mins_compressedTalmo.mp4', '2020-3-11_day_5mins_compressedTalmo.mp4', '2020-3-11_night_5mins_compressedTalmo.mp4', '2020-3-12_day_5mins_compressedTalmo.mp4', '2020-3-12_night_5mins_compressedTalmo.mp4', '2020-3-13_day_5mins_compressedTalmo.mp4', '2020-3-13_night_5mins_compressedTalmo.mp4', '2020-3-14_day_5mins_compressedTalmo.mp4', '2020-3-15_day_5mins_compressedTalmo.mp4', '2020-3-15_night_5mins_compressedTalmo.mp4', '2020-3-16_day_5mins_compressedTalmo.mp4', '2020-3-16_night_5mins_compressedTalmo.mp4', '2020-3-17_night_5mins_compressedTalmo.mp4']
Frames (user/predicted): 800/3,650
Instances (user/predicted): 2,897/9,921
Tracks: [Track(spawned_on=1998, name='female'), Track(spawned_on=1998, name='male'), Track(spawned_on=1998, name='pup shaved'), Track(spawned_on=1998, name='pup unshaved')]
Suggestions: 4,527
Provenance: {}

Here's also the setup file for the 800 frames .slp file:

INFO:sleap.nn.training:Creating training and validation splits from validation fraction: 0.1
INFO:sleap.nn.training:  Splits: Training = 720 / Validation = 80
INFO:sleap.nn.training:Setting up for training...
INFO:sleap.nn.training:Setting up pipeline builders...
INFO:sleap.nn.training:Setting up model...
INFO:sleap.nn.training:Building test pipeline...
INFO:sleap.nn.training:Loaded test example. [6.316s]
INFO:sleap.nn.training:  Input shape: (448, 448, 3)
INFO:sleap.nn.training:Created Keras model.
INFO:sleap.nn.training:  Backbone: UNet(stacks=1, filters=32, filters_rate=1.5, kernel_size=3, stem_kernel_size=7, convs_per_block=2, stem_blocks=0, down_blocks=6, middle_block=True, up_blocks=4, up_interpolate=True, block_contraction=False)
INFO:sleap.nn.training:  Max stride: 64
INFO:sleap.nn.training:  Parameters: 7,011,937
INFO:sleap.nn.training:  Heads: 
INFO:sleap.nn.training:    [0] = CenteredInstanceConfmapsHead(part_names=['nose', 'lefteye', 'righteye', 'leftear', 'rightear', 'spine1', 'spine2', 'spine3', 'spine4', 'spine5', 'tail1', 'tail2', 'tail3', 'tail4'], anchor_part='spine3', sigma=2.5, output_stride=4, loss_weight=1.0)
INFO:sleap.nn.training:    [1] = ClassVectorsHead(classes=['female', 'male', 'pup shaved', 'pup unshaved'], num_fc_layers=3, num_fc_units=256, global_pool=True, output_stride=64, loss_weight=0.001)
INFO:sleap.nn.training:  Outputs: 
INFO:sleap.nn.training:    [0] = Tensor("CenteredInstanceConfmapsHead/BiasAdd_1:0", shape=(None, 112, 112, 14), dtype=float32)
INFO:sleap.nn.training:    [1] = Tensor("ClassVectorsHead/Softmax_1:0", shape=(None, 4), dtype=float32)
INFO:sleap.nn.training:Setting up data pipelines...
INFO:sleap.nn.training:Training set: n = 720
INFO:sleap.nn.training:Validation set: n = 80
INFO:sleap.nn.training:Setting up optimization...
INFO:sleap.nn.training:  Learning rate schedule: LearningRateScheduleConfig(reduce_on_plateau=False, reduction_factor=0.1, plateau_min_delta=1e-08, plateau_patience=10, plateau_cooldown=3, min_learning_rate=1e-08)
INFO:sleap.nn.training:  Early stopping: EarlyStoppingConfig(stop_training_on_plateau=False, plateau_min_delta=1e-08, plateau_patience=20)
INFO:sleap.nn.training:Setting up outputs...
INFO:sleap.nn.training:Created run path: models/gerbils.multiclass_topdown_6
INFO:sleap.nn.training:Setting up visualization...
INFO:sleap.nn.training:Finished trainer set up. [15.3s]
 DONE SETTING UP
arie-matsliah commented 3 years ago

Hi @catubc Thanks for the detailed report. Definitely something looks off as loss goes out of control. We'll look into this and get back to you with findings or if additional information is needed.

catubc commented 3 years ago

Yes, I'm pretty sure that I'm missing a training step.
Thanks!

talmo commented 3 years ago

Hey @catubc,

Closing this following our chat today but feel free to reach out again if it's still not working with the notebooks I sent you.

Cheers,

Talmo