batmanlab / Mammo-CLIP

Official Pytorch implementation of MICCAI 2024 paper (early accept, top 11%) Mammo-CLIP: A Vision Language Foundation Model to Enhance Data Efficiency and Robustness in Mammography
https://shantanu-ai.github.io/projects/MICCAI-2024-Mammo-CLIP/
Creative Commons Attribution 4.0 International
33 stars 11 forks source link

Script for Testing Mammo-CLIP Model #10

Closed GuilhermeJC13 closed 1 month ago

GuilhermeJC13 commented 3 months ago

Congratulations on your impressive work with the Mammo-CLIP. I am currently interested in testing this model on a new dataset and analyzing the resulting embeddings. Please let me know if there is an existing script available for this purpose.

Thank you for your time and assistance.

shantanu-ai commented 3 months ago

Hi, Thanks for taking interest in our work. So u want to get the embeddings from Mammo-CLIP encoders, so that u can test it on a new data, is my understanding correct? Do u want the embeddings from the vision encoder or text encoder?

shantanu-ai commented 3 months ago

If you want to get the vision embeddings you can refer here

Also you can follow the workflow for the linear probe:

  python ./src/codebase/train_classifier.py \
    --data-dir '/restricted/projectnb/batmanlab/shawn24/PhD/RSNA_Breast_Imaging/Dataset' \
    --img-dir 'External/Vindr/vindr-mammo-a-large-scale-benchmark-dataset-for-computer-aided-detection-and-diagnosis-in-full-field-digital-mammography-1.0.0/images_png' \
    --csv-file 'External/Vindr/vindr-mammo-a-large-scale-benchmark-dataset-for-computer-aided-detection-and-diagnosis-in-full-field-digital-mammography-1.0.0/vindr_detection_v1_folds.csv' \
    --clip_chk_pt_path "/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar" \
    --data_frac 1.0 \
    --dataset 'ViNDr' \
    --arch 'upmc_breast_clip_det_b5_period_n_lp' \
    --label "Mass" \
    --epochs 30 \
    --batch-size 8 \
    --num-workers 0 \
    --print-freq 10000 \
    --log-freq 500 \
    --running-interactive 'n' \
    --n_folds 1 \
    --lr 5.0e-5 \
    --weighted-BCE 'y' \
    --balanced-dataloader 'n' 

This script will get the embeddings from the encoders and train a linear classifier at the same time. If you go to experiments.py file (line-296 and 297) and breast_clip_classifier.py (Line-53-56), you get the embeddings. From breast_clip_classifier.py file, you can retrun the embedding directly and save it in the experiments.py file.

shantanu-ai commented 3 months ago

@GuilhermeJC13 You can use this:

import torch
import gc
import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import f1_score
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
import sys
sys.path.append('/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/codebase/')

from Classifiers.models.breast_clip_classifier import BreastClipClassifier
from Datasets.dataset_utils import get_dataloader_RSNA
from breastclip.scheduler import LinearWarmupCosineAnnealingLR
from metrics import pfbeta_binarized, pr_auc, compute_auprc, auroc, compute_accuracy_np_array
from utils import seed_all, AverageMeter, timeSince
from breastclip.model.modules import load_image_encoder, LinearClassifier

class Args:
    def __init__(self):
        self.tensorboard_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log'
        self.checkpoints = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/checkpoints'
        self.output_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/out'
        self.data_dir = '/restricted/projectnb/batmanlab/shared/Data/RSNA_Breast_Imaging/Dataset'
        self.img_dir = 'RSNA_Cancer_Detection/train_images_png'
        self.clip_chk_pt_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar'
        self.csv_file = 'RSNA_Cancer_Detection/train_folds.csv'
        self.dataset = 'RSNA'
        self.data_frac = 1.0
        self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
        self.label = 'cancer'
        self.detector_threshold = 0.1
        self.swin_encoder = 'microsoft/swin-tiny-patch4-window7-224'
        self.pretrained_swin_encoder = 'y'
        self.swin_model_type = 'y'
        self.VER = '084'
        self.epochs_warmup = 0
        self.num_cycles = 0.5
        self.alpha = 10
        self.sigma = 15
        self.p = 1.0
        self.mean = 0.3089279
        self.std = 0.25053555408335154
        self.focal_alpha = 0.6
        self.focal_gamma = 2.0
        self.num_classes = 1
        self.n_folds = 4
        self.start_fold = 0
        self.seed = 10
        self.batch_size = 1
        self.num_workers = 4
        self.epochs = 9
        self.lr = 5.0e-5
        self.weight_decay = 1e-4
        self.warmup_epochs = 1
        self.img_size = [1520, 912]
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.apex = 'y'
        self.print_freq = 5000
        self.log_freq = 1000
        self.running_interactive = 'n'
        self.inference_mode = 'n'
        self.model_type = "Classifier"
        self.weighted_BCE = 'n'
        self.balanced_dataloader = 'n'

