samleoqh / MSCG-Net

Multi-view Self-Constructing Graph Convolutional Networks with Adaptive Class Weighting Loss for Semantic Segmentation
MIT License
68 stars 28 forks source link

Cross-entropy loss #24

Closed czarmanu closed 4 months ago

czarmanu commented 2 years ago

How to replace the acw loss with cross-entropy loss?

criterion = torch.nn.CrossEntropyLoss()

in train_R50.py doesn't work.

samleoqh commented 2 years ago

you have to set ignore-index=255 for ce loss, because the GT file contains 255 as invalid region, refer to this implementation https://github.com/samleoqh/DDCM-Semantic-Segmentation-PyTorch/blob/master/models/loss_functions/ce_loss.py

czarmanu commented 2 years ago

Thanks!

Modifiied the criterion as : criterion = CrossEntropy2D().cuda()

I added the following class to acw_loss.py

class CrossEntropy2D(nn.Module): def init(self, weight=None, reduction='mean', ignore_index=255): super(CrossEntropy2D, self).init()

    self.loss = nn.NLLLoss(weight, reduction=reduction, ignore_index=ignore_index)

def forward(self, outputs, targets):
    return self.loss(F.log_softmax(outputs, 1), targets)

Then edited the train_R50 as follows:

from lib.loss.acw_loss import CrossEntropy2D

def main(): random_seed(train_args.seeds) train_args.write2txt() net = load_model(name=train_args.model, classes=train_args.nb_classes, node_size=train_args.node_size)

#print("next step: load pre-trained model R50")
#net = get_net(ckpt1)
#print("loaded pre-trained model R50")

net, start_epoch = train_args.resume_train(net)
net.cuda()
net.train()

# prepare dataset for training and validation
train_set, val_set = train_args.get_dataset()
train_loader = DataLoader(dataset=train_set, batch_size=train_args.train_batch, num_workers=0, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=train_args.val_batch, num_workers=0)

#criterion = ACW_loss().cuda()
#criterion = torch.nn.CrossEntropyLoss()
criterion = CrossEntropy2D().cuda()

params = init_params_lr(net, train_args)
# first train with Adam for around 10 epoch, then manually change to SGD
# to continue the rest train, Note: need resume train from the saved snapshot
base_optimizer = optim.Adam(params, amsgrad=True)
# base_optimizer = optim.SGD(params, momentum=train_args.momentum, nesterov=True)
optimizer = Lookahead(base_optimizer, k=6)
# optimizer = AdaX(params)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 60, 1.18e-6)

new_ep = 0
while True:

    starttime = time.time()
    train_main_loss = AverageMeter()
    aux_train_loss = AverageMeter()
    cls_trian_loss = AverageMeter()

    start_lr = train_args.lr
    train_args.lr = optimizer.param_groups[0]['lr']
    num_iter = len(train_loader)
    curr_iter = ((start_epoch + new_ep) - 1) * num_iter
    print('---curr_iter: {}, num_iter per epoch: {}---'.format(curr_iter, num_iter))

    for i, (inputs, labels) in enumerate(train_loader):
        sys.stdout.flush()

        inputs, labels = inputs.cuda(), labels.cuda(),
        N = inputs.size(0) * inputs.size(2) * inputs.size(3)
        optimizer.zero_grad()
        outputs, cost = net(inputs)

        main_loss = criterion(outputs, labels)
        #loss = main_loss + cost
        loss = main_loss

        loss.backward()
        optimizer.step()
        lr_scheduler.step(epoch=(start_epoch + new_ep))

        train_main_loss.update(main_loss.item(), N)
        aux_train_loss.update(cost.item(), inputs.size(0))

        curr_iter += 1
        writer.add_scalar('main_loss', train_main_loss.avg, curr_iter)
        #writer.add_scalar('aux_loss', aux_train_loss.avg, curr_iter)
        # writer.add_scalar('cls_loss', cls_trian_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], curr_iter)

        if (i + 1) % train_args.print_freq == 0:
            newtime = time.time()

            print('[epoch %d], [iter %d / %d], [loss %.5f, aux %.5f, cls %.5f], [lr %.10f], [time %.3f]' %
                  (start_epoch + new_ep, i + 1, num_iter, train_main_loss.avg, aux_train_loss.avg,
                   cls_trian_loss.avg,
                   optimizer.param_groups[0]['lr'], newtime - starttime))

            starttime = newtime

    validate(net, val_set, val_loader, criterion, optimizer, start_epoch + new_ep, new_ep)

    new_ep += 1