paninski-lab / lightning-pose

Accelerated pose estimation and tracking using semi-supervised convolutional networks.
MIT License
238 stars 35 forks source link

Difference between backbones #197

Open ReetKaur15 opened 1 week ago

ReetKaur15 commented 1 week ago

Hi @themattinthehatt,

Could you please explain main difference between these backbones: resnet50_human_jhmdb | resnet50_human_res_rle | resnet50_human_top_res ?

Additionally, please share the path for these backbones models.

Many Thanks in advance, Cheers, Reet :)

Shmuel-columbia commented 1 week ago

Hi, Thank you for the question.
The team is currently on vacation and will be back next week.

themattinthehatt commented 3 days ago

@ReetKaur15 you can find more detailed descriptions of the different backbones here: https://lightning-pose.readthedocs.io/en/latest/source/user_guide/config_file.html#model-training-parameters

and you can find the hard-coded paths to the network weights here: https://github.com/paninski-lab/lightning-pose/blob/47ee289110fb2ef2519091b49a5658fab07b4bf4/lightning_pose/models/backbones/torchvision.py#L18

please let me know if you have any further questions!

ReetKaur15 commented 3 days ago

Hi @themattinthehatt , Thanks for the response.

No information about "resnet50_human_top_res" is given. Could you please share more details about it?

themattinthehatt commented 3 days ago

Ah yes I see I missed that one, apologies.

resnet50_human_res_rle: a regression-based ResNet-50 pretrained on MPii dataset (Andriluka et al 2014, 2D Human Pose Estimation: New Benchmark and State of the Art Analysis)

resnet50_human_top_res: a heatmap-based ResNet-50 pretrained on MPii dataset (Xiao et al 2018, Simple Baselines for Human Pose Estimation and Tracking)

Both will be compatible with all Lightning Pose model types (supervised and context, with or without unsupervised losses) but I have not tested either of them so cannot say more. If you end up testing both out I'd be curious to hear how they compare!

ReetKaur15 commented 3 days ago

@themattinthehatt, Thank you for quick response.

As per our previous discussion, I trained the vision transformer with "vit_b" architecture by passing the same parameters and hyperparameters as set in your code., given below. Unfortunately, the results are not much good as compared to the ResNet-152 backbone.

elif "vit_b_sam" in backbone_arch: ckpt_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" state_dict = torch.hub.load_state_dict_from_url(ckpt_url) new_state_dict = {} for key in state_dict: new_key = key.replace('image_encoder.', '') new_key = new_key.replace('mask_decoder.', '') new_state_dict[new_key] = state_dict[key] encoder_embed_dim = 768 encoder_depth = 12 encoder_num_heads = 12 encoder_global_attn_indexes = (2, 5, 8, 11) prompt_embed_dim = 256 finetune_image_size = image_size image_size = 1024 vit_patch_size = 16 base = ImageEncoderViT_FT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, finetune_img_size=finetune_image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=False, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, )

I am using 8300 frames from 83 videos (100 frames/video) and testing on unseen baby videos, training with 1569 epochs (around 700,000 iterations), batch_size=16.

To improve the results, I tried to change the vit_patch_size = 8 and window_size = 7. It pops up an error that I cannot adjust these hyperparameters because the pre-trained model was trained with patch size 16.

What hyperparameters or parameters can be adjusted to enhance the results? Alternatively, could the results be limited due to insufficient training data?

themattinthehatt commented 3 days ago

Yeah unfortunately some of those ViT parameters are fixed given the pretrained weights we're using.

We have also found the ViT doesn't clearly outperform resnets on the handful of datasets we've tried it on.

Part of the issue might be in the training - the learning rate and learning rate scheduler are somewhat optimized for resnets, and might need to be adjusted for the ViT. We played around with this a bit (changed the learning rate, tried cosine annealing) but again weren't able to get any easy wins.

Insufficient training data could certainly be a problem, given the large number of parameters in the transformer. How does the validation loss look during training, does it continue to decrease even after 1k+ epochs? There might be some overfitting going on.