deepinsight / insightface

State-of-the-art 2D and 3D Face Analysis Project
https://insightface.ai
23.45k stars 5.42k forks source link

(arcface_torch) how to add two images' loss and backward? #1485

Open Light-- opened 3 years ago

Light-- commented 3 years ago

Dear @anxiangsir :

i changed the code into this (add the gradient of two classification loss then backward), but find the result is very bad...and finally Nan encountered!!

may i ask what's wrong in my revision? the input of my network is image pairs, in each pair, two images share one same label.

if partial_fc.py, i changed it into


    def forward_backward(self, label, features, optimizer):
        total_label, norm_weight = self.prepare(label, optimizer)
        total_features = torch.zeros(
            size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)

        # copy features to different gpus
        dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
        total_features.requires_grad = True

        logits = self.forward(total_features, norm_weight)
        logits = self.margin_softmax(logits, total_label)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            dist.all_reduce(max_fc, dist.ReduceOp.MAX)

            # calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

            # calculate prob
            logits_exp.div_(logits_sum_exp)

            # get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            dist.all_reduce(loss, dist.ReduceOp.SUM)
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            # calculate grad
            grad[index] -= one_hot
            grad.div_(self.batch_size * self.world_size)

        # logits.backward(grad)
        return grad, logits,loss_v, total_features

in main() of train.py, i changed it into:


def main(args):
    world_size = int(os.environ['WORLD_SIZE'])
    rank = int(os.environ['RANK'])
    dist_url = "tcp://{}:{}".format(os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"])
    dist.init_process_group(backend='nccl', init_method=dist_url, rank=rank, world_size=world_size)
    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)

    if not os.path.exists(cfg.output) and rank is 0:
        os.makedirs(cfg.output)
    else:
        time.sleep(2)

    log_root = logging.getLogger()
    init_logging(log_root, rank, cfg.output)
    trainset = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, shuffle=True)
    train_loader = DataLoaderX(
        local_rank=local_rank, dataset=trainset, batch_size=cfg.batch_size,
        sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=True,
    worker_init_fn = worker_init_fn)  # for reproducibility 

    dropout = 0.4 if cfg.dataset is "webface" else 0
    backbone = eval("backbones.{}".format(args.network))(False, dropout=dropout, fp16=cfg.fp16).to(local_rank)

    if args.resume:
        try:
            backbone_pth = os.path.join(cfg.output, "backbone.pth")
            backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
            if rank is 0:
                logging.info("backbone resume successfully!")
        except (FileNotFoundError, KeyError, IndexError, RuntimeError):
            logging.info("resume fail, backbone init successfully!")

    for ps in backbone.parameters():
        dist.broadcast(ps, 0)
    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone, broadcast_buffers=False, device_ids=[local_rank])
    backbone.train()

    margin_softmax = eval("losses.{}".format(args.loss))()
    module_partial_fc = PartialFC(
        rank=rank, local_rank=local_rank, world_size=world_size, resume=args.resume,
        batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
        sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)

    opt_backbone = torch.optim.SGD(
        params=[{'params': backbone.parameters()}],
        lr=cfg.lr / 512 * cfg.batch_size * world_size,
        momentum=0.9, weight_decay=cfg.weight_decay)
    opt_pfc = torch.optim.SGD(
        params=[{'params': module_partial_fc.parameters()}],
        lr=cfg.lr / 512 * cfg.batch_size * world_size,
        momentum=0.9, weight_decay=cfg.weight_decay)

    scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
        optimizer=opt_backbone, lr_lambda=cfg.lr_func)
    scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
        optimizer=opt_pfc, lr_lambda=cfg.lr_func)

    start_epoch = 0
    total_step = int(len(trainset) / cfg.batch_size / world_size * cfg.num_epoch)
    if rank is 0: logging.info("Total Step is: %d" % total_step)

    interv = 400 
    callback_verification = CallBackVerification(interv, rank, cfg.val_targets, cfg.rec) # 150 for debug, self.frequent = 1000
    callback_logging = CallBackLogging(50, rank, total_step, cfg.batch_size, world_size, None) # verbose = 50
    callback_checkpoint = CallBackModelCheckpoint(interv, rank, cfg.output)

    loss1 = AverageMeter()
    loss2 = AverageMeter()

    global_step = 0

    grad_scaler = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
    for epoch in range(start_epoch, cfg.num_epoch):
        train_sampler.set_epoch(epoch)
        for step, (img, label) in enumerate(train_loader):
            global_step += 1
            # features = F.normalize(backbone(img))

            img1, img2 = img[:,:,:,:112], img[:,:,:,112:]
            assert not cfg.fp16

            # branch 1
            fc1Features1 = F.normalize(backbone(img1))
            fc7grad1, logits1, loss_v1, total_features1 = module_partial_fc.forward_backward(label, fc1Features1, opt_pfc)

            fc1Features2 = F.normalize(backbone(img2))
            fc7grad2, logits2, loss_v2, total_features2 = module_partial_fc.forward_backward(label, fc1Features2, opt_pfc)

            fc7addGrad = fc7grad1  + fc7grad2 
            logits1.backward(fc7addGrad)

            if total_features1.grad is not None:
                total_features1.grad.detach_()
            x_grad1: torch.Tensor = torch.zeros_like(fc1Features1, requires_grad=True)
            # feature gradient all-reduce
            dist.reduce_scatter(x_grad1, list(total_features1.grad.chunk(world_size, dim=0)))
            x_grad1 = x_grad1 * world_size
            fc1Features1.backward(x_grad1)

            clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
            opt_backbone.step()
            opt_pfc.step()
            module_partial_fc.update()
            opt_backbone.zero_grad()
            opt_pfc.zero_grad()
            loss1.update(loss_v1, 1)
            loss2.update(loss_v2, 1)

            if global_step > interv and global_step % interv == 0 and dist.get_rank()==0:
                lr = opt_backbone.state_dict()['param_groups'][0]['lr']
                logging.info('LR %s'%lr)

            callback_logging(global_step, loss1, loss2, epoch, cfg.fp16, grad_scaler)
            callback_verification(global_step, backbone)
            callback_checkpoint(global_step, backbone, module_partial_fc)

        scheduler_backbone.step()
        scheduler_pfc.step()

    dist.destroy_process_group()

