PengchengShi1220 / cbDice

[MICCAI 2024] Centerline Boundary Dice Loss for Vascular Segmentation
Apache License 2.0
51 stars 4 forks source link

y_pred and y_true size #4

Open northwill opened 1 week ago

northwill commented 1 week ago
class SoftcbDiceLoss(torch.nn.Module):
    def __init__(self, iter_=10, smooth=1.):
        super(SoftcbDiceLoss, self).__init__()
        self.smooth = smooth

        # Topology-preserving skeletonization: https://github.com/martinmenten/skeletonization-for-gradient-based-optimization
        self.t_skeletonize = Skeletonize(probabilistic=False, simple_point_detection='EulerCharacteristic')

        # Morphological skeletonization: https://github.com/jocpae/clDice/tree/master/cldice_loss/pytorch
        self.m_skeletonize = SoftSkeletonize(num_iter=iter_)

    def forward(self, y_pred, y_true, t_skeletonize_flage=False):
        breakpoint()
        if len(y_true.shape) == 4:
            dim = 2
        elif len(y_true.shape) == 5:
            dim = 3
        else:
            raise ValueError("y_true should be 4D or 5D tensor.")

When I call this function, I get an error. The shapes of y_pred and y_true are both (1, 2, 128, 128, 128), which stands for (batch, cls, z,y,x). What size should I input?

PengchengShi1220 commented 1 week ago

Hello! (1, 2, 128, 128, 128) is a 5D tensor, so it should be fine. Could you share the exact error message?