mlmed / torchxrayvision

TorchXRayVision: A library of chest X-ray datasets and models. Classifiers, segmentation, and autoencoders.
https://mlmed.org/torchxrayvision
Apache License 2.0
896 stars 214 forks source link

Upload weights to TorchXrayVision github #121

Open kyungjincho opened 1 year ago

kyungjincho commented 1 year ago

I am curious about the progress of the issue you registered in the corresponding GitHub repository (https://github.com/mi2rl/CheSS/issues/23#issue-1560703975). I'm really honored that the weights of the model were uploaded to TorchXrayVision. When will it be uploaded?

ieee8023 commented 1 year ago

Hey! We are coordinating here on who has the time to do it. Sorry for the delay! I'll get back to you with an update in a few days.

ieee8023 commented 1 year ago

Hey so some questions for you. I was able to get the model loaded by using the resnet code in your repo.

I was able to extract just the encoder_q weights and it is much smaller than the entire checkpoint (300MB->100MB). There is an fc layer in the checkpoint that predicts 128 outputs. Do you have the labels for these outputs? I can add it so the model will predict the outputs and also have a .features() function that will extract just the features from an input. It would be nice to allow people to directly compare downstream performance with your model.

I see the input image size is 512x512. Can you tell me how the normalization should work? If I give you a numpy array that is normalized between 0 and 1 how should it be normalized to go into this model?

ieee8023 commented 1 year ago

Would this be the right normalization? torchvision.transforms.Normalize(0.658, 0.221)

kyungjincho commented 1 year ago

That's right. We trained MoCo-v2, which uses ResNet50 as its backbone, as a self-supervised learning method. MoCo-v2 uses two image encoders, one of which learns the encoder through momentum updates (encoder_k). Therefore, we use the state_dict of encoder_q.

Additionally, since the model is trained as a self-supervised learning method, the 128 prediction (for contrastive learning) FC layer should be dropped, and this part should be transformed to fit the downstream task that the user desires.

The 512 x 512 image size was preprocessed as follows. We referred to the paper (https://doi.org/10.1016/j.cmpb.2022.106705) and carried out preprocessing and normalization.

The value of torchvision.transforms.Normalize(0.658, 0.221) is the mean std of our training dataset, which was used for normalization. (0-255 pixel value range)


import os
import cv2
import pydicom
import numpy as np

# our  preprocessing
def cutoff(image):
    img = image
    img = cv2.resize(img,(512,512))
    img = np.clip(img, 0, np.percentile(img, 99))
    img -= img.min()
    img /= (img.max() - img.min())
    img *= 255
    img = img.astype(np.uint8)
    return img

temp_dcm = pydicom.read_file(img_file, force=True)
temp_dcm_img = temp_dcm.pixel_array
if temp_dcm[0x28, 0x4].value.lower() == "monochrome1":
    temp_dcm_img = np.max(temp_dcm_img) - temp_dcm_img

base_img_cutoff = cutoff(temp_dcm_img)
base_image_cutoff = cv2.cvtColor(base_img_cutoff, cv2.COLOR_GRAY2BGR) # image pixel value 0~255

# We normalize data using torchvision.transforms.Normalize(0.658, 0.221)

ieee8023 commented 1 year ago

I have a PR open here: https://github.com/mlmed/torchxrayvision/pull/122 The model weights will download automatically.

I'm not sure the processing is correct because a UMAP of the representations doesn't look good.

I made a demo notebook showing things. If you edit the paths in dataset_utils.py you can load the datasets. https://github.com/mlmed/torchxrayvision/blob/21a94e6d884910ac741a65457a2755fe6ec61c35/scripts/xray_representations2.ipynb

Here is what I expect them to look like: https://github.com/mlmed/torchxrayvision/blob/master/scripts/xray_representations.ipynb

kyungjincho commented 1 year ago

I have seen the codes you showed me. I have a couple of concerns:

1) CheSS is a model trained with self-supervised learning method, so downstream tasks like supervised learning require label information. 2) CheSS takes input in the range of -2.5 to 2.5 after torchvision's normalize, but it seems that now it takes input values between 0 and 255 through the transform_from_xrv() function.

ieee8023 commented 1 year ago

So if I do this with torchvision.transforms.Normalize(0.658, 0.221)

x -= x.min()
x /= (x.max() - x.min())
x = self.normalize(x)

The values appear to be between -3 and 3 but the UMAP plots still look wrong. The UMAP has the label information for 4 targets. Each CXR only has a single pathology to avoid issues with multiple positive labels. It is essentially like looking at a KNN classifier.

image

If I scale them between 0-255 before the normalize the values are between -3 and 1150 and the UMAP looks like this:

x -= x.min()
x /= (x.max() - x.min())
x *= 255
x = self.normalize(x)

image

Any ideas? Can you make a small end to end script that will take in an image file and compute the embedding? Then I can figure out where the code diverges.

kyungjincho commented 1 year ago

We modified our image classification module in (https://github.com/mi2rl/CheSS/blob/main/downstream/classification/datasets.py). You can check if the dataset preprocessing is correct.

I believe that our method will be helpful for users to train a separate downstream task, but I don't think that the current weights trained using contrastive learning method alone will result in clustering results like the ones you showed me before

Is there anything else I can help you with?

ieee8023 commented 1 year ago

Ok! I added the normalize after the 0,255 scale and I believe it is matched to your code now: https://github.com/mlmed/torchxrayvision/pull/122/files#diff-4e3fee5eb94fc6335266cbd042c570aec49d5894e54b0ae480ecbc1046273b07R49

If you agree then I will prepare to merge it in and add documentation. Do you want to try it out to make sure it is correct?

kyungjincho commented 1 year ago

Okay. Where can I find the added documentation?

ieee8023 commented 1 year ago

Nothing fancy but I added it to the readme in the PR here: https://github.com/mlmed/torchxrayvision/tree/mi2rl-chess

kyungjincho commented 1 year ago

Great, I hope there will be some research or work that we can do together later.

ieee8023 commented 1 year ago

Were you able to test and confirm that the model outputs are correct? I have a fear that something is not correct and it will hurt the performance of your model when people benchmark against it. I'd prefer to double check now than later

kyungjincho commented 1 year ago

Unlike the existing supervised learning model, the Self-supervised learning model requires fine-tuning using labeled dataset even if it is a small data set.

I observed that loss decreases well when you fine-tuned the model based on the models you made with baseline_models.

CheSS: Chest X-Ray Pre-trained Model via Self-supervised Contrastive Learning
Outputs a 2048 dimensional feature vector
model = xrv.baseline_models.mi2rl.CheSS()

image

ieee8023 commented 1 year ago

Were you able to achieve the same results as in your paper? If there is something wrong with the code it will likely train properly but the performance wouldn't be that high. In the paper I see a 0.808 AUC number for chexpert. Is that the data you are testing on in the plot above? The train AUC should be even higher right?

kyungjincho commented 1 year ago

I didn't experiment with chexpert dataset. I think it will take some time to set up the experiment because I removed the chexpert dataset from NAS.