Closed GuilhermeJC13 closed 1 month ago
Hi, Thanks for taking interest in our work. So u want to get the embeddings from Mammo-CLIP encoders, so that u can test it on a new data, is my understanding correct? Do u want the embeddings from the vision encoder or text encoder?
If you want to get the vision embeddings you can refer here
Also you can follow the workflow for the linear probe:
python ./src/codebase/train_classifier.py \
--data-dir '/restricted/projectnb/batmanlab/shawn24/PhD/RSNA_Breast_Imaging/Dataset' \
--img-dir 'External/Vindr/vindr-mammo-a-large-scale-benchmark-dataset-for-computer-aided-detection-and-diagnosis-in-full-field-digital-mammography-1.0.0/images_png' \
--csv-file 'External/Vindr/vindr-mammo-a-large-scale-benchmark-dataset-for-computer-aided-detection-and-diagnosis-in-full-field-digital-mammography-1.0.0/vindr_detection_v1_folds.csv' \
--clip_chk_pt_path "/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar" \
--data_frac 1.0 \
--dataset 'ViNDr' \
--arch 'upmc_breast_clip_det_b5_period_n_lp' \
--label "Mass" \
--epochs 30 \
--batch-size 8 \
--num-workers 0 \
--print-freq 10000 \
--log-freq 500 \
--running-interactive 'n' \
--n_folds 1 \
--lr 5.0e-5 \
--weighted-BCE 'y' \
--balanced-dataloader 'n'
This script will get the embeddings from the encoders and train a linear classifier at the same time. If you go to experiments.py file (line-296 and 297) and breast_clip_classifier.py (Line-53-56), you get the embeddings. From breast_clip_classifier.py file, you can retrun the embedding directly and save it in the experiments.py file.
@GuilhermeJC13 You can use this:
import torch
import gc
import time
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import f1_score
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
import sys
sys.path.append('/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/codebase/')
from Classifiers.models.breast_clip_classifier import BreastClipClassifier
from Datasets.dataset_utils import get_dataloader_RSNA
from breastclip.scheduler import LinearWarmupCosineAnnealingLR
from metrics import pfbeta_binarized, pr_auc, compute_auprc, auroc, compute_accuracy_np_array
from utils import seed_all, AverageMeter, timeSince
from breastclip.model.modules import load_image_encoder, LinearClassifier
class Args:
def __init__(self):
self.tensorboard_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log'
self.checkpoints = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/checkpoints'
self.output_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/out'
self.data_dir = '/restricted/projectnb/batmanlab/shared/Data/RSNA_Breast_Imaging/Dataset'
self.img_dir = 'RSNA_Cancer_Detection/train_images_png'
self.clip_chk_pt_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar'
self.csv_file = 'RSNA_Cancer_Detection/train_folds.csv'
self.dataset = 'RSNA'
self.data_frac = 1.0
self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
self.label = 'cancer'
self.detector_threshold = 0.1
self.swin_encoder = 'microsoft/swin-tiny-patch4-window7-224'
self.pretrained_swin_encoder = 'y'
self.swin_model_type = 'y'
self.VER = '084'
self.epochs_warmup = 0
self.num_cycles = 0.5
self.alpha = 10
self.sigma = 15
self.p = 1.0
self.mean = 0.3089279
self.std = 0.25053555408335154
self.focal_alpha = 0.6
self.focal_gamma = 2.0
self.num_classes = 1
self.n_folds = 4
self.start_fold = 0
self.seed = 10
self.batch_size = 1
self.num_workers = 4
self.epochs = 9
self.lr = 5.0e-5
self.weight_decay = 1e-4
self.warmup_epochs = 1
self.img_size = [1520, 912]
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.apex = 'y'
self.print_freq = 5000
self.log_freq = 1000
self.running_interactive = 'n'
self.inference_mode = 'n'
self.model_type = "Classifier"
self.weighted_BCE = 'n'
self.balanced_dataloader = 'n'
# Create an instance of the Args class
args = Args()
# Now you can use args just like you would in your script
print(args.tensorboard_path)
# /restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log
args.model_base_name = 'efficientnetb5'
args.data_dir = Path(args.data_dir)
args.df = pd.read_csv(args.data_dir / args.csv_file)
args.df = args.df.fillna(0)
args.cur_fold = 0
args.train_folds = args.df[
(args.df['fold'] == 1) | (args.df['fold'] == 2)].reset_index(drop=True)
args.valid_folds = args.df[args.df['fold'] == args.cur_fold].reset_index(drop=True)
print(f"train_folds shape: {args.train_folds.shape}")
print(f"valid_folds shape: {args.valid_folds.shape}")
# train_folds shape: (27258, 15)
# valid_folds shape: (13682, 15)
ckpt = torch.load(args.clip_chk_pt_path, map_location="cpu")
args.image_encoder_type = ckpt["config"]["model"]["image_encoder"]["name"]
train_loader, valid_loader = get_dataloader_RSNA(args)
print(f'train_loader: {len(train_loader)}, valid_loader: {len(valid_loader)}')
# Compose([
# HorizontalFlip(p=0.5),
# VerticalFlip(p=0.5),
# Affine(p=0.5, interpolation=1, mask_interpolation=0, cval=0.0, mode=0, scale={'x': (0.8, 1.2), 'y': (0.8, 1.2)}, translate_percent={'x': (0.1, 0.1), 'y': (0.1, 0.1)}, translate_px=None, rotate=(20.0, 20.0), fit_output=False, shear={'x': (20.0, 20.0), 'y': (20.0, 20.0)}, cval_mask=0.0, keep_ratio=False, rotate_method='largest_box', balanced_scale=False),
# ElasticTransform(p=0.5, alpha=10.0, sigma=15.0, interpolation=1, border_mode=4, value=None, mask_value=None, approximate=False, same_dxdy=False),
# ], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)
# None
# train_loader: 3407, valid_loader: 1711
n_class = 1
print(ckpt["config"]["model"]["image_encoder"])
config = ckpt["config"]["model"]["image_encoder"]
image_encoder = load_image_encoder(ckpt["config"]["model"]["image_encoder"])
image_encoder_weights = {}
for k in ckpt["model"].keys():
if k.startswith("image_encoder."):
image_encoder_weights[".".join(k.split(".")[1:])] = ckpt["model"][k]
image_encoder.load_state_dict(image_encoder_weights, strict=True)
image_encoder_type = ckpt["config"]["model"]["image_encoder"]["model_type"]
image_encoder = image_encoder.to(args.device)
print(image_encoder_type)
print(config["name"].lower())
# cnn
# tf_efficientnet_b5_ns-detect
progress_iter = tqdm(enumerate(valid_loader), desc=f"[tutorial]",
total=len(valid_loader))
for step, data in progress_iter:
inputs = data['x'].to(args.device)
inputs = inputs.squeeze(1).permute(0, 3, 1, 2)
batch_size = inputs.size(0)
image_features = image_encoder(inputs)
print(image_features.shape)
break
# torch.Size([1, 2048])
@shantanu-ai Is the clip_chk_pt_path, the path to the model available on huggingface?
Hello @shantanu-ai,
Every time I try to run the scripts, I keep getting the error "KeyError: filename 'storages' not found" when I try to load the model using "ckpt = torch.load(args.clip_chk_pt_path)". I'm not sure what's causing this error. Could it be due to corrupted or improperly formatted models? Am I doing something wrong in the execution?
I unzipped the hugginface file and turned it into a .tar
Hi @GuilhermeJC13 The files are good. I think you probably set the path in an incorrect way. For b5, download only this file and follow this notebook. There was a bug in the above notebook for the calling of the encoder, I fixed it and tested it. So now it is good. Thanks for pointing out.
You only need b5-model-best-epoch-7.tar file.
Also this checkpoint is available at google drive
If the problem persists, can u please share the code?
Let me know if you have further issues. If not, let me know if i can close the issue?
I am closing the issue. If you have further queries, let us know.
Hi @shantanu-ai,
Thanks, this helped!
I plan to do the same thing I did with image encoding, but this time with vision and text encoders together. In short, I want to extract the actual embedding from the CLIP model. Do you know if you already have a script to get these embeddings?
I really appreciate your attention!
Hi @GuilhermeJC13 Can you clarify by "extract the actual embedding from the CLIP model"? Do u want to do that for the text encoder of Mammo-CLIP or u want from the actual clip?
def save_rsna_text_emb(clip_model, args):
prompts = create_rsna_mammo_prompts()
sentences_list_unique = save_sent_dict_rsna(args, sent_level=True)
idx = 0
text_embeddings_list = []
with torch.no_grad():
with tqdm(total=len(sentences_list_unique)) as t:
for sent in sentences_list_unique:
text_token = clip_model["tokenizer"](
sent, padding="longest", truncation=True, return_tensors="pt", max_length=256)
text_emb = clip_model["model"].encode_text(text_token.to(args.device))
text_emb = clip_model["model"].text_projection(text_emb) if clip_model["model"].projection else text_emb
text_emb = text_emb / torch.norm(text_emb, dim=1, keepdim=True)
text_emb = text_emb.detach().cpu().numpy()
text_embeddings_list.append(text_emb)
t.set_postfix(batch_id='{0}'.format(idx + 1))
t.update()
idx += 1
text_emb_np = np.concatenate(text_embeddings_list, axis=0)
print(f"Sent list shape: {len(sentences_list_unique)}")
print(f"Text embedding shape: {text_emb_np.shape}")
np.save(args.save_path / f"sent_emb_word_ge_{args.report_word_ge}.npy", text_emb_np)
print(f"files saved at: {args.save_path}")
Note this code, I copied from another project of mine and that codebase is messy, so it may contain trivial errors which you can fix
We compared our model with CLIP as a baseline so did not save the embeddings of CLIP. If you want to setup the baseline, refer to this issue and then u can use the code I shared earlier.
is possible to evaluate the model provided for downstream on HuggingFace. I don't know what I am doing wrong but I am defining the model
n_class = 1
model = BreastClipClassifier(args, ckpt=ckpt, n_class=n_class)
model.load_state_dict(torch.load(args.clf_chk_pr_path)["model"])
model = model.to(args.device)
model.eval()
where clf path is Downstream_evalualtion_b5_fold0/classification/Models/Classifier/fine_tune/mass/upmc_breast_clip_det_b5_period_n_ft_seed_10_fold0_best_acc_cancer_ver084.pth and doing the usual prediction
for step, data in progress_iter:
inputs = data['x'].to(args.device)
inputs = inputs.squeeze(1).permute(0, 3, 1, 2)
batch_size = inputs.size(0)
with torch.cuda.amp.autocast(enabled=True):
y_preds = model(inputs) # Get raw model outputs (logits)
# Apply sigmoid activation to get probabilities
probabilities = torch.sigmoid(y_preds)
# Compare probabilities with threshold 0.5
predictions = (probabilities >= 0.5).float()
# Display predictions with labels
for i, pred in enumerate(predictions):
label = "Cancer" if pred == 1 else "No Cancer"
print(f"Sample {i}: {label} (Probability: {probabilities[i].item():.4f})")
and testing this on the folder of rsna , and I am getting 18% correct , I don't know what I am doing wrong.
sure, I will give it a try! thanks for quick response.
should I keep the args as defined in the notebook or I need to change them for the folder0-downstream weights during intiazation of breastclassifer class ?
class Args:
def __init__(self):
self.tensorboard_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log'
self.checkpoints = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/checkpoints'
self.output_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/out'
self.data_dir = '/restricted/projectnb/batmanlab/shared/Data/RSNA_Breast_Imaging/Dataset'
self.img_dir = 'RSNA_Cancer_Detection/train_images_png'
self.clip_chk_pt_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar'
self.clf_chk_pr_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/temp/upmc_breast_clip_det_b5_period_n_ft_seed_10_fold0_best_aucroc_ver084.pth'
self.csv_file = 'RSNA_Cancer_Detection/train_folds.csv'
self.dataset = 'RSNA'
self.data_frac = 1.0
self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
self.label = 'cancer'
self.detector_threshold = 0.1
self.swin_encoder = 'microsoft/swin-tiny-patch4-window7-224'
self.pretrained_swin_encoder = 'y'
self.swin_model_type = 'y'
self.VER = '084'
self.epochs_warmup = 0
self.num_cycles = 0.5
self.alpha = 10
self.sigma = 15
self.p = 1.0
self.mean = 0.3089279
self.std = 0.25053555408335154
self.focal_alpha = 0.6
self.focal_gamma = 2.0
self.num_classes = 1
self.n_folds = 4
self.start_fold = 0
self.seed = 10
self.batch_size = 1
self.num_workers = 4
self.epochs = 9
self.lr = 5.0e-5
self.weight_decay = 1e-4
self.warmup_epochs = 1
self.img_size = [1520, 912]
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.apex = 'y'
self.print_freq = 5000
self.log_freq = 1000
self.running_interactive = 'n'
self.inference_mode = 'n'
self.model_type = "Classifier"
self.weighted_BCE = 'n'
self.balanced_dataloader = 'n'
@Al-Dai The args and argparse in the train_classifier are the same. If you follow, it will go the BreastClassifier. So, have it accordingly. Important is to do the preprocessing of RSNA. Also, for a sanity check, u can do it on VinDr - mass, calcification and density classification. For VinDr, u dont need to perform preprocessing, we directly uploaded the preprocessed files here.
Thanks, I did apply the correct transformation and it worked! thanks.
Another side question, If I want to train the Mammo-CLIP from scratch, say with torchrun --nproc_per_node=4 ./src/codebase/train.py --config-name pre_train_b5_clip.yaml, I would need to have upmc datasets right ? or is it possible to just train with vindr or rsna alone?
Training from scratch requires upmc dataset. Legally, we cannot release that dataset.
On Mon, Oct 28, 2024 at 10:37 AM Al-Dai @.***> wrote:
Thanks, I did apply the correct transformation and it worked! thanks.
Another side question, If I want to train the Mammo-CLIP from scratch, say with torchrun --nproc_per_node=4 ./src/codebase/train.py --config-name pre_train_b5_clip.yaml, I would need to have upmc datasets right ? or is it possible to just train with vindr or rsna alone?
— Reply to this email directly, view it on GitHub https://github.com/batmanlab/Mammo-CLIP/issues/10#issuecomment-2441773086, or unsubscribe https://github.com/notifications/unsubscribe-auth/AC53JXMNMCPBJLLVF2VYHJLZ5ZD27AVCNFSM6AAAAABMQ5KE5OVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINBRG43TGMBYGY . You are receiving this because you are subscribed to this thread.Message ID: @.***>
@Al-Dai As mentioned by Kayhan, the upmc dataset is private one. VinDr and RSNA do not have reports. To train Mammo-CLIP, u need at least some image+text datasets. The results will be better if you mix an image+label dataset (e.g, RSNA or VinDr) with the image+text dataset. So, if you have any image+text dataset, you can train Mammo-CLIP. Just follow the settings for UPMC for your own image+text data. The text means radiology reports.
I understand, I went through your work, and I think it's excellent! I wanted to thank you.
One last question: how is the location identified with text? I read in the paper this line: '...With Mammo-FActOR, Mammo-CLIP vision encoder excels in localization tasks, accurately identifying findings like masses and calcifications using descriptive sentences, without relying on ground truth bounding boxes.' I was wondering how these lines are generated or if there is code for it.
@Al-Dai , So, this is weak localization using text. That's the Mammo-Factor part of the paper. Read Section 2.3. For a TL;DR: we have templated sentences constructed with the help of a radiologist. We use these sentences and the vision encoder of the trained Mammo-CLIP to train a lightweight projector (Eq.3 in the paper) to learn the mapping: which activation unit (neuron) in the representation from Mammo-CLIP corresponds to a mammography finding (mass or calcification).
Congratulations on your impressive work with the Mammo-CLIP. I am currently interested in testing this model on a new dataset and analyzing the resulting embeddings. Please let me know if there is an existing script available for this purpose.
Thank you for your time and assistance.