hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.74k stars 4.34k forks source link

[BUG]: model not compatible with GeminiDDP #4958

Open zhangvia opened 12 months ago

zhangvia commented 12 months ago

🐛 Describe the bug

i'm using the colossalai to train the tryon diffusion.but there is a comptatible error. the paper link is try on diffusion ("ZERO DDP error: the synchronization of gradients doesn't exit properly.", 'The most possible reason is that the model is not compatible with GeminiDDP.\n'

Environment

No response

Orion-Zheng commented 12 months ago

Could you tell us how to reproduce the bug(the code, the command, etc)? It will help us to locate the problem

zhangvia commented 12 months ago

Could you tell us how to reproduce the bug(the code, the command, etc)? It will help us to locate the problem

you can use this repo:https://github.com/ankanbhunia/PIDM/tree/383b60eade67ec0c02d6898424f245c488c38f00

i use the booster api and gemini plugin based on this repo. my train.py is

from gc import disable
import os
import warnings

warnings.filterwarnings("ignore")

import time, cv2, torch
from tqdm import tqdm
import numpy as np
import logging

import torch.distributed as dist
from torch import nn, optim
from torch.utils import data
from torchvision import transforms
from tensorfn.optim import lr_scheduler

from config.diffconfig import DiffusionConfig, get_model_conf
from config.dataconfig import Config as DataConfig
from tensorfn import load_config as DiffConfig
from diffusion import create_gaussian_diffusion, make_beta_schedule, ddim_steps
import data as deepfashion_data

import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin

if os.getenv('use_colossalai') == '0':
    logging.basicConfig(format='%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s', level=logging.INFO)
    logger = logging
elif os.getenv('use_colossalai') == '1':
    disable_existing_loggers
    logger = get_dist_logger()
else:
    logging.error("please set env use_colossalai,0 represents using torch,1 represents using colossalai")

def init_distributed():
    dist_url = "env://" # default

    rank = int(os.environ["RANK"])
    world_size = int(os.environ['WORLD_SIZE'])
    local_rank = int(os.environ['LOCAL_RANK'])

    dist.init_process_group(
            backend="nccl",
            init_method=dist_url,
            world_size=world_size,
            rank=rank)

    torch.cuda.set_device(local_rank)
    dist.barrier()
    setup_for_distributed(rank == 0)

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

def is_main_process():
    try:
        if dist.get_rank()==0:
            return True
        else:
            return False
    except:
        return True

def sample_data(loader):
    loader_iter = iter(loader)
    epoch = 0

    while True:
        try:
            yield epoch, next(loader_iter)

        except StopIteration:
            epoch += 1
            loader_iter = iter(loader)

            yield epoch, next(loader_iter)

def accumulate(model1, model2, decay=0.9999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)

