mahmoodlab / MMP

Multimodal prototyping for cancer survival prediction - ICML 2024
Other
42 stars 7 forks source link

visualization #3

Open aletolia opened 1 month ago

aletolia commented 1 month ago

Thank you for your wonderful work!

I encountered an issue while running mmp_visualization.ipynb. Although I successfully ran MMP and obtained the results, the notebook requires reading an h5_feats_fpath = f'<path/to/features>/{slide_id}.h5' to simultaneously access the coordinates and features of the WSI. I used CLAM to segment and extract WSI features, so I ended up with an .h5 file containing the coordinates and a .pt file containing the features. I tried to merge these two files into a single .h5 file as input, but I received an error indicating that I need to input a three-dimensional tensor. Do you know how I can obtain this three-dimensional tensor? Thank you!

slide_id = 'TCGA-RZ-AB0B-01Z-00-DX1.0DF1A3A6-3030-4988-AC2C-CAA0F2EBAEB2' 
slide_fpath = f'../DATA_DIRECTORY/TCGA-RZ-AB0B-01Z-00-DX1.0DF1A3A6-3030-4988-AC2C-CAA0F2EBAEB2.svs' 
h5_feats_fpath = f'../output_file.h5'  
wsi = openslide.open_slide(slide_fpath)  
h5 = h5py.File(h5_feats_fpath, 'r') 

coords = h5['coords'][:]  
feats = torch.Tensor(h5['features'][:])  
custom_downsample = 2  
patch_size = h5['coords'].attrs['patch_size'] * custom_downsample  

with torch.inference_mode():
    out, qqs = panther_encoder.representation(feats).values()  
    tokenizer = PrototypeTokenizer(p=16, out_type='allcat')  
    mus, pis, sigmas = tokenizer.forward(out)  
    mus = mus[0].detach().cpu().numpy()  
    qq = qqs[0,:,:,0].cpu().numpy()  
    global_cluster_labels = qq.argmax(axis=1)  

cat_map = visualize_categorical_heatmap(
    wsi,
    coords, 
    global_cluster_labels, 
    label2color_dict=color_map,
    vis_level=wsi.get_best_level_for_downsample(128),
    patch_size=(patch_size, patch_size),
    alpha=0.4,
)  

