sStonemason / RET-CLIP

RET-CLIP: A Retinal Image Foundation Model Pre-trained with Clinical Diagnostic Reports
15 stars 2 forks source link

inaccurate zero-shot inference #5

Open askerlee opened 1 month ago

askerlee commented 1 month ago

synfundus-sel.zip

I tested the 5 images in the zip archive above. These images were selected from the SynFundus (https://github.com/parap1uie-s/SynFundus-1M) synthetic dataset, and each image comes with a caption.

The corresponding meta-information of the 5 images is:

file_md5    file_name   eyes_side   is_abnormal is_amd  is_aon  is_crp  is_dm   is_dme  dr_grade    is_em   is_gc   is_md   is_mh   is_htr  has_lesion  is_pm   is_rvo  is_tessellated  is_treated  is_fundus   is_macular_readable is_optic_disc_readable  is_retinal_region_readable  retinal_region_quality_score    diagnosticAdvice
8294bb6cfa7aeb46a5e5c7bf88680289    4899.png    left    0   0   0   0   0   0   0   0   0   1   0   0   0   0   0   0   0   1   1   1   1   9   疑似黄斑区域病变。
2c7070104be097670fe6e2b1fb67e185    4923.png    left    0   0   0   0   0   0   0   0   0   0   0   0   1   0   0   0   0   1   1   1   1   9   可见出血斑。
4c1149d2ec10cff917b7802ce190f10d    4977.png    left    0   0   0   1   0   0   0   0   0   1   0   0   0   0   0   0   0   1   1   1   1   9   疑似脉络膜视网膜病变,疑似黄斑区域病变。
82162645745bc1371cc2aea9ce923590    5048.png    left    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   1   1   1   1   5   疑似高血压性视网膜病变(中度)。
108247fbe1947933a3022b74154d8010    5053.png    left    0   0   0   0   0   0   0   0   1   0   0   0   0   0   0   0   0   1   1   1   1   9   疑似青光眼。

The label file is:

Normal Healthy 正常,健康
Diabetic Retinopathy 糖尿病视网膜病变
Age-related Macular Degeneration 年龄相关性黄斑变性
Anomalies of the Optic Nerve 视神经异常
Choroidal Retinal Vascular 脉络膜视网膜血管
Diabetic Macular Edema 糖尿病黄斑水肿
Epimacular Membrane 黄斑上膜
Glaucoma 青光眼
Hypertensive Retinopathy 高血压视网膜病变
Myopia 近视
Retinal Vein Occlusion 视网膜静脉阻塞
黄斑区域病变
可见出血斑
脉络膜视网膜病变

However, all 5 images were best matched with "Myopia 近视". This is weird, as myopia (is_pm, "is pathological myopia") is all 0 for all the 5 images, and the captions don't mention myopia.

sStonemason commented 1 month ago

Sorry for the delayed reply. I have two messages to share with you:

  1. I would recommend using a single image as both img_l and img_r as input, i.e. img_features = model(imgs,imgs,None). Also, the text information should be obtained by text_projection at patient level. This matches the design of the training phase and has been tested by me on multiple downstream tasks, which leads to optimal results, far superior to the results of just using monocular level features. However, please understand that I cannot publicize this part of the experimental results for now.
  2. Noted that the image and text data you used are generated by AI models, after my testing, there is indeed such a issue as you described. In addition to our model, I have also tested other foundation models designed for fundus images, and again, I was unable to accurately achieve this zero-shot task. You can try the method I mentioned in 1. and observe if the accuracy improves. However, I think there are natural problems with AI-generated data and it is not suitable for use as model training and testing.
sStonemason commented 1 month ago

I think the core of the problem lies in the fact that the image features output by the model are not processed by the linear layer when monocular input is given, but the model is trained with a linear layer (see model.py for details). Therefore, when performing the zero-shot task, the image features cannot be well aligned with the text features. So in addition to the method mentioned above, you can also try just using monocular features, but it requires you to simply rewrite the code as below. This is a flaw in our design because when performing the downstream tasks of linear probing and fine-tuning, whether or not the image features pass through the linear layer doesn't make a difference to the final result, so we didn't take this into account, sorry.

def encode_image(self, img_l, img_r, mask_ratio=0):
    if img_r is None:
        if isinstance(self.visual, ModifiedResNet):
            # mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
            vision_feature = self.visual(img_l.type(self.dtype))
            return self.left_feature_mapping(vision_feature)
        vision_feature = self.visual(img_l.type(self.dtype), mask_ratio)
        return self.left_feature_mapping(vision_feature)
    if img_l is None:
        if isinstance(self.visual, ModifiedResNet):
            # mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
            vision_feature = self.visual(img_r.type(self.dtype))
            return self.right_feature_mapping(vision_feature)
        vision_feature = self.visual(img_r.type(self.dtype), mask_ratio)
        return self.right_feature_mapping(vision_feature)
    if isinstance(self.visual, ModifiedResNet):
        # mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
        left_feature = self.visual(img_l.type(self.dtype))
        right_feature = self.visual(img_r.type(self.dtype))
        vision_feature = torch.cat(
            (left_feature, right_feature), dim=1)
        return self.global_feature_mapping(vision_feature), self.single_feature_mapping(
            left_feature), self.single_feature_mapping(right_feature)
    left_feature = self.visual(img_l.type(self.dtype), mask_ratio)
    right_feature = self.visual(img_r.type(self.dtype), mask_ratio)
    vision_feature = torch.cat(
        (left_feature, right_feature), dim=1)
    return self.global_feature_mapping(vision_feature), self.left_feature_mapping(
        left_feature), self.right_feature_mapping(right_feature)