mazurowski-lab / segmentation-guided-diffusion

[MICCAI 2024] Easy diffusion models (optionally with segmentation guidance) for medical images and beyond.
https://arxiv.org/abs/2402.05210
Other
97 stars 5 forks source link

Discrepancy Between Paper Description and Code Implementation of Class Ablation #15

Open noureddinekhiati opened 1 day ago

noureddinekhiati commented 1 day ago

Hello Authors,

Thank you for sharing the impactful work in your recent paper. I noticed a discrepancy between the class ablation strategy described in the paper and its implementation in the provided code.

Issue:

The paper mentions using a Bernoulli distribution for class ablation, which suggests that each class is considered independently for removal with a certain probability. However, the code appears to use a uniform distribution approach (eval.py file, function ablate_masks in the line 333, torch.rand), which might not align with the described method.

Suggestion:


import torch

def bernoulli_ablate_masks(segs, ablation_prob=0.5):
    num_classes = segs.max().item()  # Assuming class labels start from 1
    for class_idx in range(1, num_classes + 1):
        if torch.bernoulli(torch.tensor([ablation_prob])).item() == 1:
            segs[segs == class_idx] = 0
    return segs
nickk124 commented 1 day ago

Hi, Our implementation is just another way to implement sampling from a Bernoulli with probability parameter p = 0.5 because the Bernoulli sample $\delta$ is computed by seeing if the random uniform sample $u$ from [0,1] is greater than or less than 0.5, which has equal probability when p=0.5. Then, the truth value of $u<0.5$ is converted to 0 or 1 to be $\delta$, resulting in the Bernoulli sample (independent between mask classes). See https://stats.stackexchange.com/questions/240338/given-bernoulli-probability-how-to-draw-a-bernoulli-from-a-uniform-distribution for example.