lnsmith54 / CFL

96 stars 11 forks source link

Some bugs in using CFL for training a segmentation task #4

Open whuhxb opened 2 years ago

whuhxb commented 2 years ago

Hi @lnsmith54

When I used CFL to train a segmentation task, the loss degrades from 11 to 0. I have no idea why this happens. Could this CFL loss be used in a segmentation task? Or how to set and tune the learning rate or learning strategy?

Thanks a lot.

lnsmith54 commented 2 years ago

I do not understand why a decreasing loss is a problem. It is desirable for training to decrease the loss and a loss = 0 means there's no error.

As for using CFL for a segmentation task, I have never tried it. Since focal loss could be used for segmentation, there is no reason why CFL would not work.

whuhxb commented 2 years ago

@lnsmith54

It's strange that it only takes 10 epochs for loss to degrade to 0. When loss is 0, gradient is 0 to backward. Then, the network will learn nothing from it. So I guess there is some problem with the loss for segmentation task.

lnsmith54 commented 2 years ago

That is strange that the loss decreases to 0 in only 10 epochs. It is expected that once the loss=0, backprop will not do anything and the network will no longer learn.

whuhxb commented 2 years ago

@lnsmith54 I have changed the code from pytorch version into tensorflow version. I will pasted these two versions for CFL. Could you please help check? Thanks a lot. Pytorch Version: class Cyclical_FocalLoss(nn.Module): ''' This loss is intended for single-label classification problems ''' def init(self, gamma_pos=0, gamma_neg=4, gamma_hc=0, eps: float = 0.1, reduction='mean', epochs=200, factor=2): super(Cyclical_FocalLoss, self).init()

    self.eps = eps
    self.logsoftmax = nn.LogSoftmax(dim=-1)
    self.targets_classes = []
    self.gamma_hc = gamma_hc
    self.gamma_pos = gamma_pos
    self.gamma_neg = gamma_neg
    self.reduction = reduction
    self.epochs = epochs
    self.factor = factor # factor=2 for cyclical, 1 for modified

def forward(self, inputs, target, epoch):
    '''
    "input" dimensions: - (batch_size,number_classes)
    "target" dimensions: - (batch_size)
    '''
    num_classes = inputs.size()[-1]
    log_preds = self.logsoftmax(inputs)
    if len(list(target.size()))>1:
        target = torch.argmax(target, 1)
    self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)

    # Cyclical
    if self.factor*epoch < self.epochs:
        eta = 1 - self.factor *epoch/(self.epochs-1)
    elif self.factor == 1.0:
        eta = 0
    else:
        eta = (self.factor*epoch/(self.epochs-1) - 1.0)/(self.factor - 1.0) 

    # ASL weights
    targets = self.targets_classes
    anti_targets = 1 - targets
    xs_pos = torch.exp(log_preds)
    xs_neg = 1 - xs_pos
    xs_pos = xs_pos * targets
    xs_neg = xs_neg * anti_targets
    asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
                             self.gamma_pos * targets + self.gamma_neg * anti_targets)
    positive_w = torch.pow(1 + xs_pos,self.gamma_hc * targets)
    log_preds = log_preds * ((1 - eta)* asymmetric_w + eta * positive_w)

    if self.eps > 0:  # label smoothing
        self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)

    # loss calculation
    loss = - self.targets_classes.mul(log_preds)

    loss = loss.sum(dim=-1)
    if self.reduction == 'mean':
        loss = loss.mean()

    return loss

Tensorflow version: class Cyclical_FocalLoss: ''' This loss is intended for single-label classification problems ''' def init(self, gamma_pos=0, gamma_neg=4, gamma_hc=0, eps: float = 0.1, reduction='mean', epochs=200, factor=2):

    self.eps = eps
    #self.logsoftmax = tf.nn.log_softmax()
    self.targets_classes = []
    self.gamma_hc = gamma_hc
    self.gamma_pos = gamma_pos
    self.gamma_neg = gamma_neg
    self.reduction = reduction
    self.epochs = epochs
    self.factor = factor

