MECLabTUDA / M3d-Cam

MIT License
306 stars 40 forks source link

Any plans of integrating new methods? #15

Closed sarthakpati closed 3 years ago

sarthakpati commented 3 years ago

Specifically from https://github.com/frgfm/torch-cam, that work across 2D/3D and for segmentation/classification?

Karol-G commented 3 years ago

Hey Sarthak,

sorry for the late reply! We intend to add new methods in the future, but we don't have any concrete plans yet.

Best Karol

sarthakpati commented 3 years ago

Hey Karol,

Cool, thanks for the update!

Cheers, Sarthak

Mushtaqml commented 2 years ago

@sarthakpati torch-cam works for 3D classification task?

sarthakpati commented 2 years ago

I don't think so.

Mushtaqml commented 2 years ago

@sarthakpati any idea if there is any package which I can use for 3D classification task?

sarthakpati commented 2 years ago

Medcam supports 3D classification. If you are interested, medcam is also integrated with GaNDLF to give you an end-to-end solution.

Mushtaqml commented 2 years ago

@sarthakpati My input is 128x128x64 and it outputs 2x8x16, Idk how to overlay it.

Karol-G commented 2 years ago

Hi @Mushtaqml,

after you inject your model with medcam you can pass your image in a separate variable when calling forward of your model. Medcam will then automatically create an overlay for you.

Example code:

# Import M3d-CAM
from medcam import medcam

# Init your model and dataloader
model = MyCNN()
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Inject model with M3d-CAM
model = medcam.inject(model, output_dir="attention_maps", save_maps=True)

# Continue to do what you're doing...
# In this case inference on some new data
model.eval()
for batch in data_loader:
    # Every time forward is called, attention maps will be generated and saved in the directory "attention_maps"
    output = model(batch, raw_input=batch)  # <---------- The relevant line for the overlay
    # more of your code...

The relevant line is marked as "<---------- The relevant line for the overlay". Medcam adds the raw_input parameter to your model that you can use to pass it the non-normalized / non-standardized image (tensor or numpy).

Best Karol

Mushtaqml commented 2 years ago

Hi @Karol-G,

Thank you for the detailed response.

I am training a 3D resnet model with MRI dataset.

test_dataset = MRIDataset(data=sample, mri_type=type_, is_train=False)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0)

x = next(iter(test_dl))
print(x['image'].shape)

The input MRI shape:

torch.Size([1, 1, 256, 256, 64])

The input MRI sample:

image

Model inference code:

#Import M3d-CAM
from medcam import medcam

model = monai.networks.nets.resnet10(spatial_dims=3, n_input_channels=1, n_classes=1)
device = torch.device("cuda")
model.to(device)
# Inject model with M3d-CAM
model = medcam.inject(model, output_dir="attention_maps", save_maps=True)

model.eval()

model(img, raw_input=img)

Output:

image

I thought that the raw image should not contain the batch dimension, hence I ran it again with a small change:

#Import M3d-CAM
from medcam import medcam

# Inject model with M3d-CAM
model = medcam.inject(model, output_dir="attention_maps", save_maps=True)

model.eval()

model(img, raw_input=img[0])

Output:

tensor([[0.0421]], device='cuda:0', grad_fn=)

It successfully ran but the attention maps saved is of shape 8x2x 16 and it looks like this:

image

I tried to visualize it using ITK-SNAP. Please can you tell me that what I am doing wrong here?.

Thanks

Karol-G commented 2 years ago

Hey,

sorry I forgot, I did not implement the overlay for 3D only for 2D, as there is no common software to open 3D images with an overlay. So you can only save the attention maps without overlay.

By default medcam uses the last conv layer for extracting the attention. The feature map seems to be very small for this layer with a size of (8, 2, 16). The attention map will have the same size. One usually resizes the attention back to the original image size to see the relation between image region and attention.

Best Karol

Mushtaqml commented 2 years ago

@Karol-G Thank you for your kind support.

One last question, do you have any suggestion that how can I try to interpret or understand my 3D network in any way other than GRADCAM.

Thanks