def train(conf, loader, model, ema, diffusion, betas, optimizer, scheduler, guidance_prob, cond_scale, device, booster):

    import time

    i = 0

    loss_list = []
    loss_mean_list = []
    loss_vb_list = []

    torch.cuda.synchronize()

    for epoch in range(500):
        if os.getenv('use_colossalai') == '0':
            if is_main_process: print ('#Epoch - '+str(epoch))
        else:
            logger.info(f'#Epoch - {str(epoch)}',ranks=[0])

        start_time = time.time()

        for batch in tqdm(loader):

            i = i + 1

            img = batch["source_image"]
            target_img = batch["target_image"]
            target_pose = torch.cat([batch['target_image_ref'], batch['target_skeleton']], 1)
            if booster is None:
                img = img.to(device)
                target_img = target_img.to(device)
                target_pose = target_pose.to(device)
                time_t = torch.randint(
                    0,
                    conf.diffusion.beta_schedule["n_timestep"],
                    (img.shape[0],),
                    device=device,
                )
            else:
                img = img.to(get_current_device(),dtype=torch.float16)
                target_img = target_img.to(get_current_device(),dtype=torch.float16)
                target_pose = target_pose.to(get_current_device(),dtype=torch.float16)
                time_t = torch.randint(
                    0,
                    conf.diffusion.beta_schedule["n_timestep"],
                    (img.shape[0],),
                    device=get_current_device(),
                )

            loss_dict = diffusion.training_losses(model, x_start = target_img, t = time_t, cond_input = [img, target_pose], prob = 1 - guidance_prob)

            loss = loss_dict['loss'].mean()
            loss_mse = loss_dict['mse'].mean()
            loss_vb = loss_dict['vb'].mean()
            if booster is None:
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1)
                scheduler.step()
                optimizer.step()
                loss = loss_dict['loss'].mean()
            else:
                booster.backward(loss, optimizer)
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

            loss_list.append(loss.detach().item())
            loss_mean_list.append(loss_mse.detach().item())
            loss_vb_list.append(loss_vb.detach().item())

            accumulate(
                ema, model.module, 0 if i < conf.training.scheduler.warmup else 0.9999
            )

            if i%args.save_checkpoints_every_iters == 0 and is_main_process():

                if conf.distributed:
                    model_module = model.module

                else:
                    model_module = model

                torch.save(
                    {
                        "model": model_module.state_dict(),
                        "ema": ema.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "conf": conf,
                    },
                    conf.training.ckpt_path + f"/model_{str(i).zfill(6)}.pt"
                )
        if booster is None:
            if is_main_process():

                print ('Epoch Time '+str(int(time.time()-start_time))+' secs')
                print ('Model Saved Successfully for #epoch '+str(epoch)+' #steps '+str(i))

                if conf.distributed:
                    model_module = model.module

                else:
                    model_module = model

                torch.save(
                    {
                        "model": model_module.state_dict(),
                        "ema": ema.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "conf": conf,
                    },
                    conf.training.ckpt_path + '/last.pt'

                )
        else:
            logger.info(f'Epoch Time {str(int(time.time()-start_time))} secs',rank=[0])
            booster.save_model(model, f"{conf.training.ckpt_path} + '/last.pt'")
            booster.save_model(ema,f"{conf.training.ckpt_path} + '/last_ema.pt'")
            logger.info(f'Model Saved Successfully for #epoch {str(epoch)} #steps {str(i)}',rank=[0])

def main(settings, EXP_NAME):

    [args, DiffConf, DataConf] = settings

    # if is_main_process(): wandb.init(project="person-synthesis", name = EXP_NAME,  settings = wandb.Settings(code_dir="."))

    if DiffConf.ckpt is not None: 
        DiffConf.training.scheduler.warmup = 0

    DiffConf.distributed = True
    local_rank = int(os.environ['LOCAL_RANK'])

    DataConf.data.train.batch_size = args.batch_size//2  #src -> tgt , tgt -> src

    model = get_model_conf().make_model()
    ema = get_model_conf().make_model()

    if DiffConf.ckpt is not None:
        ckpt = torch.load(DiffConf.ckpt, map_location=lambda storage, loc: storage)

        if DiffConf.distributed:
            model.module.load_state_dict(ckpt["model"])

        else:
            model.load_state_dict(ckpt["model"])

        ema.load_state_dict(ckpt["ema"])
        scheduler.load_state_dict(ckpt["scheduler"])

        if is_main_process():  print ('model loaded successfully')

    if os.getenv('use_colossalai') == '0':
        model = model.to(args.device)
        ema = ema.to(args.device)
    else:
        model = model.to(get_current_device(),dtype=torch.float16)
        ema = ema.to(get_current_device(),dtype=torch.float16)

    if DiffConf.distributed:
        if os.getenv('use_colossalai') == '0':
            model = nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                find_unused_parameters=True
            )
            booster = None
        else:
            booster_kwargs = {}
            plugin = GeminiPlugin(placement_policy='static', strict_ddp_mode=True, initial_scale=2 ** 5)
            booster = Booster(plugin=plugin, **booster_kwargs)

    if os.getenv('use_colossalai') == '0':    
        val_dataset, train_dataset = deepfashion_data.get_train_val_dataloader(DataConf.data, labels_required = True, distributed = True)
    elif os.getenv('use_colossalai') == '1':
        val_dataset,train_dataset = deepfashion_data.get_train_val_dataset(DataConf.data,True)
        train_dataset = plugin.prepare_dataloader(train_dataset,num_workers=DataConf.data.train.batch_size*2,batch_size=DataConf.data.train.batch_size)
        val_dataset = plugin.prepare_dataloader(train_dataset,num_workers=1,batch_size=1)

    def cycle(iterable):
        while True:
            for x in iterable:
                yield x

    val_dataset = iter(cycle(val_dataset))

    if os.getenv('use_colossalai') == '1':
        optimizer = HybridAdam(model.parameters(), lr=2e-5, initial_scale=2**5,clipping_norm=1)
        scheduler = DiffConf.training.scheduler.make(optimizer)
        model, optimizer, _, _, scheduler = booster.boost(model, optimizer, lr_scheduler=scheduler)
    else:
        optimizer = DiffConf.training.optimizer.make(model.parameters())
        scheduler = DiffConf.training.scheduler.make(optimizer)

    betas = DiffConf.diffusion.beta_schedule.make()
    diffusion = create_gaussian_diffusion(betas, predict_xstart = False)

    train(
        DiffConf, train_dataset, model, ema, diffusion, betas, optimizer, scheduler, args.guidance_prob, args.cond_scale, args.device,booster
    )

