Tencent / MedicalNet

Many studies have shown that the performance on deep learning is significantly affected by volume of training data. The MedicalNet project provides a series of 3D-ResNet pre-trained models and relative code.
Other
1.87k stars 409 forks source link

Utilizing resnet_50.pth for 3D Feature Map Extraction #83

Open aeinkoupaei opened 5 months ago

aeinkoupaei commented 5 months ago

Hi, I want to use resnet_50.pth pre-trained encoder to extract 3D feature maps from medical images. Is the following method correct? It seems strange that the parameters of width, height, depth and number of channels can be adjusted manually. Isn't it the case that the resnet_50.pth pre-trained model is trained with a specific architecture, length, width, height, and channel? Therefore, shouldn't the input of the trained model for extracting 3D feature maps have the same dimensions as inputs of the model in the training phase?

resnet50 = resnet50( sample_input_D=32, sample_input_H=256, sample_input_W=256, shortcut_type='B', no_cuda=True, num_seg_classes=1 ) pretrain = torch.load("pretrain/resnet_50.pth") # Load the weights from the pretrained file pretrained_dict = pretrain['state_dict'] new_state_dict = OrderedDict() for k, v in pretrained_dict.items(): name = k[7:] # Remove 'module.' new_state_dict[name] = v resnet10.load_state_dict(new_state_dict, strict=False)

A_img_feature_map = resnet50(A_img)

Ram2314 commented 2 months ago

Hi! @aeinkoupaei Did you figure out if this is the correct approach to get the feature map from a image? Also needing to use this make a feature map. Is there any reason you made num_seg_classes=1?

aeinkoupaei commented 2 months ago

Hi @Ram2314,

To use a pre-trained ResNet model for extracting 3D feature maps, you'll need to focus on the ResNet class within the resnet.py file. Here's what to change: 1- Modifying the init function of ResNet class: Delete the entire self.conv_seg block. This removes the unnecessary layers for our feature extraction task. 2- Modifying the forward function of ResNet class: Delete this line: x = self.conv_seg(x). This ensures the model doesn't perform the final segmentation prediction, but instead outputs the feature map before that stage.

Here's an example of how to use a pre-trained ResNet-10 model for feature extraction:

resnet_10 = resnet10(shortcut_type='B', no_cuda=True) pretrain = torch.load("resnet_10_23dataset.pth") pretrained_dict = pretrain['state_dict'] new_state_dict = OrderedDict() for k, v in pretrained_dict.items(): name = k[7:] new_state_dict[name] = v

resnet_10.load_state_dict(new_state_dict)

Ram2314 commented 2 months ago

@aeinkoupaei Awesome thanks! Also did you figure out the length, width, height issue? Seems I can set it to anything and it will work?