Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.23k stars 3.38k forks source link

When training large datasets, the speed is very slow. About ten times slower than pytorch #13180

Closed hitlzy closed 2 years ago

hitlzy commented 2 years ago

🐛 Bug

To Reproduce

When training large datasets, the speed is very slow. About ten times slower than pytorch. But when we reduced the dataset size, the training speed became normal again。

Because it is the company's code, it is not convenient for us to publicize the model. We make training public

import time
import os
import random
import torch
import pytorch_lightning as pl

from torch.optim import lr_scheduler

from models.ourmodel import OURModel
from losses.ourloss import OURLoss
from utils.utils import projection, WarmupMultiStepLR, batch_rodrigues

class OURTrainer(pl.LightningModule):
    def __init__(self, hparams, len_traindataloader):
        super(OURTrainer, self).__init__()

        self.hparams.update(hparams)

        self.model = OURModel(in_planes=6, num_layers=50)
        self.loss_fn = OURLoss(hparams)

        self.len_traindataloader = len_traindataloader

    def forward(self, *args):
        return self.model(*args)

    def training_step(self, batch, batch_idx):
        # Get data from the batch
        start_time = time.time()
        do_summary = self.global_step % self.hparams.summary_freq == 0

        input = torch.cat((batch["rgb"], batch["normal"]), 1)
        gt_shape_use = random.uniform(0.0, 1.0) > 0.6

        output, root_coord, cls_kp, reg_kp, in_shape, dp_out = self(input, batch['gender'],
                                                                     batch['intrinsic'],
                                                                     batch['gt_joints2d'][:, 7, :2] / 224.0,
                                                                     batch['shape'] if gt_shape_use else batch['opt_shape'])

        gt_rotmat = batch_rodrigues(batch['pose'].view(-1, 3)).view(-1, 24, 3, 3)
        gt_verts, gt_j3d = self.model.get_verts_and_joints(gt_rotmat.view(-1, 216), in_shape, batch['gender'])
        gt_s_verts, gt_s_j3d = self.model.get_verts_and_joints(gt_rotmat.view(-1, 216), batch['shape'], batch['gender'])
        gt_j2d = projection(gt_s_j3d, batch['cam'], batch['intrinsic'])

        batch['gt_rotmat'] = gt_rotmat
        batch['gt_j3d'] = gt_j3d
        batch['gt_j2d'] = gt_j2d

        total_loss, loss_dict, loss_fbc, loss_root2d, loss_cls, loss_coord2d, loss_z, \
        loss_U, loss_V, loss_IndexUV, loss_segAnn = self.loss_fn(batch, output, root_coord, cls_kp, reg_kp, dp_out)

        if self.trainer.is_global_zero and do_summary:
            self.train_summaries(start_time, batch_idx, total_loss, loss_dict, loss_fbc, loss_root2d, loss_cls,
                        loss_coord2d, loss_z, loss_U, loss_V, loss_IndexUV, loss_segAnn)

        return {'loss': total_loss}

    def train_summaries(self, start_time, batch_idx, total_loss, loss_dict, loss_fbc, loss_root2d, loss_cls,
                        loss_coord2d, loss_z, loss_U, loss_V, loss_IndexUV, loss_segAnn):
        print_something = "Epoch {}/{}, Iter {}/{}. lr {:.6f}, " \
                          "train loss = {:.5f}".format(self.current_epoch, self.trainer.max_epochs, batch_idx,
                                                       self.len_traindataloader, self.lr_schedulers().get_lr()[0],
                                                                                           total_loss)
        for i, loss in enumerate(loss_dict):
            print_something += ", pose loss[{}] = {:.5f}, j3d loss[{}] = {:.5f}, j2d loss[{}] = {:.5f}, scale loss[{}] = {:.5f}".format(
                i, loss['loss_pose'], i, loss['loss_j3d'], i, loss['loss_j2d'], i, loss['loss_scale'])
            print_something += ", fbc loss[{}] = {:.5f}".format(
                i, loss_fbc[i]['loss_iuv']
            )
            print_something += ", real j2d loss[{}] = {:.5f}, real j3d loss[{}] = {:.5f}".format(
                i, loss_fbc[i]['loss_j2d'], i, loss_fbc[i]['loss_j3d']
            )
        print_something += ", root2d loss = {:.5f}".format(loss_root2d)
        print_something += ", cls loss = {:.5f}, coord2d loss = {:.5f}, z loss = {:.5f}".format(loss_cls,
                                                                                                loss_coord2d,
                                                                                                loss_z)
        print_something += ", umap loss = {:.5f}, vmap loss = {:.5f}, imap loss = {:.5f}, smap loss = {:.5f}".format(
            loss_U, loss_V, loss_IndexUV, loss_segAnn)
        print_something += ", time = {:.3f}".format(time.time() - start_time)
        self.print(print_something)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.hparams.lr,
                                     betas=(0.9, 0.999), weight_decay=self.hparams.wdecay)
        milestones = [self.len_traindataloader * int(epoch_idx) for epoch_idx in self.hparams.lrepochs.split(':')[0].split(',')]
        lr_gamma = 1. / float(self.hparams.lrepochs.split(':')[-1])
        scheduler = WarmupMultiStepLR(optimizer, milestones, gamma=lr_gamma, warmup_factor=1. / 3, warmup_iters=500,
                                      last_epoch=self.len_traindataloader * 0 - 1)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
