SuperMedIntel / MedSegDiff

Medical Image Segmentation with Diffusion Model
MIT License
1.09k stars 166 forks source link

about the model #124

Open Devil-Ideal opened 1 year ago

Devil-Ideal commented 1 year ago

Hi! I'm interesting in your work and I have read both of the paper. However, after reading the code, I'm a little confused. It seems that this code is still about the first paper, MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model, since I didn't find the SS-Former and the Anchor Condition. I'm guessing it's an improvement over the previous paper (the first paper), because there are also somethings different from the paper. Still, I noticed somthing strange, the ISICDataset returns (img, mask, name) and batch gets the first element, cond gets the secoend element [code: batch, cond, name = next(data_iter) ], and then they were concated by, cond), dim=1). I think the variable denoted cond is the mask, and it's at the last channel. In the model, UNetModel_newpreview, the function named highway_forward is used to encode the image but it seems that it is misused to encode the mask, since the variable denoted c is represented the mask in fact. the code: ++++++++++++function of the gaussion_diffusion def training_losses_segmentation(self, model, classifier, x_start, t, model_kwargs=None, noise=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ if model_kwargs is None: model_kwargs = {} if noise is None: noise = th.randn_like(x_start[:, -1:, ...])

    mask = x_start[:, -1:, ...]
    res = torch.where(mask > 0, 1, 0)   #merge all tumor classes into one to get a binary segmentation mask

    res_t = self.q_sample(res, t, noise=noise)     #add noise to the segmentation channel
    x_t[:, -1:, ...]=res_t.float()
    terms = {}

    if self.loss_type == LossType.MSE or self.loss_type == LossType.BCE_DICE or self.loss_type == LossType.RESCALED_MSE:

        model_output, cal = model(x_t, self._scale_timesteps(t), **model_kwargs)

++++++++forward of the model def forward(self, x, timesteps, y=None): """ Apply the model to an input batch.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"

    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.dtype)
    c = h[:,:-1,...]
    anch, cal = self.highway_forward(c)
Devil-Ideal commented 1 year ago

Hi! I'm interesting in your work and I have read both of the paper. However, after reading the code, I'm a little confused. It seems that this code is still about the first paper, MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model, since I didn't find the SS-Former and the Anchor Condition. I'm guessing it's an improvement over the previous paper (the first paper), because there are also somethings different from the paper. Still, I noticed somthing strange, the ISICDataset returns (img, mask, name) and batch gets the first element, cond gets the secoend element [code: batch, cond, name = next(data_iter) ], and then they were concated by, cond), dim=1). I think the variable denoted cond is the mask, and it's at the last channel. In the model, UNetModel_newpreview, the function named highway_forward is used to encode the image but it seems that it is misused to encode the mask, since the variable denoted c is represented the mask in fact. the code: ++++++++++++function of the gaussion_diffusion def training_losses_segmentation(self, model, classifier, x_start, t, model_kwargs=None, noise=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ if model_kwargs is None: model_kwargs = {} if noise is None: noise = th.randn_like(x_start[:, -1:, ...])

    mask = x_start[:, -1:, ...]
    res = torch.where(mask > 0, 1, 0)   #merge all tumor classes into one to get a binary segmentation mask

    res_t = self.q_sample(res, t, noise=noise)     #add noise to the segmentation channel
    x_t[:, -1:, ...]=res_t.float()
    terms = {}

    if self.loss_type == LossType.MSE or self.loss_type == LossType.BCE_DICE or self.loss_type == LossType.RESCALED_MSE:

        model_output, cal = model(x_t, self._scale_timesteps(t), **model_kwargs)

++++++++forward of the model def forward(self, x, timesteps, y=None): """ Apply the model to an input batch.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"

    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.dtype)
    c = h[:,:-1,...]
    anch, cal = self.highway_forward(c)

Sorry, I misunderstand something, the variable c is actually represent condition. But there is still one thing confuses me. The image and mask were both input into the encoder of diffusion model which mismatch with the paper.

takimailto commented 1 year ago

I also did not find the code related to SSFormer, may I ask if your problem has been solved? Besides, it seems that the code only adds the conditional feature at the first layer?

Devil-Ideal commented 1 year ago

I also did not find the code related to SSFormer, may I ask if your problem has been solved? Besides, it seems that the code only adds the conditional feature at the first layer?

Sorry, I haven't solve it and I'm trying to reproduce the experiments using old version of the model(--version old, not by default)