# Create an instance of the Args class
args = Args()

# Now you can use args just like you would in your script
print(args.tensorboard_path) 
# /restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log
args.model_base_name = 'efficientnetb5'
args.data_dir = Path(args.data_dir)
args.df = pd.read_csv(args.data_dir / args.csv_file)
args.df = args.df.fillna(0)
args.cur_fold = 0
args.train_folds = args.df[
                (args.df['fold'] == 1) | (args.df['fold'] == 2)].reset_index(drop=True)
args.valid_folds = args.df[args.df['fold'] == args.cur_fold].reset_index(drop=True)

print(f"train_folds shape: {args.train_folds.shape}")
print(f"valid_folds shape: {args.valid_folds.shape}")
# train_folds shape: (27258, 15)
# valid_folds shape: (13682, 15)

ckpt = torch.load(args.clip_chk_pt_path, map_location="cpu")
args.image_encoder_type = ckpt["config"]["model"]["image_encoder"]["name"]
train_loader, valid_loader = get_dataloader_RSNA(args)
print(f'train_loader: {len(train_loader)}, valid_loader: {len(valid_loader)}')
# Compose([
#   HorizontalFlip(p=0.5),
#   VerticalFlip(p=0.5),
#   Affine(p=0.5, interpolation=1, mask_interpolation=0, cval=0.0, mode=0, scale={'x': (0.8, 1.2), 'y': (0.8, 1.2)}, translate_percent={'x': (0.1, 0.1), 'y': (0.1, 0.1)}, translate_px=None, rotate=(20.0, 20.0), fit_output=False, shear={'x': (20.0, 20.0), 'y': (20.0, 20.0)}, cval_mask=0.0, keep_ratio=False, rotate_method='largest_box', balanced_scale=False),
#   ElasticTransform(p=0.5, alpha=10.0, sigma=15.0, interpolation=1, border_mode=4, value=None, mask_value=None, approximate=False, same_dxdy=False),
# ], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)
# None
# train_loader: 3407, valid_loader: 1711

n_class = 1
print(ckpt["config"]["model"]["image_encoder"])
config = ckpt["config"]["model"]["image_encoder"]
image_encoder = load_image_encoder(ckpt["config"]["model"]["image_encoder"])
image_encoder_weights = {}
for k in ckpt["model"].keys():
    if k.startswith("image_encoder."):
        image_encoder_weights[".".join(k.split(".")[1:])] = ckpt["model"][k]
image_encoder.load_state_dict(image_encoder_weights, strict=True)
image_encoder_type = ckpt["config"]["model"]["image_encoder"]["model_type"]
image_encoder = image_encoder.to(args.device)

print(image_encoder_type)
print(config["name"].lower()) 
# cnn
# tf_efficientnet_b5_ns-detect

progress_iter = tqdm(enumerate(valid_loader), desc=f"[tutorial]",
                     total=len(valid_loader))
for step, data in progress_iter:
    inputs = data['x'].to(args.device)
    inputs = inputs.squeeze(1).permute(0, 3, 1, 2)
    batch_size = inputs.size(0)

    image_features = image_encoder(inputs)
    print(image_features.shape)
    break
    # torch.Size([1, 2048])
GuilhermeJC13 commented 2 months ago

@shantanu-ai Is the clip_chk_pt_path, the path to the model available on huggingface?

shantanu-ai commented 2 months ago

https://huggingface.co/shawn24/Mammo-CLIP/tree/main/Pre-trained-checkpoints

GuilhermeJC13 commented 2 months ago

Hello @shantanu-ai,

Every time I try to run the scripts, I keep getting the error "KeyError: filename 'storages' not found" when I try to load the model using "ckpt = torch.load(args.clip_chk_pt_path)". I'm not sure what's causing this error. Could it be due to corrupted or improperly formatted models? Am I doing something wrong in the execution?

I unzipped the hugginface file and turned it into a .tar

shantanu-ai commented 2 months ago

Hi @GuilhermeJC13 The files are good. I think you probably set the path in an incorrect way. For b5, download only this file and follow this notebook. There was a bug in the above notebook for the calling of the encoder, I fixed it and tested it. So now it is good. Thanks for pointing out.

You only need b5-model-best-epoch-7.tar file.

Also this checkpoint is available at google drive

If the problem persists, can u please share the code?

shantanu-ai commented 2 months ago

Also, under the hood, the notebook is calling this function

If you want to modify anything custom, you can modify the forward function of the above method.

