zyxu1996 / Efficient-Transformer

Online !!! Application of an efficient transformer improved based on Swin transformer on remote sensing segmentation
Apache License 2.0
72 stars 7 forks source link

Loss #11

Open PeterLuoCoder opened 1 year ago

PeterLuoCoder commented 1 year ago

"Hello, I'm studying your code and have a few questions. I apologize for the interruption, but could you please explain the following to me? In the train.py file, do setting use_mixup and use_edge to 0 or 1 represent false and true, respectively? When both use_mixup and use_edge are set to 0 by default, is the loss calculated as CrossEntropyLoss() + Edge_loss()? If I want to calculate the loss using only CrossEntropyLoss, how should I modify it? Thank you."

class FullModel(nn.Module):

def __init__(self, model, args2):
    super(FullModel, self).__init__()
    self.model = model
    self.use_mixup = args2.use_mixup
    self.use_edge = args2.use_edge

    # self.ce_loss = Edge_weak_loss()
    self.ce_loss = CrossEntropyLoss()

    self.edge_loss = Edge_loss()

    if self.use_mixup:
        self.mixup = Mixup(use_edge=args2.use_edge)

def forward(self, input, label=None, train=True):

    if train and self.use_mixup and label is not None:
        if self.use_edge:
            loss = self.mixup(input, label, [self.ce_loss, self.edge_loss], self.model)
        else:
            loss = self.mixup(input, label, self.ce_loss, self.model)
        return loss

    output = self.model(input)
    if train:
        losses = 0
        if isinstance(output, (list, tuple)):
            if self.use_edge:
                for i in range(len(output) - 1):
                    loss = self.ce_loss(output[i], label)
                    losses += loss
                losses += self.edge_loss(output[-1], edge_contour(label).long())
            else:
                for i in range(len(output)):
                    loss = self.ce_loss(output[i], label)
                    losses += loss
        else:
            losses = self.ce_loss(output, label)
        return losses
    else:
        if isinstance(output, (list, tuple)):
            return output[0]
        else:
            return output