datvuthanh / HybridNets

HybridNets: End-to-End Perception Network
MIT License
588 stars 120 forks source link

The loss doesn't converge when training segmentation head only. #54

Closed bigsquirrel18 closed 1 year ago

bigsquirrel18 commented 2 years ago

I changed the backbone to Efficientnet-b0 and reduce the number of BiFPN layers from 6 to 1, in order to cut down the runtime of inference. After training 200 epochs with segmentation head frozen, I tried to train the model freezing backbone and detection head. But I found that the train loss of segmentation head dose not seem to converge. And the loss of valuation and mIOU are reducing at the same time, which doesn't make sense. Apart from that, I also found that when freezing segmentation head, the segmentation loss is not set to 0 in the code. which can affect the updating of weights in backbone, I suppose.

bigsquirrel18 commented 2 years ago

And I found that the inference speed is much lower than YOLOP if take Efficientnet-b3 as backbone and use 6 BiFPN layers to fuse features. (28 FPS vs 74 FPS on the same machine)

xoiga123 commented 2 years ago

About that...

But I found that the train loss of segmentation head dose not seem to converge. And the loss of valuation and mIOU are reducing at the same time, which doesn't make sense.

There is a bug in random_perspective in utils.py, which does affine transform on segmentation labels but unintentionally creates gray labels (0<value<255). Please add in flags=cv2.INTER_NEAREST so labels remain black and white. But if you're training in MULTICLASS_MODE, this doesn't affect really because I only used white pixels to create segmentation label tensor segmentation[seg_class == 255] = seg_index + 1. In my retraining, the segmentation head does seem to converge, but it's still lower than the public result (even with all the bug fixes, wtf). So maybe there's a bug somewhere else too, or it needs more tuning. My mIOU stabilized around 50 epochs in phase 2. seg[seg_class] = cv2.warpAffine(seg[seg_class], M[:2], dsize=(width, height), borderValue=0, flags=cv2.INTER_NEAREST)

Apart from that, I also found that when freezing segmentation head, the segmentation loss is not set to 0 in the code. which can affect the updating of weights in backbone, I suppose.

I thought I did this in the latest commit, no? seg_loss = seg_loss.mean() if not opt.freeze_seg else torch.tensor(0) Btw use torch.tensor(0, device="cuda") if cuda.

And I found that the inference speed is much lower than YOLOP if take Efficientnet-b3 as backbone and use 6 BiFPN layers to fuse features. (28 FPS vs 74 FPS on the same machine)

Yeah, depthwise convolution is inherently slow in pytorch but not on deployment platforms like tensorrt... I did achieve 30fps on Jetson Xavier NX with 384x256. In our paper, we benchmarked against YOLOP on V100 because it's old and could not utilize cuda, cudnn like newer GPUs in current pytorch, which YOLOP does very well. There's nothing much we could do on the pytorch side.

All the changes I listed above is heavily related to the branch I'm working on, with the bug fixes. It's all experimental stuff so it could break, so I'm retraining for validation. Honestly, I'm as confused and hopeless as you're right now.

bigsquirrel18 commented 2 years ago

Thanks a lot for the detailed reply, I really appreciate it! Btw, the change seg_loss = seg_loss.mean() if not opt.freeze_seg else torch.tensor(0) was not synchronized in train_dpp.py. And I think the reason why my segmentation head didn't converge is that the backbone was pruned too much, making it pretty difficult to extrat useful features from original images.

xoiga123 commented 2 years ago

I'll look into train_ddp.py later, thank you!

And what do you mean by pruning? As in reducing BiFPN layers from 6 to 1? Apart from that, I don't think we did any pruning at all.

bigsquirrel18 commented 2 years ago

