Open askerlee opened 1 month ago
Sorry for the delayed reply. I have two messages to share with you:
I think the core of the problem lies in the fact that the image features output by the model are not processed by the linear layer when monocular input is given, but the model is trained with a linear layer (see model.py for details). Therefore, when performing the zero-shot task, the image features cannot be well aligned with the text features. So in addition to the method mentioned above, you can also try just using monocular features, but it requires you to simply rewrite the code as below. This is a flaw in our design because when performing the downstream tasks of linear probing and fine-tuning, whether or not the image features pass through the linear layer doesn't make a difference to the final result, so we didn't take this into account, sorry.
def encode_image(self, img_l, img_r, mask_ratio=0):
if img_r is None:
if isinstance(self.visual, ModifiedResNet):
# mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
vision_feature = self.visual(img_l.type(self.dtype))
return self.left_feature_mapping(vision_feature)
vision_feature = self.visual(img_l.type(self.dtype), mask_ratio)
return self.left_feature_mapping(vision_feature)
if img_l is None:
if isinstance(self.visual, ModifiedResNet):
# mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
vision_feature = self.visual(img_r.type(self.dtype))
return self.right_feature_mapping(vision_feature)
vision_feature = self.visual(img_r.type(self.dtype), mask_ratio)
return self.right_feature_mapping(vision_feature)
if isinstance(self.visual, ModifiedResNet):
# mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
left_feature = self.visual(img_l.type(self.dtype))
right_feature = self.visual(img_r.type(self.dtype))
vision_feature = torch.cat(
(left_feature, right_feature), dim=1)
return self.global_feature_mapping(vision_feature), self.single_feature_mapping(
left_feature), self.single_feature_mapping(right_feature)
left_feature = self.visual(img_l.type(self.dtype), mask_ratio)
right_feature = self.visual(img_r.type(self.dtype), mask_ratio)
vision_feature = torch.cat(
(left_feature, right_feature), dim=1)
return self.global_feature_mapping(vision_feature), self.left_feature_mapping(
left_feature), self.right_feature_mapping(right_feature)
synfundus-sel.zip
I tested the 5 images in the zip archive above. These images were selected from the SynFundus (https://github.com/parap1uie-s/SynFundus-1M) synthetic dataset, and each image comes with a caption.
The corresponding meta-information of the 5 images is:
The label file is:
However, all 5 images were best matched with "Myopia 近视". This is weird, as myopia (is_pm, "is pathological myopia") is all 0 for all the 5 images, and the captions don't mention myopia.