display(cat_map.resize((cat_map.width//4, cat_map.height//4)))  
display(get_mixture_plot(mus, colors=list(color_map_hex.values())))  

ValueError Traceback (most recent call last) Cell In[4], line 15 13 14 with torch.inference_mode(): ---> 15 out, qqs = panther_encoder.representation(feats).values()
16 tokenizer = PrototypeTokenizer(p=16, out_type='allcat')
17 mus, pis, sigmas = tokenizer.forward(out)

File ~/documents/code_notes/visualization/MMP/src/visualization/../mil_models/model_PANTHER.py:29, in PANTHER.representation(self, x) 28 def representation(self, x):
---> 29 out, qqs = self.panther(x)
30 return {'repr': out, 'qq': qqs}

File ~/anaconda3/envs/mmp/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, kwargs) 1530 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(args, kwargs)

File ~/anaconda3/envs/mmp/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, *kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(args, **kwargs) 1543 try: 1544 result = None

File ~/documents/code_notes/visualization/MMP/src/visualization/../mil_models/PANTHER/layers.py:49, in PANTHERBase.forward(self, S, mask) 48 def forward(self, S, mask=None):
---> 49 B, N_max, d = S.shape
51 if mask is None:
52 mask = torch.ones(B, N_max).to(S)
ValueError: not enough values to unpack (expected 3, got 2)

Awakuhe commented 1 month ago

try this? with torch.inference_mode(): out, qqs = panther_encoder.representation(feats.unsqueeze(dim=0)).values()
tokenizer = PrototypeTokenizer(p=16, out_type='allcat')
mus, pis, sigmas = tokenizer.forward(out)
mus = mus[0].detach().cpu().numpy()
qq = qqs[0,:,:,0].cpu().numpy()
global_cluster_labels = qq.argmax(axis=1)

andrewsong90 commented 1 month ago

Hello @aletolia,

Thanks for the inquiry - There might be some leftover kinks I might need iron out in the viz code. I am currently traveling to present this work, but let me get back to this as soon as I return. Thank you for your patience

aletolia commented 1 month ago

try this? with torch.inference_mode(): out, qqs = panther_encoder.representation(feats.unsqueeze(dim=0)).values() tokenizer = PrototypeTokenizer(p=16, out_type='allcat') mus, pis, sigmas = tokenizer.forward(out) mus = mus[0].detach().cpu().numpy() qq = qqs[0,:,:,0].cpu().numpy() global_cluster_labels = qq.argmax(axis=1)

thank you! it works for me!

H-Q-N commented 1 month ago

try this? with torch.inference_mode(): out, qqs = panther_encoder.representation(feats.unsqueeze(dim=0)).values() tokenizer = PrototypeTokenizer(p=16, out_type='allcat') mus, pis, sigmas = tokenizer.forward(out) mus = mus[0].detach().cpu().numpy() qq = qqs[0,:,:,0].cpu().numpy() global_cluster_labels = qq.argmax(axis=1)

thank you! it works for me!

Hello, can you run the visualization code successfully? I keep getting errors when running this part of the code. If you can run it successfully, could you provide the _mmp_visualization.ipynb_ file that runs successfully? Thank you very much!

aletolia commented 1 month ago

try this? with torch.inference_mode(): out, qqs = panther_encoder.representation(feats.unsqueeze(dim=0)).values() tokenizer = PrototypeTokenizer(p=16, out_type='allcat') mus, pis, sigmas = tokenizer.forward(out) mus = mus[0].detach().cpu().numpy() qq = qqs[0,:,:,0].cpu().numpy() global_cluster_labels = qq.argmax(axis=1)

thank you! it works for me!

Hello, can you run the visualization code successfully? I keep getting errors when running this part of the code. If you can run it successfully, could you provide the _mmp_visualization.ipynb_ file that runs successfully? Thank you very much!

In fact, I am not sure if I have "truly" replicated the results described in their paper. Here is how I proceeded: First, I used pre-trained UNI and CLAM models for segmentation and feature extraction. For any given WSI (Whole Slide Image) that can be processed in this way, you will obtain an example.h5 file (which stores the coordinates) and a corresponding example.pt file (which stores the features of the regions corresponding to those coordinates). Then, I used the following code to merge them.

import torch
import h5py

def merge_files(h5_file_path, pt_file_path, output_h5_file_path):
    features = torch.load(pt_file_path)

    with h5py.File(h5_file_path, 'r') as h5_file:
        with h5py.File(output_h5_file_path, 'w') as new_h5_file:
            for item in h5_file:
                h5_file.copy(item, new_h5_file)

            new_h5_file.create_dataset('features', data=features.numpy())

    print(f"Features from {pt_file_path} have been added to {output_h5_file_path} under the key 'features'.")

h5_file_path = '../patches/TCGA-V3-A9ZX-01Z-00-DX1.B4A9BC5E-3AFA-4765-BB79-25F3B0497B09.h5'
pt_file_path = '../features/uvm/uni_extracted_mag20x_patch512_fp/feats_pt/TCGA-V3-A9ZX-01Z-00-DX1.B4A9BC5E-3AFA-4765-BB79-25F3B0497B09.pt'
output_h5_file_path = '../output_file1.h5'

merge_files(h5_file_path, pt_file_path, output_h5_file_path)

and then run

slide_id = 'TCGA-V3-A9ZX-01Z-00-DX1.B4A9BC5E-3AFA-4765-BB79-25F3B0497B09'  
slide_fpath = f'../DATA_DIRECTORY/TCGA-V3-A9ZX-01Z-00-DX1.B4A9BC5E-3AFA-4765-BB79-25F3B0497B09.svs'  
h5_feats_fpath = f'../output_file1.h5'  
wsi = openslide.open_slide(slide_fpath)  
h5 = h5py.File(h5_feats_fpath, 'r')  

coords = h5['coords'][:]  
feats = torch.Tensor(h5['features'][:]) 
custom_downsample = 1  
patch_size = h5['coords'].attrs['patch_size'] * custom_downsample 
with torch.inference_mode():
    out, qqs = panther_encoder.representation(feats.unsqueeze(dim=0)).values()
    tokenizer = PrototypeTokenizer(p=16, out_type='allcat')
    mus, pis, sigmas = tokenizer.forward(out)
    mus = mus[0].detach().cpu().numpy()
    qq = qqs[0,:,:,0].cpu().numpy()
    global_cluster_labels = qq.argmax(axis=1)

cat_map = visualize_categorical_heatmap(
    wsi,
    coords, 
    global_cluster_labels, 
    label2color_dict=color_map,
    vis_level=wsi.get_best_level_for_downsample(128),
    patch_size=(patch_size, patch_size),
    alpha=0.4,
)  

display(cat_map.resize((cat_map.width, cat_map.height)))  
display(get_mixture_plot(mus, colors=list(color_map_hex.values())))  

I did get a "heatmap," but I'm not quite sure if it aligns with the original intent of the heatmap in the paper. There are some differences in the way they convey the results.

Richarizardd commented 2 weeks ago

Hi @aletolia @H-Q-N - Regarding the visualization, are you able to reproduce the prototypical assignment visualization in PANTHER? https://github.com/mahmoodlab/PANTHER?tab=readme-ov-file#step-3-visualization

The prototypical assignment maps between PANTHER and MMP should not change (calculating the mixture components should be the same).

aletolia commented 2 weeks ago

@Richarizardd Thank you for your response! I successfully reproduced the prototypical assignment visualization in PANTHER, and the results are identical to those in MMP. However, one thing that confuses me is that I used WSIs from TCGA-UVM for the visualization. In terms of the image results, the c13 prototype is concentrated in the tissue area near the tumor, while the c15 prototype is concentrated within the tumor tissue (as shown in the figure below).

TCGA-V3-A9ZX-01Z-00-DX1 B4A9BC5E-3AFA-4765-BB79-25F3B0497B09_heatmap

However, in the bar chart of the number of each prototype (as I understand it), the proportions of the c13 and c15 prototypes do not differ significantly from those of other prototypes. This is also why I was previously uncertain about my visualization results. The image below is the mixture_plot corresponding to the above image(by the way, I changed the color mapping for the c13 and c15 prototypes to color_map[13] = (220,20,60) and color_map[15] = (65,105,225)).

TCGA-V3-A9ZX-01Z-00-DX1 B4A9BC5E-3AFA-4765-BB79-25F3B0497B09_mus_plot

Richarizardd commented 1 week ago

@aletolia what is this TCGA slide?

aletolia commented 1 week ago

@aletolia what is this TCGA slide?

It's TCGA-V3-A9ZX-01Z-00-DX1.B4A9BC5E-3AFA-4765-BB79-25F3B0497B09.svs, and I used a pre-trained UNI model to extract features using CLAM. Then I directly ran PANTHER on the extracted features to obtain WSI prototypes.

Richarizardd commented 1 week ago

@aletolia - I will be away from keyboard and will look back on this on 8/31.