YangChangHee / ICCV2023_SEFD_RELEASE

45 stars 1 forks source link

NO TRAINER CLASS! #5

Closed ocissor closed 12 months ago

ocissor commented 1 year ago

Hi, First of all great work. I was going through your train.py file in which you import Trainer from base, but in the base.py there is no class named Trainer, can you please help me with this.

Thank you!

YangChangHee commented 1 year ago

We apologize for any inconvenience. I implemented the training code too complicatedly, so I'm re-editing it! Also, I'm in a hurry with another conference right now, so I'm trying to revise this as soon as it's finished. Can you wait for a few days?

Thank you! :)

ocissor commented 1 year ago

We apologize for any inconvenience. I implemented the training code too complicatedly, so I'm re-editing it! Also, I'm in a hurry with another conference right now, so I'm trying to revise this as soon as it's finished. Can you wait for a few days?

Thank you! :)

Thank you for your response. Is it possible for you to share the function trainer.set_lr() and the trainer.optimizer. If not I totally understand it. Conference submissions are very hectic. All the best for your submission.

Thank you!

YangChangHee commented 1 year ago

The Trainer class we used is as follows!

class Trainer(Base):
    def __init__(self):
        super(Trainer, self).__init__(log_name = 'train_logs.txt')

    def get_optimizer(self, model):
        optimizer = torch.optim.Adam([
            {'params': model.module.backbone.parameters(), 'lr': cfg.lr_backbone},
            {'params': model.module.pose2feat.parameters()},
            {'params': model.module.position_net.parameters()},
            {'params': model.module.rotation_net.parameters()},
            #{'params': model.module.edge_module.parameters()},
        ],
        lr=cfg.lr)
        print('The parameters of backbone, pose2feat, position_net, rotation_net, are added to the optimizer.')

        return optimizer

    def get_optimizer_dilation(self, model):
        optimizer = torch.optim.Adam([
            {'params': model.module.backbone.parameters(), 'lr': cfg.lr_backbone},
            {'params': model.module.pose2feat.parameters()},
            {'params': model.module.position_net.parameters()},
            {'params': model.module.rotation_net.parameters()},
            {'params': model.module.ced.parameters()},
        ],
        lr=cfg.lr)
        # kornia toolkit make gradients
        print('The parameters of backbone, pose2feat, position_net, rotation_net, are added to the optimizer.')

        return optimizer

    def save_model(self, state, epoch):
        file_path = osp.join(cfg.model_dir,'snapshot_{}.pth.tar'.format(str(epoch)))
        torch.save(state, file_path)
        self.logger.info("Write snapshot into {}".format(file_path))

    def load_model(self, model, optimizer):
        model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar'))
        cur_epoch = max([int(file_name[file_name.find('snapshot_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list])
        ckpt_path = osp.join(cfg.model_dir, 'snapshot_' + str(cur_epoch) + '.pth.tar')
        ckpt = torch.load(ckpt_path)
        start_epoch = ckpt['epoch'] + 1

        if cfg.distillation_pretrained==True:
            model.load_state_dict(torch.load(cfg.distillation_module_path)['network'],strict=False)
        else:
            model.load_state_dict(ckpt['network'], strict=False)
            model=model.module.cpu()
            model.backbone.init_weights()
            model = DataParallel(model).cuda()

        self.logger.info('Load checkpoint from {}'.format(ckpt_path))
        return start_epoch, model, optimizer

    def set_lr(self, epoch):
        for e in cfg.lr_dec_epoch:
            if epoch < e:
                break
        if epoch < cfg.lr_dec_epoch[-1]:
            idx = cfg.lr_dec_epoch.index(e)
            for g in self.optimizer.param_groups:
                g['lr'] = cfg.lr / (cfg.lr_dec_factor ** idx)
        else:
            for g in self.optimizer.param_groups:
                g['lr'] = cfg.lr / (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch))

    def get_lr(self):
        for g in self.optimizer.param_groups:
            cur_lr = g['lr']
        return cur_lr

    def _make_batch_generator(self):
        # data load and construct batch generator
        self.logger.info("Creating dataset...")
        trainset3d_loader = []
        if cfg.smplify==False:
            for i in range(len(cfg.trainset_3d)):
                trainset3d_loader.append(eval(cfg.trainset_3d[i])(transforms.ToTensor(), "train"))
        #print("this")
        trainset2d_loader = []
        for i in range(len(cfg.trainset_2d)):
            trainset2d_loader.append(eval(cfg.trainset_2d[i])(transforms.ToTensor(), "train"))

        if len(trainset3d_loader) > 0 and len(trainset2d_loader) > 0:
            self.vertex_num = trainset3d_loader[0].vertex_num
            self.joint_num = trainset3d_loader[0].joint_num
            trainset3d_loader = MultipleDatasets(trainset3d_loader, make_same_len=False)
            trainset2d_loader = MultipleDatasets(trainset2d_loader, make_same_len=False)
            trainset_loader = MultipleDatasets([trainset3d_loader, trainset2d_loader], make_same_len=True)
        elif len(trainset3d_loader) > 0:
            self.vertex_num = trainset3d_loader[0].vertex_num
            self.joint_num = trainset3d_loader[0].joint_num
            trainset_loader = MultipleDatasets(trainset3d_loader, make_same_len=False)
        elif len(trainset2d_loader) > 0:
            self.vertex_num = trainset2d_loader[0].vertex_num
            self.joint_num = trainset2d_loader[0].joint_num
            trainset_loader = MultipleDatasets(trainset2d_loader, make_same_len=False)
        else:
            assert 0, "Both 3D training set and 2D training set have zero length."

        self.itr_per_epoch = math.ceil(len(trainset_loader) / cfg.num_gpus / cfg.train_batch_size)
        self.batch_generator = DataLoader(dataset=trainset_loader, batch_size=cfg.num_gpus*cfg.train_batch_size, shuffle=True, num_workers=cfg.num_thread, pin_memory=True)

    def _make_model(self):
        # prepare network
        self.logger.info("Creating graph and optimizer...")
        if cfg.distillation_pretrained==True and cfg.smplify==False:
            smpl_overlap_model=smpl_get_model(30)
            smpl_overlap_model = DataParallel(smpl_overlap_model).cuda()
            smpl_overlap_model.load_state_dict(torch.load(cfg.distillation_module_path)['network'],strict=False)
            print("Load SMPL_overlap_module success!")
        else:
            smpl_overlap_model=None

        model = get_model(self.vertex_num, self.joint_num,smpl_overlap_model, 'train')
        model = DataParallel(model).cuda()
        optimizer = self.get_optimizer(model)
        if cfg.continue_train:
            start_epoch, model, optimizer = self.load_model(model, optimizer)
        else:
            start_epoch = 0
        optimizer = self.get_optimizer(model)
        model.train()

        self.start_epoch = start_epoch
        self.model = model
        self.optimizer = optimizer

I apologize for sending you the code in such a raw way. We will update you as soon as preparations for this conference are over! Please wait a little longer!

good luck!

YangChangHee commented 12 months ago

We upload train class :) I close this