MrGiovanni / ContinualLearning

[MICCAI 2023] Continual Learning for Abdominal Multi-Organ and Tumor Segmentation
https://www.cs.jhu.edu/~alanlab/Pubs23/zhang2023continual.pdf
Other
52 stars 8 forks source link

training process #14

Open JcWang20 opened 5 months ago

JcWang20 commented 5 months ago

When I debugged training breakpoints, I found that the shape of the image was n, c, h, w, d. But why was it in proj_out in the forward of swinunter shape is not this? def proj_out(self, x, normalize=False): if normalize: x_shape = x.size() if len(x_shape) == 5: n, ch, d, h, w = x_shape x = rearrange(x, "n c d h w -> n d h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n d h w c -> n c d h w")

it means the x is n c d h w shape,but i got n, c, h, w, d。 is this right?what is more,the dice loss fuction:

def binary_diceloss predict = predict.contiguous().view(predict.shape[0], -1) target = target.contiguous().view(target.shape[0], -1)

    num = torch.sum(torch.mul(predict, target), dim=1)
    den = torch.sum(predict, dim=1) + torch.sum(target, dim=1) + self.smooth

    dice_score = 2*num / den
    dice_loss = 1 - dice_score
    dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0]

the predict shape is h,w,d, but you view(predict.shape[0], -1), then get shape of h,wd, then the loss func is wrong, and target[:,0]!=-1, regardless of whether the second dimension is hw or wd, this operation seems meaningless. if i need to view(1,-1) then it looks like meaningful. could you please help me with this problem.