Also, we uploaded a tutorial notebook on setting up classifier using Mammo-CLIP vision encoder. U can take a look as well.

shantanu-ai commented 2 months ago

Let me know if you have further issues. If not, let me know if i can close the issue?

shantanu-ai commented 2 months ago

I am closing the issue. If you have further queries, let us know.

GuilhermeJC13 commented 1 month ago

Hi @shantanu-ai,

Thanks, this helped!

I plan to do the same thing I did with image encoding, but this time with vision and text encoders together. In short, I want to extract the actual embedding from the CLIP model. Do you know if you already have a script to get these embeddings?

I really appreciate your attention!

shantanu-ai commented 1 month ago

Hi @GuilhermeJC13 Can you clarify by "extract the actual embedding from the CLIP model"? Do u want to do that for the text encoder of Mammo-CLIP or u want from the actual clip?

For text embedding from Mammo-CLIP

def save_rsna_text_emb(clip_model, args):
    prompts = create_rsna_mammo_prompts()
    sentences_list_unique = save_sent_dict_rsna(args, sent_level=True)
    idx = 0
    text_embeddings_list = []
    with torch.no_grad():
        with tqdm(total=len(sentences_list_unique)) as t:
            for sent in sentences_list_unique:
                text_token = clip_model["tokenizer"](
                    sent, padding="longest", truncation=True, return_tensors="pt", max_length=256)

                text_emb = clip_model["model"].encode_text(text_token.to(args.device))
                text_emb = clip_model["model"].text_projection(text_emb) if clip_model["model"].projection else text_emb
                text_emb = text_emb / torch.norm(text_emb, dim=1, keepdim=True)
                text_emb = text_emb.detach().cpu().numpy()
                text_embeddings_list.append(text_emb)

                t.set_postfix(batch_id='{0}'.format(idx + 1))
                t.update()
                idx += 1

    text_emb_np = np.concatenate(text_embeddings_list, axis=0)
    print(f"Sent list shape: {len(sentences_list_unique)}")
    print(f"Text embedding shape: {text_emb_np.shape}")
    np.save(args.save_path / f"sent_emb_word_ge_{args.report_word_ge}.npy", text_emb_np)
    print(f"files saved at: {args.save_path}")

Note this code, I copied from another project of mine and that codebase is messy, so it may contain trivial errors which you can fix

For extracting embeddings from CLIP

We compared our model with CLIP as a baseline so did not save the embeddings of CLIP. If you want to setup the baseline, refer to this issue and then u can use the code I shared earlier.

Al-Dai commented 2 weeks ago

is possible to evaluate the model provided for downstream on HuggingFace. I don't know what I am doing wrong but I am defining the model

  n_class = 1
  model = BreastClipClassifier(args, ckpt=ckpt, n_class=n_class)
  model.load_state_dict(torch.load(args.clf_chk_pr_path)["model"])
  model = model.to(args.device)
  model.eval()

  where clf path   is Downstream_evalualtion_b5_fold0/classification/Models/Classifier/fine_tune/mass/upmc_breast_clip_det_b5_period_n_ft_seed_10_fold0_best_acc_cancer_ver084.pth  and doing the usual prediction 

  for step, data in progress_iter:
    inputs = data['x'].to(args.device)
    inputs = inputs.squeeze(1).permute(0, 3, 1, 2)
    batch_size = inputs.size(0)
    with torch.cuda.amp.autocast(enabled=True):
        y_preds = model(inputs)  # Get raw model outputs (logits)

        # Apply sigmoid activation to get probabilities
        probabilities = torch.sigmoid(y_preds)

        # Compare probabilities with threshold 0.5
        predictions = (probabilities >= 0.5).float()

        # Display predictions with labels
        for i, pred in enumerate(predictions):
            label = "Cancer" if pred == 1 else "No Cancer"
            print(f"Sample {i}: {label} (Probability: {probabilities[i].item():.4f})")

and testing this on the folder of rsna , and I am getting 18% correct , I don't know what I am doing wrong.

shantanu-ai commented 2 weeks ago

@Al-Dai Can you use the valid_fn() in this file. Also, did you preprocess the RSNA images with this script? Also, make sure the transforms are correct.

Al-Dai commented 2 weeks ago

sure, I will give it a try! thanks for quick response.

should I keep the args as defined in the notebook or I need to change them for the folder0-downstream weights during intiazation of breastclassifer class ?