the results:

Training: 2021-04-29 17:16:38,932-Speed 2689.27 samples/sec   Loss1 43.6098   Loss2 44.0652   Epoch: 3   Global Step: 2400   Required: 3 hours
testing verification..
(12000, 512)
infer time 19.456014000000057
Training: 2021-04-29 17:17:01,106-[lfw][2400]XNorm: 22.560986
Training: 2021-04-29 17:17:01,107-[lfw][2400]Accuracy-Flip: 0.49983+-0.00050
Training: 2021-04-29 17:17:01,107-[lfw][2400]Accuracy-Highest: 0.50400
testing verification..
(14000, 512)
infer time 22.705947999999964
Training: 2021-04-29 17:17:26,860-[cfp_fp][2400]XNorm: 25.310066
Training: 2021-04-29 17:17:26,860-[cfp_fp][2400]Accuracy-Flip: 0.50014+-0.00119
Training: 2021-04-29 17:17:26,860-[cfp_fp][2400]Accuracy-Highest: 0.50314
testing verification..
(12000, 512)
infer time 19.476336999999983
Training: 2021-04-29 17:17:49,061-[agedb_30][2400]XNorm: 27.397241
Training: 2021-04-29 17:17:49,061-[agedb_30][2400]Accuracy-Flip: 0.50050+-0.00107
Training: 2021-04-29 17:17:49,061-[agedb_30][2400]Accuracy-Highest: 0.51433

...