import os
import sys
import argparse
import datetime
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

sys.path.append('')
from core.our_trainer import OURTrainer
from configs.config import get_confing
from datasets.cmu_rgbd_normal import CMUDataset
from datasets.tonglu_rgbd_normal import TLDataset
from datasets.merge import Merge
from torch.utils.data import DataLoader

def train_dataset():
    cmudataset = CMUDataset() #200000 images
    tldataset = TLDataset() #100000 images
    traindataset = Merge([cmudataset, tldataset])
    return traindataset

def train_dataloader(hparams):
    traindataset = train_dataset()
    return DataLoader(
        dataset=traindataset,
        batch_size=hparams.batch_size,
        shuffle=True,
        num_workers=hparams.num_workers,
        persistent_workers=True,
        pin_memory=hparams.pin_m,
        drop_last=False,
    )

def main(hparams, train_dataloader, fast_dev_run=False):
    if not torch.cuda.is_available():
        raise SystemError("need use gpu")
    if hparams.seed >= 0:
        print('random seed', hparams.seed)
        os.environ['PYTHONHASHSEED'] = str(hparams.seed)
        pl.trainer.seed_everything(hparams.seed)

    if not os.path.isdir(hparams.logdir):
        os.makedirs(hparams.logdir)

    current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
    print('current time', current_time_str)

    if hparams.resume:
        # todo
        save_models = []
        for root, dirs, files in os.walk(hparams.logdir, topdown=False):
            for f in files:
                if f.endswith('.ckpt'):
                    save_models.append(os.path.join(root, f))
        if len(save_models) > 0:
            save_models = sorted(save_models, key=lambda x: int(x.split('-')[0]))
            loadckpt = os.path.join(hparams.logdir, save_models[-1])
        else:
            loadckpt = None
        print("resume ", loadckpt)
    else:
        loadckpt = None

    model = OURTrainer(hparams=hparams, len_traindataloader=int(len(train_dataloader) / hparams.gpus)).to('cuda')

    ckpt_callback = ModelCheckpoint(
        filename='{epoch}-{step}',
        save_last=True,
        every_n_epochs=hparams.save_freq,
    )
    if hparams.gpus > 1:
        trainer = pl.Trainer(
            profiler="simple",
            accelerator="gpu",
            devices=hparams.gpus,
            strategy="ddp_find_unused_parameters_false",
            max_epochs=hparams.epochs,
            callbacks=[ckpt_callback],
            default_root_dir=hparams.logdir,
            detect_anomaly=True,
            enable_progress_bar=False,
            reload_dataloaders_every_n_epochs=0,
            fast_dev_run=fast_dev_run,
        )
    else:
        trainer = pl.Trainer(
            profiler="simple",
            accelerator="gpu",
            devices=hparams.gpus,
            max_epochs=hparams.epochs,
            callbacks=[ckpt_callback],
            default_root_dir=hparams.logdir,
            detect_anomaly=True,
            enable_progress_bar=False,
            reload_dataloaders_every_n_epochs=0,
            fast_dev_run=fast_dev_run,
        )

    print('training')
    if loadckpt is not None:
        trainer.fit(model, train_dataloaders=train_dataloader, ckpt_path=loadckpt)
    else:
        trainer.fit(model, train_dataloaders=train_dataloader)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="model training")

    parser.add_argument("--iuvname", type=str, default="gt_iuv", help="the labeled IUV name")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.0001)
    # parser.add_argument("--lrepochs", type=str, default="10,18,26,34:2",
    #                     help="epoch ids to downscale lr and the downscale rate")
    # parser.add_argument("--epochs", type=int, default=45)
    parser.add_argument("--lrepochs", type=str, default="5,9,13,17:2",
                        help="epoch ids to downscale lr and the downscale rate")
    parser.add_argument("--epochs", type=int, default=22)
    parser.add_argument("--wdecay", type=float, default=0.0001)
    parser.add_argument("--gpus", type=int, default=2)
    parser.add_argument("--logdir", default="/data1/3D_Reconstruction/rgbd-human/experiments/rgbd_model",
                        help="the directory to save checkpoints/logs")
    parser.add_argument("--seed", type=int, default=1, metavar='S', help="random seed")
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--summary_freq", type=int, default=1, help="print and summary frequency")
    parser.add_argument("--save_freq", type=int, default=5, help="save checkpoint frequency")
    parser.add_argument("--resume", action="store_true", help="continue to train the model")
    parser.add_argument("--pin_m", default=True, action="store_true", help="data loader pin memory")
    parser.add_argument("--eval_only", action="store_true", help="eval_only")
    parser.add_argument('--fdr', action='store_true')

    args = parser.parse_args()

    hparams = get_confing(args)
    print(hparams)

    main(hparams=hparams, train_dataloader=train_dataloader(hparams), fast_dev_run=args.fdr)

Expected behavior

Environment

Additional context

cc @borda @akihironitta

akihironitta commented 2 years ago

@hitlzy Need to see if the comparison against your PyTorch script is fair. Are you sure that you're using all the PL features in your pure PyTorch script, too? e.g. Are you using anomaly detection context manager in your PyTorch script? (#12344) https://pytorch.org/docs/stable/autograd.html#anomaly-detection

Here's the checklist: #12398


Feel free to reopen this issue.