openmedlab / PathoDuet

Other
181 stars 9 forks source link

get_feature #2

Closed 464hee closed 1 year ago

464hee commented 1 year ago

What method should I use to obtain features similar to CTransPath returned by P1 or P2 models?

464hee commented 1 year ago

@zhangxiaofan101 @Shentl @hsymm

Shentl commented 1 year ago

1. Get and load the model: If you run our training code and save the checkpoint, the keys in the state_dict will be started with ‘module.base_encoder’ or ‘module.decoder’. And you should get the clean model with “state_dict[k[len("module.base_encoder."):]] = state_dict[k]” If you download our released model (already the cleaned model), then just load it as shown in the README.md

2. Define model.head If you want to get features, use “model.head = nn.Identity()”, and if you want to perform downstream classification tasks, use “model.head = nn.Linear(768, args.num_classes)”.

3. Get result If you have correctly defined model.head, then use “_, output = model(images)” to get the feature or classification result.

Note1: Considering the gap between pathological images and natural images, we do not use a normalize function in data augmentation during training. Therefore, if the data normalization function is used in downstream tasks, the performance will drop significantly

No data normalization! No data normalization! No data normalization!

Note2: Do not use main_cls.py,which is from moco-v3. We will update this file in a few days.

Note3: Do not use the output of model.forward_features() and then average the second dim by yourself. That is because when we set global_pool='avg', model.fc_norm will not be nn.Identity() as in the normal ViT

Note4: We recommend using "model = VisionTransformerMoCo(bridge_token=True, global_pool='avg')" which includes our pretext_token in the model. However, you can also use "model = timm.create_model('vit_base_patch16_224', pretrained=False, global_pool='avg')" and directly load the parameters into a regular vit-base/16 model. We have compared these two different methods and found that there is not much difference in the performance of downstream tasks , with the first approach showing a slight advantage over the second one.

464hee commented 1 year ago

I followed the execution in the sample code and got the following error

super().init(**kwargs) TypeError: init() got an unexpected keyword argument 'bridge_token'

VisionTransformerMoCo no bridge_token

I removed this parameter, hopefully it's a correct approach。

I've been able to get the features, thanks

##########################################################

from vits import VisionTransformerMoCo import torch.nn as nn import torch from PIL import Image

import torchvision.transforms as transforms model = VisionTransformerMoCo( global_pool='avg') model.head = nn.Identity() checkpoint = torch.load('checkpoint_p2.pth', map_location="cpu") model.load_state_dict(checkpoint, strict=False)

transforms_img = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor()] ) img = Image.open('1103.png') img = img.convert('RGB') img_new = transforms_img(img).unsqueeze(0) print(imgnew.shape)#torch.Size([1, 3, 224, 224]) ,output_feat = model(img_new) print(output_feat.shape) #torch.Size([1, 768])

464hee commented 1 year ago

No data normalization! No data normalization! No data normalization! Thanks for repeating the reminder, I was wondering if I could follow this and do a simple processing of the image that could be output better by the model:

############### transforms_img = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor()] )

Shentl commented 1 year ago

No data normalization! No data normalization! No data normalization! Thanks for repeating the reminder, I was wondering if I could follow this and do a simple processing of the image that could be output better by the model:

############### transforms_img = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor()] )

Sure, you could use any data augmentation method other than data normalization, such as RandomResizedCrop, RandomHorizontalFlip, etc.

Shentl commented 1 year ago

I followed the execution in the sample code and got the following error

super().init(kwargs) TypeError: init**() got an unexpected keyword argument 'bridge_token'

This is our fault, the bridge_token is something in the original code, we will update the model immediately

Shentl commented 1 year ago

Thanks for your issue. We are uploading new models to fix the bug, and we recommend you replacing the key 'bridge_token' with 'pretext_token' to fix it if you don't want to download the new models again.

464hee commented 1 year ago

ok,Thank you for your reply.