Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.89k stars 3.34k forks source link

training time increase epoch by epoch #20076

Open Eric-Lin-CVTE opened 1 month ago

Eric-Lin-CVTE commented 1 month ago

Bug description

when I run the following code, the training time of the epoch will increase epoch by epoch. For example, the first epoch takes 3:39 min, and the second on takes 4:21min, and the third one takes 5:46 min ..., I don't know why. The following is my code . And the version of lightning my used is 2.3.1

What version are you seeing the problem on?

master

How to reproduce the bug

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File    :   train_pl.py
@Time    :   2024/07/04 17:07:44
@Author  :   Lin Weiquan 
@Version :   1.0
@Desc    :   基于pytorch-lightning训练
'''

from lib.models.besizer_crnn_v3_mask import CRNN_V3_1
from lib.models.besizer_crnn_v7_mask import CRNN_v7
from torch.utils.data import DataLoader
from lib.dataset import get_dataset
from lib.dataset.variable_width import DistCollateFn
from easydict import EasyDict as edict
from lib.utils.utils import model_info
from numpy import *
from pathlib import Path
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
import lib.utils.utils as utils
import torch
import torch.nn as nn
import editdistance
import argparse
import yaml
import time
import os
import torch.backends.cudnn as cudnn
import lightning as pl
import lib.config.alphabets as alphabets
import lib.config.alphabets_shuffle as alphabets_shuffle
import lib.config.alphabets_shuffle2 as alphabets_shuffle2
import lib.config.alphabets_shuffle3 as alphabets_shuffle3

def parse_arg():
    parser = argparse.ArgumentParser(description="train crnn_v7")

    parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)

    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
        #config = yaml.load(f)
        config = edict(config)

    if not config.DATASET.SHUFFLE:
        config.DATASET.ALPHABETS = alphabets.alphabet
    else:
        config.DATASET.ALPHABETS = alphabets_shuffle3.alphabet
    config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
    print("NUM_CLASSES: ", config.MODEL.NUM_CLASSES)

    # try:
    #     config.TRAIN.DISTRIBUTED.LOCAL_RANK = int(os.environ["LOCAL_RANK"])
    # except:
    #     config.TRAIN.DISTRIBUTED.LOCAL_RANK = -1

    return config

class LightningFreeWrite(pl.LightningModule):
    def __init__(self, config):
        super(LightningFreeWrite, self).__init__()
        self.model = CRNN_V3_1(nclass=config.MODEL.NUM_CLASSES + 1,  nh=config.MODEL.NUM_HIDDEN)
        self.criterion = torch.nn.CTCLoss(zero_infinity=True)
        self.config = config
        self.converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
        self.best_char_acc = 0.0
        self.train_step_loss_outputs = []
        self.validation_step_loss_outputs = []
        self.validation_step_char_outputs = []
        self.validation_step_accuracy_outputs = []

    def forward(self, x, mask):
        output, mask = self.model(x, mask)
        return output ,mask

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.TRAIN.LR)
        last_epoch = self.config.TRAIN.BEGIN_EPOCH
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, self.config.TRAIN.LR_STEP,
            self.config.TRAIN.LR_FACTOR, last_epoch-1
        )
        opt_sched = {"scheduler": lr_scheduler, "interval": "epoch"}
        return [optimizer], [opt_sched]

    def ctc_loss(self, preds, gts, input_lens, length):
        return self.criterion(preds, gts, input_lens, length)

    def training_step(self, train_batch, batch_idx):

        inp, labels, masks, input_lens = train_batch

        bs = inp.size(0)

        # model infer
        preds, masks = self.forward(inp, masks)

        preds = preds.permute(1,0,2)

        # compute loss

        gts, length, ace_labels = self.converter.encode(labels)

        gts = gts.long()

        preds_size = torch.IntTensor([preds.size(0)] * bs)

        # loss = self.criterion(preds, gts, input_lens, length)
        loss = self.criterion(preds, gts, preds_size, length)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        # self.train_step_loss_outputs.append(loss)
        # self.log('train_loss', loss,sync_dist=True)
        return loss

    # def on_train_epoch_end(self):

    #     avg_loss = torch.stack(self.train_step_loss_outputs).mean().item()
    #     self.log('avg_train_loss', avg_loss, sync_dist=True, on_epoch=True, logger=True, prog_bar=True)
    #     # self.train_step_loss_outputs.clear()
    #     del self.train_step_loss_outputs
    #     self.train_step_loss_outputs = []
        # return avg_loss

    def validation_step(self, val_batch, batch_idx):
        inp, labels, masks, input_lens = val_batch

        # model infer
        preds, masks = self.forward(inp, masks)

        preds = preds.permute(1,0,2)

        # compute loss

        bs = inp.size(0)

        gts, length, ace_labels = self.converter.encode(labels)

        gts = gts.long()

        preds_size = torch.IntTensor([preds.size(0)] * bs)

        # loss = self.criterion(preds, gts, input_lens, length)
        loss = self.criterion(preds, gts, preds_size, length)

        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        # print(preds.data)
        sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False)

        n_correct = 0
        sum_char = 0
        error_char = 0
        for pred, target in zip(sim_preds, labels):
            if pred == target:
                n_correct += 1
            sum_char += len(target)
            edit_distance = editdistance.eval(pred, target)
            error_char += edit_distance
        accuracy = n_correct / len(labels)
        char_acc = 1 - error_char / sum_char
        # self.log('tl_acc', accuracy,sync_dist=True)
        # self.log('char_acc', char_acc,sync_dist=True)

        self.validation_step_loss_outputs.append(loss)
        self.validation_step_char_outputs.append(char_acc)
        self.validation_step_accuracy_outputs.append(accuracy)
        return loss

    def on_validation_epoch_end(self):
        # print(self.validation_step_outputs)

        avg_loss = torch.stack(self.validation_step_loss_outputs).mean().item()
        avg_char_acc = mean(self.validation_step_char_outputs)
        avg_tl_acc = mean(self.validation_step_accuracy_outputs)
        self.log('avg_val_loss', avg_loss,sync_dist=True, on_epoch=True, logger=True, prog_bar=True)
        self.log('avg_char_acc', avg_char_acc,sync_dist=True, on_epoch=True, logger=True, prog_bar=True)
        self.log('avg_tl_acc', avg_tl_acc,sync_dist=True, on_epoch=True, logger=True, prog_bar=True)
        # self.validation_step_loss_outputs.clear() # free memory
        # self.validation_step_char_outputs.clear()
        # self.validation_step_accuracy_outputs.clear()
        del self.validation_step_loss_outputs, self.validation_step_char_outputs, self.validation_step_accuracy_outputs
        self.validation_step_loss_outputs = []
        self.validation_step_char_outputs = []
        self.validation_step_accuracy_outputs = []

def main():
    # os.environ['OMP_NUM_THREADS'] = '1'
    pl.seed_everything(9958, workers=True)
    torch.set_float32_matmul_precision('high')

    # load config
    config = parse_arg()

    # output_dict = utils.create_log_folder(config, phase='train')

    # cudnn
    # cudnn.benchmark = config.CUDNN.BENCHMARK
    # cudnn.deterministic = config.CUDNN.DETERMINISTIC
    # cudnn.enabled = config.CUDNN.ENABLED

    custom_preprocess = None

    train_dataset = get_dataset(config)(config, custom_preprocess, is_train=True)
    val_dataset = get_dataset(config)(config, custom_preprocess, is_train=False)

    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        # shuffle=config.TRAIN.SHUFFLE,
        drop_last=True,
        collate_fn = DistCollateFn(training = True),
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    val_dataloader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        # shuffle=config.TEST.SHUFFLE,
        drop_last=True,
        collate_fn = DistCollateFn(training = True),
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    time_step = time.strftime("%Y-%m-%d-%H-%M", time.localtime())

    save_log_path = os.path.join('/result', config.OUTPUT_DIR, config.DATASET.DATASET, config.MODEL.NAME, time_step, "log")
    save_ckpt_path = os.path.join('/result', config.OUTPUT_DIR, config.DATASET.DATASET, config.MODEL.NAME, time_step, "checkpoint")

    each_checkpoint = ModelCheckpoint(dirpath=save_ckpt_path,
                                        every_n_epochs=1, monitor="avg_char_acc", save_top_k=-1, save_last=True,filename='{epoch}-{step}-{avg_char_acc:.4f}')

    model = LightningFreeWrite(config=config)
    # 创建一个进度条回调实例
    progress_bar = TQDMProgressBar(refresh_rate=500)
    logger = TensorBoardLogger('tb_logs', name='crnn_v7',version=0, log_graph=True)
    trainer = pl.Trainer(accelerator="gpu", 
                        devices=[0,1], 
                        strategy="ddp", 
                        logger = logger,
                        # progress_bar_refresh_rate=100,
                        max_epochs=config.TRAIN.END_EPOCH,
                        default_root_dir=save_log_path,
                        enable_checkpointing=True,
                        gradient_clip_val=True,
                        callbacks=[each_checkpoint,progress_bar])

    trainer.fit(model, train_dataloader, val_dataloader)

if __name__ == '__main__':

    main()

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- PyTorch Lightning Version (e.g., 1.5.0): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): ```

More info

No response

cc @borda

awaelchli commented 1 month ago

Hey @Eric-Lin-CVTE Thank you for sharing the code. Could you help making the code more minimal to reproduce the issue? Because the code imports from a lib package. A good approach to narrow down where the issue is, is by disabling certain features to check that they are unrelated. For example, disable validation to check if the issue is in validation or not. Or disable checkpointing, etc.

biomolecule commented 1 week ago

I encountered a similar issue. After checking, I found that the time spent in the forward and step parts of the model increases rapidly as the number of epochs increases. Further analysis showed that PyTorch Lightning's automatic zero grad operation is failing. In other words, even without using gradient accumulation, parameter gradients still exist before each forward pass. Manually implementing zeroing, forward, backward, etc., in the training_step resolved the issue.