trzy / FasterRCNN

Clean and readable implementations of Faster R-CNN in PyTorch and TensorFlow 2 with Keras.
137 stars 32 forks source link

Please support to create new backbone based on ViT #9

Open tommyngx opened 1 year ago

tommyngx commented 1 year ago

Dear @trzy,

Thank for great repo. I am trying to try the new backbone ViT from your source code. I using the similar template from file: vgg16_torch.py which modify the line 67: vgg16 = torchvision.models.vgg16(weights = torchvision.models.VGG16_Weights.IMAGENET1K_V1, dropout = dropout_probability) to ViT = torchvision.models.vit_b_16(weights=ViT_B_16_Weights.DEFAULT)

Based on ViT concept, the feature should be like # Expand the class token to the full batch batch_class_token = vit.class_token.expand(img.shape[0], -1, -1) feats = torch.cat([batch_class_token, feats], dim=1) feats = vit.encoder(feats) We're only interested in the representation of the classifier token that we appended at position 0. feats = feats[:, 0]

I am still get lost to fix FeatureExtractor function to fix the concept. Please assist if possible. Many thanks!