MECLabTUDA / M3d-Cam

MIT License
315 stars 40 forks source link

3D CNN Prediction #21

Open stevenagl12 opened 2 years ago

stevenagl12 commented 2 years ago

I work with a different form of a predict function in which I simply call the model on my patches and save those individual patches, then have another script stitch them back together. Is there a way that mecam can just save the vlumetric patch gradcam values and I can stitch those together to create the attention maps, or will I need to recreate my inference/predict script? Also, what is the difference btween gradcam and gradcam++?

Karol-G commented 2 years ago

Hi,

Is there a way that mecam can just save the vlumetric patch gradcam values and I can stitch those together to create the attention maps, or will I need to recreate my inference/predict script?

No, sadly not. You have to stich them back yourself. However, I would recommend to use TorchIO for this, which has a GridSampler and GridAggregator just for this purpose.The grid sampler samples every patch from the input image and the grid aggregator stiches the predictions back together. This can be done in a few lines of code and no intermediate saving of the patches is required.

Example from TorchIO:

import torch
import torch.nn as nn
import torchio as tio
patch_overlap = 4, 4, 4  # or just 4
patch_size = 88, 88, 60
subject = tio.datasets.Colin27()
subject
grid_sampler = tio.inference.GridSampler(
    subject,
    patch_size,
    patch_overlap,
)
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)
aggregator = tio.inference.GridAggregator(grid_sampler)
model = nn.Identity().eval()
with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch['t1'][tio.DATA]
        locations = patches_batch[tio.LOCATION]
        logits = model(input_tensor)
        labels = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True)
        outputs = labels
        aggregator.add_batch(outputs, locations)
output_tensor = aggregator.get_output_tensor()

Here is the link: https://torchio.readthedocs.io/patches/patch_inference.html#

Also, what is the difference btween gradcam and gradcam++

I don't really remember the difference anymore. You will have to read them in the papers yourself. Both papers are located under M3d-Cam/Papers.

Best, Karol

stevenagl12 commented 2 years ago

I was able to get the image blocks out and stitch them up myself, how do you typically show the color maps from the classification example? I am curious to see if I can create cloud-like gaussian fields of attention similar to those within 3D.