Open zhangvia opened 12 months ago
Could you tell us how to reproduce the bug(the code, the command, etc)? It will help us to locate the problem
Could you tell us how to reproduce the bug(the code, the command, etc)? It will help us to locate the problem
you can use this repo:https://github.com/ankanbhunia/PIDM/tree/383b60eade67ec0c02d6898424f245c488c38f00
i use the booster api and gemini plugin based on this repo. my train.py is
from gc import disable
import os
import warnings
warnings.filterwarnings("ignore")
import time, cv2, torch
from tqdm import tqdm
import numpy as np
import logging
import torch.distributed as dist
from torch import nn, optim
from torch.utils import data
from torchvision import transforms
from tensorfn.optim import lr_scheduler
from config.diffconfig import DiffusionConfig, get_model_conf
from config.dataconfig import Config as DataConfig
from tensorfn import load_config as DiffConfig
from diffusion import create_gaussian_diffusion, make_beta_schedule, ddim_steps
import data as deepfashion_data
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
if os.getenv('use_colossalai') == '0':
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s', level=logging.INFO)
logger = logging
elif os.getenv('use_colossalai') == '1':
disable_existing_loggers
logger = get_dist_logger()
else:
logging.error("please set env use_colossalai,0 represents using torch,1 represents using colossalai")
def init_distributed():
dist_url = "env://" # default
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(
backend="nccl",
init_method=dist_url,
world_size=world_size,
rank=rank)
torch.cuda.set_device(local_rank)
dist.barrier()
setup_for_distributed(rank == 0)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_main_process():
try:
if dist.get_rank()==0:
return True
else:
return False
except:
return True
def sample_data(loader):
loader_iter = iter(loader)
epoch = 0
while True:
try:
yield epoch, next(loader_iter)
except StopIteration:
epoch += 1
loader_iter = iter(loader)
yield epoch, next(loader_iter)
def accumulate(model1, model2, decay=0.9999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
def train(conf, loader, model, ema, diffusion, betas, optimizer, scheduler, guidance_prob, cond_scale, device, booster):
import time
i = 0
loss_list = []
loss_mean_list = []
loss_vb_list = []
torch.cuda.synchronize()
for epoch in range(500):
if os.getenv('use_colossalai') == '0':
if is_main_process: print ('#Epoch - '+str(epoch))
else:
logger.info(f'#Epoch - {str(epoch)}',ranks=[0])
start_time = time.time()
for batch in tqdm(loader):
i = i + 1
img = batch["source_image"]
target_img = batch["target_image"]
target_pose = torch.cat([batch['target_image_ref'], batch['target_skeleton']], 1)
if booster is None:
img = img.to(device)
target_img = target_img.to(device)
target_pose = target_pose.to(device)
time_t = torch.randint(
0,
conf.diffusion.beta_schedule["n_timestep"],
(img.shape[0],),
device=device,
)
else:
img = img.to(get_current_device(),dtype=torch.float16)
target_img = target_img.to(get_current_device(),dtype=torch.float16)
target_pose = target_pose.to(get_current_device(),dtype=torch.float16)
time_t = torch.randint(
0,
conf.diffusion.beta_schedule["n_timestep"],
(img.shape[0],),
device=get_current_device(),
)
loss_dict = diffusion.training_losses(model, x_start = target_img, t = time_t, cond_input = [img, target_pose], prob = 1 - guidance_prob)
loss = loss_dict['loss'].mean()
loss_mse = loss_dict['mse'].mean()
loss_vb = loss_dict['vb'].mean()
if booster is None:
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1)
scheduler.step()
optimizer.step()
loss = loss_dict['loss'].mean()
else:
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
loss_list.append(loss.detach().item())
loss_mean_list.append(loss_mse.detach().item())
loss_vb_list.append(loss_vb.detach().item())
accumulate(
ema, model.module, 0 if i < conf.training.scheduler.warmup else 0.9999
)
if i%args.save_checkpoints_every_iters == 0 and is_main_process():
if conf.distributed:
model_module = model.module
else:
model_module = model
torch.save(
{
"model": model_module.state_dict(),
"ema": ema.state_dict(),
"scheduler": scheduler.state_dict(),
"optimizer": optimizer.state_dict(),
"conf": conf,
},
conf.training.ckpt_path + f"/model_{str(i).zfill(6)}.pt"
)
if booster is None:
if is_main_process():
print ('Epoch Time '+str(int(time.time()-start_time))+' secs')
print ('Model Saved Successfully for #epoch '+str(epoch)+' #steps '+str(i))
if conf.distributed:
model_module = model.module
else:
model_module = model
torch.save(
{
"model": model_module.state_dict(),
"ema": ema.state_dict(),
"scheduler": scheduler.state_dict(),
"optimizer": optimizer.state_dict(),
"conf": conf,
},
conf.training.ckpt_path + '/last.pt'
)
else:
logger.info(f'Epoch Time {str(int(time.time()-start_time))} secs',rank=[0])
booster.save_model(model, f"{conf.training.ckpt_path} + '/last.pt'")
booster.save_model(ema,f"{conf.training.ckpt_path} + '/last_ema.pt'")
logger.info(f'Model Saved Successfully for #epoch {str(epoch)} #steps {str(i)}',rank=[0])
def main(settings, EXP_NAME):
[args, DiffConf, DataConf] = settings
# if is_main_process(): wandb.init(project="person-synthesis", name = EXP_NAME, settings = wandb.Settings(code_dir="."))
if DiffConf.ckpt is not None:
DiffConf.training.scheduler.warmup = 0
DiffConf.distributed = True
local_rank = int(os.environ['LOCAL_RANK'])
DataConf.data.train.batch_size = args.batch_size//2 #src -> tgt , tgt -> src
model = get_model_conf().make_model()
ema = get_model_conf().make_model()
if DiffConf.ckpt is not None:
ckpt = torch.load(DiffConf.ckpt, map_location=lambda storage, loc: storage)
if DiffConf.distributed:
model.module.load_state_dict(ckpt["model"])
else:
model.load_state_dict(ckpt["model"])
ema.load_state_dict(ckpt["ema"])
scheduler.load_state_dict(ckpt["scheduler"])
if is_main_process(): print ('model loaded successfully')
if os.getenv('use_colossalai') == '0':
model = model.to(args.device)
ema = ema.to(args.device)
else:
model = model.to(get_current_device(),dtype=torch.float16)
ema = ema.to(get_current_device(),dtype=torch.float16)
if DiffConf.distributed:
if os.getenv('use_colossalai') == '0':
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
find_unused_parameters=True
)
booster = None
else:
booster_kwargs = {}
plugin = GeminiPlugin(placement_policy='static', strict_ddp_mode=True, initial_scale=2 ** 5)
booster = Booster(plugin=plugin, **booster_kwargs)
if os.getenv('use_colossalai') == '0':
val_dataset, train_dataset = deepfashion_data.get_train_val_dataloader(DataConf.data, labels_required = True, distributed = True)
elif os.getenv('use_colossalai') == '1':
val_dataset,train_dataset = deepfashion_data.get_train_val_dataset(DataConf.data,True)
train_dataset = plugin.prepare_dataloader(train_dataset,num_workers=DataConf.data.train.batch_size*2,batch_size=DataConf.data.train.batch_size)
val_dataset = plugin.prepare_dataloader(train_dataset,num_workers=1,batch_size=1)
def cycle(iterable):
while True:
for x in iterable:
yield x
val_dataset = iter(cycle(val_dataset))
if os.getenv('use_colossalai') == '1':
optimizer = HybridAdam(model.parameters(), lr=2e-5, initial_scale=2**5,clipping_norm=1)
scheduler = DiffConf.training.scheduler.make(optimizer)
model, optimizer, _, _, scheduler = booster.boost(model, optimizer, lr_scheduler=scheduler)
else:
optimizer = DiffConf.training.optimizer.make(model.parameters())
scheduler = DiffConf.training.scheduler.make(optimizer)
betas = DiffConf.diffusion.beta_schedule.make()
diffusion = create_gaussian_diffusion(betas, predict_xstart = False)
train(
DiffConf, train_dataset, model, ema, diffusion, betas, optimizer, scheduler, args.guidance_prob, args.cond_scale, args.device,booster
)
if __name__ == "__main__":
if os.getenv('use_colossalai') == '0':
init_distributed()
else:
colossalai.launch_from_torch(config={})
import argparse
parser = argparse.ArgumentParser(description='help')
parser.add_argument('--exp_name', type=str, default='vto_model')
parser.add_argument('--DiffConfigPath', type=str, default='./config/diffusion.conf')
parser.add_argument('--DataConfigPath', type=str, default='./config/data_arcsoft.yaml')
parser.add_argument('--dataset_path', type=str, default='./')
parser.add_argument('--save_path', type=str, default='checkpoints')
parser.add_argument('--cond_scale', type=int, default=2)
parser.add_argument('--guidance_prob', type=int, default=0.1)
parser.add_argument('--sample_algorithm', type=str, default='ddim') # ddpm, ddim
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--save_wandb_logs_every_iters', type=int, default=200000)
parser.add_argument('--save_checkpoints_every_iters', type=int, default=2000)
parser.add_argument('--save_wandb_images_every_epochs', type=int, default=100000)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--n_gpu', type=int, default=8)
parser.add_argument('--n_machine', type=int, default=1)
parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument("opts", default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
print ('Experiment: '+ args.exp_name)
DiffConf = DiffConfig(DiffusionConfig, args.DiffConfigPath, args.opts, False)
DataConf = DataConfig(args.DataConfigPath)
DiffConf.training.ckpt_path = os.path.join(args.save_path, args.exp_name)
DataConf.data.path = args.dataset_path
if is_main_process():
if not os.path.isdir(args.save_path): os.mkdir(args.save_path)
if not os.path.isdir(DiffConf.training.ckpt_path): os.mkdir(DiffConf.training.ckpt_path)
# DiffConf.ckpt = "checkpoints/vto_model/last.pt"
# print("Loading model {}.".format(DiffConf.ckpt))
main(settings = [args, DiffConf, DataConf], EXP_NAME = args.exp_name)
and besides, there are some dtype errors in models. maybe you need change them to float16
besides,how to use the ema feature with colossalai fsdp plugin or gemini plugin? i found that using fsdp plugin will change the model parameters format. and i cant use ema in the end of the step based on the repo in my comment @Orion-Zheng
🐛 Describe the bug
i'm using the colossalai to train the tryon diffusion.but there is a comptatible error. the paper link is try on diffusion ("ZERO DDP error: the synchronization of gradients doesn't exit properly.", 'The most possible reason is that the model is not compatible with GeminiDDP.\n'
Environment
No response