if __name__ == "__main__":
    if os.getenv('use_colossalai') == '0':
        init_distributed()
    else:
        colossalai.launch_from_torch(config={})

    import argparse

    parser = argparse.ArgumentParser(description='help')
    parser.add_argument('--exp_name', type=str, default='vto_model')
    parser.add_argument('--DiffConfigPath', type=str, default='./config/diffusion.conf')
    parser.add_argument('--DataConfigPath', type=str, default='./config/data_arcsoft.yaml')
    parser.add_argument('--dataset_path', type=str, default='./')
    parser.add_argument('--save_path', type=str, default='checkpoints')
    parser.add_argument('--cond_scale', type=int, default=2)
    parser.add_argument('--guidance_prob', type=int, default=0.1)
    parser.add_argument('--sample_algorithm', type=str, default='ddim') # ddpm, ddim
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--save_wandb_logs_every_iters', type=int, default=200000)
    parser.add_argument('--save_checkpoints_every_iters', type=int, default=2000)
    parser.add_argument('--save_wandb_images_every_epochs', type=int, default=100000)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--n_gpu', type=int, default=8)
    parser.add_argument('--n_machine', type=int, default=1)
    parser.add_argument('--local-rank', type=int, default=0)
    parser.add_argument("opts", default=None, nargs=argparse.REMAINDER)

    args = parser.parse_args()

    print ('Experiment: '+ args.exp_name)

    DiffConf = DiffConfig(DiffusionConfig,  args.DiffConfigPath, args.opts, False)
    DataConf = DataConfig(args.DataConfigPath)

    DiffConf.training.ckpt_path = os.path.join(args.save_path, args.exp_name)
    DataConf.data.path = args.dataset_path

    if is_main_process():

        if not os.path.isdir(args.save_path): os.mkdir(args.save_path)
        if not os.path.isdir(DiffConf.training.ckpt_path): os.mkdir(DiffConf.training.ckpt_path)

    # DiffConf.ckpt = "checkpoints/vto_model/last.pt"
    # print("Loading model {}.".format(DiffConf.ckpt))

    main(settings = [args, DiffConf, DataConf], EXP_NAME = args.exp_name)

and besides, there are some dtype errors in models. maybe you need change them to float16

zhangvia commented 12 months ago

besides,how to use the ema feature with colossalai fsdp plugin or gemini plugin? i found that using fsdp plugin will change the model parameters format. and i cant use ema in the end of the step based on the repo in my comment @Orion-Zheng