Haochen-Wang409 / U2PL

[CVPR'22 & IJCV'24] Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels & Using Unreliable Pseudo-Labels for Label-Efficient Semantic Segmentation
Apache License 2.0
436 stars 61 forks source link

Using pre-trained weights from vision transformer #89

Closed elle-miller closed 2 years ago

elle-miller commented 2 years ago

Hi there, thanks for the great tool!

Would it be possible to use pre-trained weights from a vision transformer, instead of resnet? I had a quick go, but quickly ran into a memory error.

e.g. using https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16 https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_small_p16_384_20220308-410f6037.pth

I am curious if you think this would be possible or have any thoughts, before I continue.

Thank you very much,

Elle

Haochen-Wang409 commented 2 years ago

Hi, thanks for your attention and approval!

Since you have encountered an out of memory error using ViT-B/16, it is better to try reducing batch_size in config.yaml.

elle-miller commented 2 years ago

Thanks for fast response! My batch size is already =1. I was wondering if I needed to add any code to u2pl/models to get it working?

Haochen-Wang409 commented 2 years ago

I checked the number of parameters of ResNet-101 and ViT-B/16, which are 85M and 86M, respectively.

Does the original ResNet-101 work well?