Training: 2021-04-29 17:52:21,601-Speed 2653.84 samples/sec   Loss1 43.9708   Loss2 43.6432   Epoch: 8   Global Step: 5600   Required: 3 hours
Training: 2021-04-29 17:52:44,205-[lfw][5600]XNorm: 21.012763
Training: 2021-04-29 17:52:44,206-[lfw][5600]Accuracy-Flip: 0.49983+-0.00050
Training: 2021-04-29 17:52:44,206-[lfw][5600]Accuracy-Highest: 0.50400
Training: 2021-04-29 17:53:10,282-[cfp_fp][5600]XNorm: 17.821486
Training: 2021-04-29 17:53:10,282-[cfp_fp][5600]Accuracy-Flip: 0.50043+-0.00181
Training: 2021-04-29 17:53:10,282-[cfp_fp][5600]Accuracy-Highest: 0.50314
Training: 2021-04-29 17:53:32,873-[agedb_30][5600]XNorm: 19.934401
Training: 2021-04-29 17:53:32,873-[agedb_30][5600]Accuracy-Flip: 0.50000+-0.00236
Training: 2021-04-29 17:53:32,873-[agedb_30][5600]Accuracy-Highest: 0.51433
Training: 2021-04-29 17:53:57,506-Speed 688.18 samples/sec   Loss1 43.8332   Loss2 43.9438   Epoch: 8   Global Step: 5650   Required: 3 hours
Training: 2021-04-29 17:54:21,778-Speed 2719.32 samples/sec   Loss1 43.8296   Loss2 43.8148   Epoch: 8   Global Step: 5700   Required: 3 hours
Training: 2021-04-29 17:54:46,199-Speed 2702.61 samples/sec   Loss1 44.0131   Loss2 43.6397   Epoch: 8   Global Step: 5750   Required: 3 hours
Training: 2021-04-29 17:55:10,660-Speed 2698.16 samples/sec   Loss1 43.9690   Loss2 43.6218   Epoch: 8   Global Step: 5800   Required: 3 hours
Training: 2021-04-29 17:55:35,123-Speed 2698.04 samples/sec   Loss1 43.6915   Loss2 43.9429   Epoch: 8   Global Step: 5850   Required: 3 hours
Training: 2021-04-29 17:55:59,642-Speed 2691.88 samples/sec   Loss1 43.9401   Loss2 43.6547   Epoch: 8   Global Step: 5900   Required: 3 hours
Training: 2021-04-29 17:56:25,154-Speed 2587.05 samples/sec   Loss1 43.7101   Loss2 43.9016   Epoch: 9   Global Step: 5950   Required: 3 hours
Training: 2021-04-29 17:56:49,175-LR 0.025781250000000002
Training: 2021-04-29 17:56:49,176-Speed 2747.59 samples/sec   Loss1 nan   Loss2 nan   Epoch: 9   Global Step: 6000   Required: 3 hours
yjhuasheng commented 3 years ago

您好,请问您解决这个问题了吗,我也有同样的问题

Light-- commented 3 years ago

您好,请问您解决这个问题了吗,我也有同样的问题

没有啊,要不留个联系方式交流一下?

yjhuasheng commented 3 years ago

您好,请问您解决这个问题了吗,我也有同样的问题

没有啊,要不留个联系方式交流一下?

阔以啊,怎么联系?

zhangluustb commented 3 years ago

您好,你们问题解决了吗?

Light-- commented 3 years ago

您好,你们问题解决了吗? @zhangluustb

没有。加个微信?

yjhuasheng commented 3 years ago

你微信多少啊

------------------ 原始邮件 ------------------ 发件人: "deepinsight/insightface" @.>; 发送时间: 2021年6月2日(星期三) 中午12:59 @.>; @.**@.>; 主题: Re: [deepinsight/insightface] (arcface_torch) how to add two images' loss and backward? (#1485)

您好,你们问题解决了吗? @zhangluustb

没有。加个微信?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub, or unsubscribe.

zhangluustb commented 3 years ago

您好,你们问题解决了吗? @zhangluustb

没有。加个微信?

lu6017512