Yeah, by pruning I mean reducing BiFPN layers from 6 to 1 and changing the backbone to Efficientnet-b0 (my English is not so good ⁄(⁄ ⁄ ⁄ω⁄ ⁄ ⁄)⁄

xoiga123 commented 2 years ago

It's alright, @datvuthanh what do you think about his network configuration?

Daigo9449 commented 2 years ago

Hi! Are there any considerations/is possible to train the network in a custom dataset? I am interested in training the object detection head in order to detect more classes on the road.

bigsquirrel18 commented 2 years ago

Good news! I've found why the seg loss trained in multiclass mode didn't converge!

I tested tversky loss and focal loss separately, and found that only focal loss didn't converge. So I checked the code in focal loss section, and discovered a bug!

The probability y_pred was computed by softmax function in multiclass mode, and the focal loss of it is supposed to be - gt * (1-y_pred)^gamma * log(y_pred). But in the code, the loss was computed by BCEloss function and summed all together, making the final loss function wrong. Take 1 pixel in 384640 for example, suppose this pixel belongs to background, which means the ground truth of it is [1,0,0] in one-hot format, and the loss should be `(1-y_pred_0)^gamma log(y_pred_0). But in the code, the final loss was computed as(1-y_pred_0)^gamma log(y_pred_0) + y_pred_1^gamma log(1-y_pred_1) + y_pred_2^gamma * log(1-y_pred_2)` (ignoring the alpha weight factor), where y_pred_0,1,2 are the probability of background, road, lane respectively.

After I modified the code of corresponding part, the loss began to converge! In detail, I commented out line 347 to 355 in loss.py, and added code below: y_true = F.one_hot(y_true, num_classes).permute(0, 3, 1, 2), loss = - (1-y_pred).pow(2) * torch.log(y_pred) * y_true, loss = loss.mean().

xoiga123 commented 2 years ago

@bigsquirrel18 Hi, that's really great :+1:. We took the stock Focal Loss with no modifications (well apart from this tiny thing). I'm in no position to validate your change :dizzy_face:, tagging @datvuthanh because he knows this stuff better. Can you report the final result after the change? I'll try training with that too.

cagataysari commented 2 years ago

Hello, we came across the same problem with seg loss and we also replicated that only focal seg loss would not converge. We were able to converge segmentation loss with only Tversky. We will try the patch that bigsquirrel suggests and report the results here soon

bigsquirrel18 commented 2 years ago

Hi @cagataysari , actually I met another problem after modifying the focal loss and didn't figure it out. The modified focal loss started to converge indeed, but after 2 epoches the indicator stopped updating. At first I guess it was because y_pred was too small, making log(y_pred) towards -inf, but after I applied F.nll_loss(), the problem was not solved. Hope you can make a forward step!

cagataysari commented 2 years ago

hello @bigsquirrel18 , thank you for the support. We were able to converge the loss by commenting out FocalSegLoss and instead using monai's focalloss function too. You might want to use that maybe https://docs.monai.io/en/stable/losses.html#focalloss

bigsquirrel18 commented 2 years ago

Thanks a lot for the advice @cagataysari, I'll try it later!

datvuthanh commented 1 year ago

Sorry for my late reply,

To be clear, our approach is calculate focal loss for each independent class (vs background). So, the final focal loss is sum of each binary focal loss and this makes model is hard to converge a bit.

However, thank you for your issue.

You can check my new focal loss.

class FocalLossSegV2(_Loss):
    def __init__(
        self,
        mode: str,
        num_class: int,
        alpha: Optional[float] = None,
        gamma: Optional[float] = 2.0,
        reduction: Optional[str] = "mean",
    ):   
        """Compute Focal loss
        Args:
            mode: Loss mode 'binary', 'multiclass' or 'multilabel'
            alpha: Prior probability of having positive value in target.
            gamma: Power factor for dampening weight (focal strength).
            ignore_index: If not None, targets may contain values to be ignored.
                Target values equal to ignore_index will be ignored from loss computation.
            normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
            reduced_threshold: Switch to reduced focal loss. Note, when using this mode you
                should use `reduction="sum"`.
        Shape
             - **y_pred** - torch.Tensor of shape (N, C, H, W)
             - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
        Reference
            https://github.com/BloodAxe/pytorch-toolbelt
        """
        assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
        super().__init__()

        self.mode = mode
        self.smooth = 1e-4
        self.gamma = gamma
        self.reduction = reduction
        if alpha is None:
            self.alpha = torch.ones(num_class, )
        elif isinstance(alpha, (int, float)):
            self.alpha = torch.as_tensor([alpha] * num_class)
        elif isinstance(alpha, (list, np.ndarray)):
            self.alpha = torch.as_tensor(alpha)
        if self.alpha.shape[0] != num_class:
            raise RuntimeError('the length not equal to number of class')
    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        B, C, H, W = y_pred.shape

        alpha = self.alpha.to(y_pred.device)
        prob = F.softmax(y_pred, dim=1)

        if prob.dim() > 2:
            prob = prob.view(B, C, -1)  # B,C,H,W => B,C,H*W
            prob = prob.permute(0, 2, 1).contiguous()  # B,C,H*W => B,H*W,C
            prob = prob.view(-1, C)  # B,H*W,C => B*H*W,C
        y_true = y_true.view(-1, 1)  # B,H,W => B*H*W,1

        prob = prob.gather(1, y_true).view(-1) + self.smooth # avoid nan
        logpt = torch.log(prob)

        alpha_class = alpha[y_true.squeeze().long()]

        loss = - alpha_class * (1 - prob) ** self.gamma * logpt

        if self.reduction == "mean":
            loss = loss.mean()
        if self.reduction == "sum":
            loss = loss.sum()
        if self.reduction == "batchwise_mean":
            loss = loss.sum(0)    

        return loss    

if __name__ == '__main__':
    y_pred = torch.randn(1, 3, 200, 400) # For example: background, lane, road
    y_true = torch.randn(1, 200, 400) 
    focal_loss = FocalLossSegV2(mode='multiclass', num_class = 3, alpha = 0.25, gamma = 2)
cagataysari commented 1 year ago

Hello @datvuthanh,

will you change FocalLossSeg with this new FocalLossSegV2 on the main branch ?

Thanks in advance

datvuthanh commented 1 year ago

Hi @cagataysari,

I don't want to change the current version of focal loss as it may lead to different results than the paper I have published. In my case, using focal loss for binary classification will help the model predict which class a pixel belongs to more accurately. Its weakness, in my opinion, is that the independent calculation of the loss can lead to a lack of connectivity between pixels.

For example, with FocalSegLoss V1, at pixel location (i,j), the softmax function will output: 0.1, 0.4, 0.5. You can see that the values for class 2 and 3 are equal, leading to confusion in the final prediction. With FocalSegLoss V2, at pixel location (i,j), the softmax function may output: 0.1, 0.2, 0.7. You can see that the value is much larger and easier to predict. However, FocalSegLoss V2 may lead to a higher false positive rate than FocalSegLoss V1.

Thank you