aipixel / AEMatter

Another matter.
GNU General Public License v2.0
52 stars 2 forks source link

how many days are needed to reproduce the results #16

Closed ChenyiZhang007 closed 1 month ago

ChenyiZhang007 commented 1 month ago

Following official settings, it takes me 45 days to train 300 epochs, Is something wrong?

Windaway commented 1 month ago

In fact, 120 epochs is enough for training. I wrote 300 epochs casually when summarizing the code. However, single-card training is very slow. I can give the code for single-machine multi-card distributed training.

ChenyiZhang007 commented 1 month ago

ok, looking forward to the update

Windaway commented 1 month ago
import random
import warnings
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import model
import torch.nn as nn
import dataset
import laploss
# CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train_dist.py

parser = argparse.ArgumentParser(description='AEMatter Training')
parser.add_argument('-j',
                    '--workers',
                    default=4,
                    type=int,
                    metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--local_rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ')
best_acc1 = 0

def main_worker(gpu, ngpus_per_node, args=None):
    global best_acc1
    torch.backends.cudnn.benchmark = True
    dist.init_process_group(backend='nccl')
    h_dataset = dataset.BasicData(1024)
    h_trainsampler = torch.utils.data.distributed.DistributedSampler(h_dataset)
    train_loader = torch.utils.data.DataLoader(h_dataset,batch_size=1, num_workers=2, shuffle=None, drop_last=True,sampler=h_trainsampler)
    def weighted_loss(pd, gt, wl=0.9, epsilon=1e-12, tri=None):
        bs, _, h, w = pd.shape
        alpha_gt = gt.view(bs, 1, h, w)
        tri = tri.view(bs, 1, h, w)
        diff_alpha0 = (pd - alpha_gt).float() * (tri == 1)
        loss_alpha2 = torch.sqrt(diff_alpha0 * diff_alpha0 + epsilon)
        sums = (torch.sum(tri == 1) + 50.)
        loss_alpha = loss_alpha2.sum() / sums
        return loss_alpha
    def get_param(model):
        nodecay = {'params': [], 'weight_decay': 0}
        decay = {'params': [], 'weight_decay': 1e-6}
        for name, param in model.named_parameters():
            if 'start_conv' in name:
                nodecay['params'].append(param)
            elif 'bias' in name:
                nodecay['params'].append(param)
            elif 'convo' in name:
                nodecay['params'].append(param)
            elif 'conv5' in name:
                nodecay['params'].append(param)
            elif 'conv4' in name:
                nodecay['params'].append(param)
            elif 'conv3' in name:
                nodecay['params'].append(param)
            else:
                decay['params'].append(param)
        return [nodecay, decay]

    segmodel = model.AEMatter()
    segmodel = segmodel.cuda(gpu)
    segmodel = torch.nn.parallel.DistributedDataParallel(segmodel, device_ids=[args.local_rank],find_unused_parameters=True)
    segmodel.train()
    we = torch.tensor([0.001, 1, 0.002]).cuda(gpu)
    optim_g = torch.optim.RAdam(get_param(segmodel), 2.0*1e-5,betas=(0.5,0.999))
    sl = torch.optim.lr_scheduler.CosineAnnealingLR(optim_g,120,1e-7)
    idx = 0
    l1loss=nn.L1Loss().cuda(gpu)
    mloss=laploss.lap_loss().cuda(gpu)
    scaler=torch.cuda.amp.GradScaler()

    def focalc(outputs, targets):
        alpha = 1
        gamma = 2
        ce_loss = torch.nn.functional.cross_entropy(outputs, targets,
                                                    reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (alpha * (1 - pt) ** gamma * ce_loss).mean()
        return focal_loss

    for epoch in range(125):
        print('PreTrain_Start', epoch)
        id = 0
        L = 0
        L_tri = 0
        L_alpha1 = 0
        L_alpha2 = 0

        L_edge = 0
        L_img = 0
        L_mask = 0
        for _, data in enumerate(train_loader):
            bgt2 ,mgt,mgt2,Tfseg,Talpha,fgt= data
            optim_g.zero_grad()
            _,_,h,w=mgt.shape
            mgt=mgt.cuda(gpu, non_blocking=True)
            Talpha=Talpha.cuda(gpu,non_blocking=True)
            Tfseg=Tfseg.cuda(gpu,non_blocking=True)
            optim_g.zero_grad()

            with torch.cuda.amp.autocast():
                lastpred=segmodel(mgt,Tfseg)
                alpha=lastpred[:,0:1]*Tfseg[:,1:2]+Tfseg[:,2:3]
                lossm = mloss(alpha, Talpha)
                loss_alpha=l1loss(alpha,Talpha)
                loss_i=weighted_loss(alpha,Talpha,tri=Tfseg[:,1:2])
                loss=loss_alpha*0.5+loss_i*0.5+lossm*0.5

            scaler.scale(loss).backward()
            scaler.unscale_(optim_g)
            torch.nn.utils.clip_grad_norm_(segmodel.parameters(), 10.)
            scaler.step(optim_g)
            scaler.update()

            id += 1
            L += loss.item()
            L_tri += loss_alpha.item()
            L_alpha1 += loss_alpha.item()
            L_alpha2 += lossm.item()
            L_edge += lossm.item()
            L_img += loss_i.item()
            L_mask += loss_i.item()
            if id % 100 == 0 and id > 0:
                print('Epoch', epoch, 'Total_Los', L / 100.,'Alpha1Loss',L_alpha1/100,'Alpha2Loss',L_alpha2/100)
                L = 0
                id = 0
                L_tri = 0
                L_alpha1 = 0
                L_alpha2 = 0
                L_fg = 0
                L_bg = 0
                L_img2 = 0
                L_img = 0
        if gpu==0 and epoch>25:
            torch.save(segmodel.module.state_dict(), './ckpt/' + str(epoch//1) +'_' +str (gpu)+'aem.ckpt')
        sl.step()

if __name__ == '__main__':
    args = parser.parse_args()
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    main_worker(args.local_rank, args.workers, args)
ChenyiZhang007 commented 1 month ago

Thanks for your response! I wonder if the learning rate should be adjusted according to the batchsize. And the batchsize and the number of GPUs used in the main paper.

Windaway commented 1 month ago

I remember that I have multiplied the initial learning rate by the number of GPUs.

ChenyiZhang007 commented 1 month ago

thanks a lot