pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
https://pytorch.org/examples
BSD 3-Clause "New" or "Revised" License
22.21k stars 9.51k forks source link

DDP training question #1096

Open Henryplay opened 1 year ago

Henryplay commented 1 year ago

Hi, I'm using the tutorial https://github.com/pytorch/tutorials/blob/master/intermediate_source/ddp_tutorial.rst for DDP train,using 4 gpus in myself code, reference Basic Use Case. But when I finished the modification, it was stuck during run the demo,meanwhile,video memory has been occupied.Could you help me?

Henryplay commented 1 year ago

and my code is here

from math import gamma
import os
import torch
import argparse
from tqdm import tqdm
from utils.scheduler import GradualWarmupScheduler
from modeling.model import CNN
from modeling.loss import CTCLoss
from utils.dataset import CharDict, LoadData, ImageTransform
from utils.utils import paser_config, edit_distance_score, setup_logger
from torch.utils.data import DataLoader

import torch.distributed as dist
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

class Trainer:

    def __init__(self, config_file):
        self.configs = paser_config(config_file)
        # os.environ['CUDA_VISIBLE_DEVICES'] = self.configs['trainer']['gpus']
        self.build_dataloader()
        self.build_model()
        self.start_epoch = 0
        self.max_epochs = self.configs['trainer']['epochs']
        self.save_dir = os.path.join(self.configs['trainer']['output_dir'], self.configs['name'])
        if not os.path.exists(self.save_dir) : os.makedirs(self.save_dir)
        log_file_mode = 'a' if self.configs['trainer']["resume_ckpt"] else 'w'
        self.logger = setup_logger(log_file_path=os.path.join(self.save_dir, 'train.log'), log_file_mode=log_file_mode)
        self.checkpoint = {
            'epoch': 0,
            'history_acc': [],
            'history_eds': [],
            'model': {},
            'optimizer': {},
            'lr_scheduler': {},
            'configs': self.configs
        }
        if self.configs['trainer']["finetune_ckpt"]:
            self.model.load_state_dict(torch.load(self.configs['trainer']["finetune_ckpt"])['model'], False)
            #ckpt = torch.load(self.configs['trainer']["finetune_ckpt"])['model']
            #self.model.load_state_dict({k: v for k, v in ckpt.items() if 'fc' not in k},False)
        elif self.configs['trainer']["resume_ckpt"]:
            self.checkpoint = torch.load(self.configs['trainer']["resume_ckpt"])
            self.model.load_state_dict(self.checkpoint['model'])
            self.optimizer.load_state_dict(self.checkpoint['optimizer'])
            self.lr_scheduler.load_state_dict(self.checkpoint['lr_scheduler'])
            self.checkpoint['model'].clear()
            self.checkpoint['optimizer'].clear()
            self.checkpoint['lr_scheduler'].clear()
            self.start_epoch = self.checkpoint['epoch'] + 1
        # warp dp-model
        # self.model = torch.nn.DataParallel(self.model)
    def setup(self,rank, world_size):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        # initialize the process group
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
    def cleanup(self):
        dist.destroy_process_group()
    def train(self,rank,world_size):
        self.setup(rank,world_size)
        self.model = self.model.to(rank)
        self.model = DDP(self.model, device_ids=[rank])
        for epoch in range(self.start_epoch, self.max_epochs):
            self.model.train()
            self.checkpoint['epoch'] = epoch
            for i, datas in enumerate(self.train_dataloader):
                img, targets, target_lens = datas["img"], datas["target"], datas["target_len"]
                img = img.to(rank)
                preds = self.model(img)
                loss = self.criterion(preds, targets.to(rank), target_lens.to(rank))
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                # log info
                if i%10 == 0:
                    batch_acc, batch_eds = self.metrics(preds, targets, target_lens)
                    msg = "Epoch: %d/%d, " % (epoch, self.max_epochs) + \
                          "Batch: %d/%d, "%(i, len(self.train_dataloader)) + \
                          "Lr: %.6f, " %  self.scheduler_warmup.get_last_lr()[0] + \
                          "Loss: %.3f, " % loss.item() + \
                          "Acc: %.3f, EDS: %.3f" % (batch_acc, batch_eds)
                    self.logger.info(msg)
            self.scheduler_warmup.step()
            self.cleanup()
            self.eval()

    @torch.no_grad()
    def eval(self):
        self.model.eval()
        nbatch = len(self.test_dataloader)
        acc, eds = 0, 0
        for datas in tqdm(self.test_dataloader, desc="Testing..."):
            img, targets, target_lens = datas["img"], datas["target"], datas["target_len"]
            preds = self.model(img.cuda())
            batch_acc, batch_eds = self.metrics(preds, targets, target_lens)
            acc += batch_acc
            eds += batch_eds
        mean_acc = acc / nbatch
        mean_eds = eds / nbatch

        self.save_model(mean_acc, mean_eds)
        return mean_acc, mean_eds

    def metrics(self, preds, targets, target_lens):
        """WARNING:
            This function will consume a lot of time. Don't use it frequently.
        """
        bs = preds.size(0)
        preds_prob,  preds_idx = preds.permute(0,2,1).detach().softmax(dim=2).max(2)
        decode_idx, decode_prob,_ = self.chardict.ctc_decode(preds_idx.cpu().numpy(), preds_prob.cpu().numpy())
        preds_texts = [self.chardict.idx2text(i, reserve_char='\a') for i in decode_idx]
        target_texts = [self.chardict.idx2text(t[:l], reserve_char='') for t, l in zip(targets, target_lens)]
        ed_score = 0.0
        n_correct = 0
        for s1, s2 in zip(preds_texts, target_texts):
            ed_score += edit_distance_score(s1, s2)
            n_correct += (s1 == s2)
        ed_score /= bs
        batch_acc = n_correct / bs
        return batch_acc, ed_score

    def save_model(self, cur_acc, cur_eds):
        best_acc_path = os.path.join(self.save_dir, "model_best_acc.pth")
        best_eds_path = os.path.join(self.save_dir, "model_best_eds.pth")
        model_last_path = os.path.join(self.save_dir, "model_last.pth")
        self.checkpoint['history_acc'].append(cur_acc)
        self.checkpoint['history_eds'].append(cur_eds)
        self.checkpoint['model'] = self.model.module.state_dict()
        self.checkpoint['optimizer'] = self.optimizer.state_dict()
        self.checkpoint['lr_scheduler'] = self.lr_scheduler.state_dict()

        torch.save(self.checkpoint, model_last_path)
        self.logger.info("Current acc: %.3f, eds: %.3f" % (cur_acc, cur_eds))
        self.logger.info("Save current epoch model to: %s" % model_last_path)
        best_acc = max(self.checkpoint['history_acc'])
        best_eds = max(self.checkpoint['history_eds'])
        if cur_acc >= best_acc:
            torch.save(self.checkpoint, best_acc_path)
            self.logger.info("Best acc: %.3f", cur_acc)
            self.logger.info("Save best Acc model to: %s" % best_acc_path)
        if cur_eds >= best_eds:
            torch.save(self.checkpoint, best_eds_path)
            self.logger.info("Best eds: %.3f", cur_eds)
            self.logger.info("Save best EDS model to: %s" % best_eds_path)

        # release
        self.checkpoint['model'].clear()
        self.checkpoint['optimizer'].clear()
        self.checkpoint['lr_scheduler'].clear()

    def build_model(self):
        in_dim = 1 if self.configs['dataset']['img_mode'] == 'gray' else 3
        out_dim = self.configs['dataset']['ncls']
        self.model = CNN(in_dim, out_dim)
        self.optimizer = getattr(torch.optim, self.configs['optimizer']['type'])(
            self.model.parameters(), **self.configs['optimizer']['args'])
        #set lr_decay
        lr_scheduler_type = self.configs['lr_scheduler']['type']
        if lr_scheduler_type == "StepLR":
            self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.configs['lr_scheduler']['type'])(
                self.optimizer, **self.configs['lr_scheduler']['args'])
        else:
            self.lr_scheduler = getattr(torch.optim.lr_scheduler,self.configs['lr_scheduler']['type'])(
                self.optimizer,5
            )
        self.criterion = CTCLoss()

    def build_dataloader(self):
        self.chardict = CharDict(
            self.configs['dataset']['dict'], self.configs['dataset']['ncls'])
        imtrans = ImageTransform(
            self.configs['dataset']['img_mode'], self.configs['dataset']['img_size'])
        trainset = LoadData(
            self.configs['dataset']['trainset'], self.chardict, imtrans)
        self.train_dataloader = DataLoader(
            trainset, self.configs['dataset']['batch_size'], shuffle=True, collate_fn=trainset.collate_fn, num_workers=16)
        testset = LoadData(
            self.configs['dataset']['testset'], self.chardict, imtrans)
        self.test_dataloader = DataLoader(
            testset, self.configs['dataset']['batch_size'], shuffle=False, collate_fn=trainset.collate_fn, num_workers=16)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_file', default='config/pycrnn.yaml', type=str)
    args = parser.parse_args()
    trainer = Trainer(args.config_file)
    world_size = 4
    mp.spawn(trainer.train,
            args=(world_size, ),
            nprocs = world_size,
            join=True)
AntyRia commented 1 year ago

Hi, do you have a problem with the application getting stuck after starting multiple nodes? On my side, too, running the official multi-node example would get stuck