def forward(self, inputs, target, epoch):
    '''
    "input" dimensions: - (batch_size,number_classes)
    "target" dimensions: - (batch_size)
    '''
    num_classes = inputs.size()[-1]
    #log_preds = tf.nn.log_softmax(inputs, dim=-1) # added
    log_preds = tf.nn.softmax(inputs, dim=-1) # added
    if len(list(target.size()))>1:
        target = tf.argmax(target, 1)
    self.targets_classes = tf.zeros_like(inputs) + tf.one_hot(tf.expand_dims(target,1), num_classes) # right?

    # Cyclical
    if self.factor*epoch < self.epochs:
        eta = 1 - self.factor *epoch/(self.epochs-1)
    elif self.factor == 1.0:
        eta = 0
    else:
        eta = (self.factor*epoch/(self.epochs-1) - 1.0)/(self.factor - 1.0) 

    # ASL weights
    targets = self.targets_classes
    anti_targets = 1 - targets
    xs_pos = tf.exp(log_preds)
    xs_neg = 1 - xs_pos
    xs_pos = xs_pos * targets
    xs_neg = xs_neg * anti_targets
    asymmetric_w = tf.pow(1 - xs_pos - xs_neg,
                             self.gamma_pos * targets + self.gamma_neg * anti_targets)
    positive_w = tf.pow(1 + xs_pos,self.gamma_hc * targets)
    log_preds = log_preds * ((1 - eta)* asymmetric_w + eta * positive_w)

    if self.eps > 0:  # label smoothing
        #self.targets_classes = tf.add(tf.multiply(self.targets_classes, 1 - self.eps), self.eps / num_classes)
        self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)

    # loss calculation
    #loss = - tf.multiply(self.targets_classes, log_preds)
    loss = - self.targets_classes.mul(log_preds)

    loss = tf.reduce_sum(loss, dim=-1) # loss = tf.reduce_sum(loss, axis=1) 
    if self.reduction == 'mean':
        loss = tf.reduce_mean(loss)

    return loss
whuhxb commented 2 years ago

Hi @lnsmith54

I have tried to change the CFL loss into focal loss with my segmentation task. When training for about 10 epochs, the loss degrades to 0. The same bug with CFL. It's so strange.

lnsmith54 commented 2 years ago

Yes, that is strange. I don't have an answer for you.

whuhxb commented 2 years ago

@lnsmith54 OK. Let me check again. I'm not sure whether the bug is caused by the different versions.

whuhxb commented 2 years ago

@lnsmith54 I have examined the focal loss, and it runs OK now.

whuhxb commented 2 years ago

Hi @lnsmith54 I have checked the code, and the cross entropy loss and focal loss run OK now. But, when running CFL, the loss is nan, I think this may be caused by exp or /0 operation. How to avoid too large or too small value in ASL weights?

When reivsed with ASL weights, the CFL loss is nan. gamma_pos = 0 gamma_neg = 4 mid_epochs = 300 gamma_hc = 0 factor = 2 # factor=2 for cyclical, 1 for modified

    # ASL weights
    xs_pos = tf.exp(log_preds) 
    xs_neg = 1 - xs_pos
    xs_pos = xs_pos * targets
    xs_neg = xs_neg * anti_targets
    asymmetric_w = tf.pow(1 - xs_pos - xs_neg, gamma_pos * targets + gamma_neg * anti_targets)
    positive_w = tf.pow(1 + xs_pos, gamma_hc * targets)
    log_preds = log_preds * ((1 - eta)* asymmetric_w + eta * positive_w)      

But log_preds (1 - eta) asymmetric_w or log_preds eta positive_w, the loss runs OK. When log_preds ((1 - eta) asymmetric_w + eta * positive_w) divided by 1000 and runs on one GPU card, it runs OK.

Do you know how to tune this? Thanks a lot.

lnsmith54 commented 2 years ago

As you say, there doesn't seem to be any cause for the NANs. You can try setting: gamma_pos = 0 gamma_neg = 0 gamma_hc = 0 This is the same as cross entropy softmax and you should not have NANs. You can try this and if this works, increase the gammas to see at what point it becomes NANs. In other words, find the border between where the code works and not.