cilix-ai / on-the-fly-guidance

[MICCAI 2024] On-the-Fly Guidance Training for Medical Image Registration. Pre-print available in link below.
https://arxiv.org/abs/2308.15216
Creative Commons Attribution 4.0 International
29 stars 3 forks source link

Loss calculation related #12

Closed domadaaaa closed 8 months ago

domadaaaa commented 8 months ago

train.py your code

loss_ncc = criterion_ncc(output[0], y)
loss_reg = criterion_reg(output[1], y)
loss = loss_ncc + loss_reg
loss_vals = [loss_ncc, loss_reg]
loss_all.update(loss.item(), y.numel())

criterion_reg is defined as losses.Grad3d(penalty='l2') I'm confused,beacause output[1] represents flow after stn, but y represents fixed. The gradloss may be calculated using fixed_label and wraped_label

yuelinxin commented 8 months ago

Hi @domadaaaa , the Grad3d is only a regularisation on output[1] (the deformation field), the Grad3d's implementation is taken from TransMorph:

class Grad3d(torch.nn.Module):
    """
    N-D gradient loss.
    """

    def __init__(self, penalty='l1', loss_mult=None):
        super(Grad3d, self).__init__()
        self.penalty = penalty
        self.loss_mult = loss_mult

    def forward(self, y_pred, y_true):
        dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
        dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
        dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])

        if self.penalty == 'l2':
            dy = dy * dy
            dx = dx * dx
            dz = dz * dz

        d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
        grad = d / 3.0

        if self.loss_mult is not None:
            grad *= self.loss_mult
        return grad

As you can see the second argument is not used, the y in loss_reg = criterion_reg(output[1], y) is just a formality, you can remove it if you want.

yuelinxin commented 8 months ago

@domadaaaa We are in the process of refining our code, it contains some implementation which is not intended in our paper, apologies for the confusion.

domadaaaa commented 8 months ago

Hi @domadaaaa , the Grad3d is only a regularisation on output[1] (the deformation field), the Grad3d's implementation is taken from TransMorph:

class Grad3d(torch.nn.Module):
    """
    N-D gradient loss.
    """

    def __init__(self, penalty='l1', loss_mult=None):
        super(Grad3d, self).__init__()
        self.penalty = penalty
        self.loss_mult = loss_mult

    def forward(self, y_pred, y_true):
        dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
        dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
        dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])

        if self.penalty == 'l2':
            dy = dy * dy
            dx = dx * dx
            dz = dz * dz

        d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
        grad = d / 3.0

        if self.loss_mult is not None:
            grad *= self.loss_mult
        return grad

As you can see the second argument is not used, the y in loss_reg = criterion_reg(output[1], y) is just a formality, you can remove it if you want.

alright, Have you ever tested the performance of this model in multimodal registration (inputting a pair of images contain fixed and moving at a time without fixed images as altas) in your work?

yuelinxin commented 8 months ago

Not sure what multi-modal registration are you exactly referring to, we did do some preliminary experiments on MRI to CT registration.

domadaaaa commented 8 months ago

不确定您到底指的是什么多模态配准,我们确实做了一些关于 MRI 到 CT 配准的初步实验。

Thanks,How is your research on preliminary registration from MRI to CT? Will the code and method be published?

yuelinxin commented 8 months ago

As I've mentioned, we are in the process of refactoring and updating our code, this part may become available in future versions, you are welcomed to add another issue regarding this.

domadaaaa commented 8 months ago

As I've mentioned, we are in the process of refactoring and updating our code, this part may become available in future versions, you are welcomed to add another issue regarding this.

Thank you very much. I hope to see your work on multi-modal registration soon