class Args:
def __init__(self):
    self.tensorboard_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log'
    self.checkpoints = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/checkpoints'
    self.output_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/out'
    self.data_dir = '/restricted/projectnb/batmanlab/shared/Data/RSNA_Breast_Imaging/Dataset'
    self.img_dir = 'RSNA_Cancer_Detection/train_images_png'
    self.clip_chk_pt_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar'
    self.clf_chk_pr_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/temp/upmc_breast_clip_det_b5_period_n_ft_seed_10_fold0_best_aucroc_ver084.pth'
    self.csv_file = 'RSNA_Cancer_Detection/train_folds.csv'
    self.dataset = 'RSNA'
    self.data_frac = 1.0
    self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
    self.label = 'cancer'
    self.detector_threshold = 0.1
    self.swin_encoder = 'microsoft/swin-tiny-patch4-window7-224'
    self.pretrained_swin_encoder = 'y'
    self.swin_model_type = 'y'
    self.VER = '084'
    self.epochs_warmup = 0
    self.num_cycles = 0.5
    self.alpha = 10
    self.sigma = 15
    self.p = 1.0
    self.mean = 0.3089279
    self.std = 0.25053555408335154
    self.focal_alpha = 0.6
    self.focal_gamma = 2.0
    self.num_classes = 1
    self.n_folds = 4
    self.start_fold = 0
    self.seed = 10
    self.batch_size = 1
    self.num_workers = 4
    self.epochs = 9
    self.lr = 5.0e-5
    self.weight_decay = 1e-4
    self.warmup_epochs = 1
    self.img_size = [1520, 912]
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.apex = 'y'
    self.print_freq = 5000
    self.log_freq = 1000
    self.running_interactive = 'n'
    self.inference_mode = 'n'
    self.model_type = "Classifier"
    self.weighted_BCE = 'n'
    self.balanced_dataloader = 'n'
shantanu-ai commented 2 weeks ago

@Al-Dai The args and argparse in the train_classifier are the same. If you follow, it will go the BreastClassifier. So, have it accordingly. Important is to do the preprocessing of RSNA. Also, for a sanity check, u can do it on VinDr - mass, calcification and density classification. For VinDr, u dont need to perform preprocessing, we directly uploaded the preprocessed files here.

Al-Dai commented 2 weeks ago

Thanks, I did apply the correct transformation and it worked! thanks.

Another side question, If I want to train the Mammo-CLIP from scratch, say with torchrun --nproc_per_node=4 ./src/codebase/train.py --config-name pre_train_b5_clip.yaml, I would need to have upmc datasets right ? or is it possible to just train with vindr or rsna alone?

kayhan-batmanghelich commented 2 weeks ago

Training from scratch requires upmc dataset. Legally, we cannot release that dataset.

On Mon, Oct 28, 2024 at 10:37 AM Al-Dai @.***> wrote:

Thanks, I did apply the correct transformation and it worked! thanks.

Another side question, If I want to train the Mammo-CLIP from scratch, say with torchrun --nproc_per_node=4 ./src/codebase/train.py --config-name pre_train_b5_clip.yaml, I would need to have upmc datasets right ? or is it possible to just train with vindr or rsna alone?

— Reply to this email directly, view it on GitHub https://github.com/batmanlab/Mammo-CLIP/issues/10#issuecomment-2441773086, or unsubscribe https://github.com/notifications/unsubscribe-auth/AC53JXMNMCPBJLLVF2VYHJLZ5ZD27AVCNFSM6AAAAABMQ5KE5OVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINBRG43TGMBYGY . You are receiving this because you are subscribed to this thread.Message ID: @.***>

shantanu-ai commented 2 weeks ago

@Al-Dai As mentioned by Kayhan, the upmc dataset is private one. VinDr and RSNA do not have reports. To train Mammo-CLIP, u need at least some image+text datasets. The results will be better if you mix an image+label dataset (e.g, RSNA or VinDr) with the image+text dataset. So, if you have any image+text dataset, you can train Mammo-CLIP. Just follow the settings for UPMC for your own image+text data. The text means radiology reports.

Al-Dai commented 2 weeks ago

I understand, I went through your work, and I think it's excellent! I wanted to thank you.

One last question: how is the location identified with text? I read in the paper this line: '...With Mammo-FActOR, Mammo-CLIP vision encoder excels in localization tasks, accurately identifying findings like masses and calcifications using descriptive sentences, without relying on ground truth bounding boxes.' I was wondering how these lines are generated or if there is code for it.

shantanu-ai commented 2 weeks ago

@Al-Dai , So, this is weak localization using text. That's the Mammo-Factor part of the paper. Read Section 2.3. For a TL;DR: we have templated sentences constructed with the help of a radiologist. We use these sentences and the vision encoder of the trained Mammo-CLIP to train a lightweight projector (Eq.3 in the paper) to learn the mapping: which activation unit (neuron) in the representation from Mammo-CLIP corresponds to a mammography finding (mass or calcification).