mahmoodlab / CONCH

A vision-language foundation model for computational pathology - Nature Medicine
Other
281 stars 22 forks source link

MI-Zero for Zero-Shot NSCLC/RCC Subtyping #16

Closed Dootmaan closed 2 months ago

Dootmaan commented 2 months ago

Thank you for your great work! Recently I was trying to test the zero-shot performance of CONCH on my own NSCLC/RCC dataset split. However, the AUC of both test datasets is below 0.6. My code for inference is presented as follows:

with torch.inference_mode():
    text = tokenize(tokenizer, ['squamous cell carcinoma','adenocarcinoma']).to(device)
    text_features = model.encode_text(text)                                    # "model" is CONCH, initialized by create_model_from_pretrained from conch.open_clip_custom
    image_features = model_mizero(inputs_tensor.permute(0,2,1)).squeeze(-1)    # "inputs_tensor" is the slide embedding generated by CONCH, shaped as B x N x 512

    tPredict0=image_features @ text_features.t()
gSlidePred=tPredict0.softmax(dim=-1)

Is there anything wrong with my zero-shot inference code? I have also tried to change model_mizero into simple nn.AdaptiveAvgPool1d and nn.AdaptiveMaxPool1d but the AUC still won't go beyond 0.6.

fedshyvana commented 2 months ago

Hi, I see several issues for potentially why performance is bad. Assume text_features is C x D

  1. it's not clear what "model_mizero" in your code does without more context. It should 1. apply the contrastive projector head to inputs_tensor, it should then also l2-norm, the output image_features of this should still be N x D
  2. Now you should dot product image_features and text_features to get N x C and then you need to apply topK pooling for each class to get 1 x C logits
  3. Finally you take the softmax (for AUC calculation) - we also scale the logits via the learned temperature in the softmax operation

Also, for optimal performance you should be ensembling multiple templates and classnames (currently you use single classname and no class templates).

Hope this helps conceptually - practically, you can refer to the example here (we provide API for MI-Zero inference): https://github.com/mahmoodlab/CONCH/blob/main/notebooks/MI-zeroshot_classification_example_ensemble.ipynb

Dootmaan commented 2 months ago

Thank you for your timely reply!