hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.32k stars 4.3k forks source link

[BUG]: booster for booster.backward(loss, optimizer) in stable diffusion dreambooth inpainting #5305

Open shileims opened 5 months ago

shileims commented 5 months ago

🐛 Describe the bug

The error happens in booster.backward(loss, optimizer), I used GeminiPlugin

ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 2044) of binary: /opt/conda/envs/pytorch/bin/python

Environment

linux, cuda11.7 torch1.13.1

shileims commented 5 months ago

When I use the same booster setting for another model, it works very well.

def main(args): if args.seed is None: colossalai.launch_from_torch(config={}) else: colossalai.launch_from_torch(config={}, seed=args.seed)

global_rank = dist.get_rank()
# local_rank = dist.get_local_rank()
local_rank = args.local_rank
world_size = dist.get_world_size()
logger.info(f'Global rank: {global_rank}')
logger.info(f'Local rank: {local_rank}')
logger.info(f'World  size: {world_size}')

# Handle the repository creation
if global_rank == 0:
    assert args.output_dir is not None, f'output dir could not be none'
    os.makedirs(args.output_dir, exist_ok=True)
    import json
    with open(os.path.join(args.output_dir, 'commandline_args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

assert os.path.exists(args.pretrained_model_name_or_path), f'pretrained_model_name_or_path not exist'
logger.info(f'Pretrained model path is {args.pretrained_model_name_or_path}', ranks=[0])
# Load models and create wrapper for stable diffusion
logger.info(f"Loading tokenizer from pretrained model", ranks=[0])
tokenizer = AutoTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )

logger.info("Loading text encoder from pretrained model", ranks=[0])
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
text_encoder = text_encoder_cls.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="text_encoder",
    revision=args.revision,
)
logger.info("Loading autoencoder from pretrained model", ranks=[0])
vae = AutoencoderKL.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="vae",
    revision=args.revision,
)
logger.info("Loading unet from pretrained model", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
    )

vae.requires_grad_(False)
text_encoder.requires_grad_(False)
assert not args.train_text_encoder, f'colossalai gemini and low_level_zero do not support train two models together'

if args.gradient_checkpointing:
    unet.enable_gradient_checkpointing()

if args.scale_lr:
    args.learning_rate = (
        args.learning_rate * (args.gradient_accumulation_steps * args.train_batch_size * world_size / args.sd_batchsize)
    )
    if args.gradient_accumulation_steps > 1:
        assert args.plugin != "low_level_zero", f'Low_level_zero does not support graident accumulation'

if args.sacle_lr_warmup_steps:
    args.lr_warmup_steps = int(args.lr_warmup_steps * (args.gradient_accumulation_steps * args.train_batch_size * world_size / args.sd_batchsize))

# Use Booster API to use Gemini/Zero with ColossalAI

booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
    booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
    plugin = TorchDDPPlugin()
elif args.plugin == "gemini":
    plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5, enable_gradient_accumulation=True)
elif args.plugin == "low_level_zero":
    plugin = LowLevelZeroPlugin(initial_scale=2**5)

booster = Booster(plugin=plugin, **booster_kwargs)

# config optimizer for colossalai zero
optimizer = HybridAdam(
    unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
)

noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")

weight_dtype = torch.float32
if args.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16
else:
    raise NotImplementedError

train_dataset = DreamBoothInpaintingDataset(
    txt_path=args.txt_path,
    tokenizer=tokenizer,
    global_rank=global_rank,
    world_size=world_size,
    batch_size=args.train_batch_size,
    size=args.resolution,
    center_crop=args.center_crop,
    write_training_data=args.write_training_data,
    debug=args.debug
)

def collate_fn(examples):

    input_ids = [example["prompt_ids"] for example in examples]

    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.
    gt_imgs = []
    ori_gt_imgs = []
    masks = []
    masked_images = []
    for example in examples:
        masked_image = example['masked_image']
        mask = example['mask']
        gt_img = example['gt_image']
        ori_gt_img = example['ori_gt_image']
        ori_gt_img = ori_gt_img / 255.0
        ori_gt_img = torch.from_numpy(ori_gt_img)

        masks.append(mask)
        masked_images.append(masked_image)
        ori_gt_imgs.append(ori_gt_img)
        gt_imgs.append(gt_img)

    gt_imgs = torch.stack(gt_imgs)
    gt_imgs = gt_imgs.to(memory_format=torch.contiguous_format).float()
    gt_imgs = gt_imgs.to(dtype=weight_dtype)

    ori_gt_imgs = torch.stack(ori_gt_imgs)
    ori_gt_imgs = ori_gt_imgs.to(memory_format=torch.contiguous_format).float()
    ori_gt_imgs = ori_gt_imgs.to(dtype=weight_dtype)

    input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids

    masks = torch.stack(masks)
    masks = masks.to(dtype=weight_dtype)
    masked_images = torch.stack(masked_images)
    masked_images = masked_images.to(dtype=weight_dtype)

    batch = {"input_ids": input_ids, "gt_imgs": gt_imgs, "masks": masks, "masked_images": masked_images, 'ori_gt_imgs': ori_gt_imgs}
    return batch

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4
)

dist.barrier()

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * world_size,
    num_training_steps=args.max_train_steps * world_size,
    num_cycles=args.lr_num_cycles,
    power=args.lr_power,
)

# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae.to(get_current_device(), dtype=weight_dtype)
text_encoder.to(get_current_device(), dtype=weight_dtype)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)

# Train!
total_batch_size = args.train_batch_size * world_size

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}", ranks=[0])
logger.info(f"  Num batches each epoch = {len(train_dataloader)}", ranks=[0])
logger.info(f"  Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f"  Total optimization steps = {args.max_train_steps}", ranks=[0])

if global_rank == 0:
    import json
    with open(os.path.join(args.output_dir, 'commandline_args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not global_rank==0)
progress_bar.set_description("Steps")
global_step = 0
torch.cuda.synchronize()
optimizer.zero_grad()
for epoch in range(args.num_train_epochs):
    unet.train()
    for step, batch in enumerate(train_dataloader):
        torch.cuda.reset_peak_memory_stats()
        # Move batch to gpu
        for key, value in batch.items():
            batch[key] = value.to(get_current_device(), non_blocking=True)

        optimizer.zero_grad()

        latents = vae.encode(batch["gt_imgs"].to(dtype=weight_dtype)).latent_dist.sample()
        latents = latents * 0.18215

        if global_rank == 0:
            logger.info('gt imgs latent done')

        # Convert masked images to latent space
        masked_latents = vae.encode(
            batch["masked_images"].reshape(batch["gt_imgs"].shape).to(dtype=weight_dtype)
        ).latent_dist.sample()
        masked_latents = masked_latents * 0.18215

        if global_rank == 0:
            logger.info('masked imgs latent done')

        masks = batch["masks"]
        # resize the mask to latents shape as we concatenate the mask to the latents
        mask = torch.nn.functional.interpolate(masks, size=(args.resolution // 8, args.resolution // 8))

        mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # concatenate the noised latents with the mask and the masked latents
        latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)

        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(batch["input_ids"])[0]

        if global_rank == 0:
            logger.info('text latent done')

        torch.cuda.empty_cache()

        # Predict the noise residual
        noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample

        if global_rank == 0:
            logger.info('unet done')

        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # Compute instance loss
        if args.snr_gamma is None:
            loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
        else:
            # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
            # Since we predict the noise instead of x_0, the original formulation is slightly changed.
            # This is discussed in Section 4.2 of the same paper.
            snr = compute_snr(noise_scheduler, timesteps)
            base_weight = (
                torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
            )

            if noise_scheduler.config.prediction_type == "v_prediction":
                # Velocity objective needs to be floored to an SNR weight of one.
                mse_loss_weights = base_weight + 1
            else:
                # Epsilon and sample both use the same loss weights.
                mse_loss_weights = base_weight
            loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none")
            loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
            loss = loss.mean()          

        # loss = loss / args.gradient_accumulation_steps

        if global_rank == 0:
            logger.info('loss computation done')

        # optimizer.backward(loss)
        booster.backward(loss, optimizer)

        progress_bar.update(1)
        global_step += 1

        if global_rank == 0:
            logger.info('loss backward done')

        optimizer.step()
        lr_scheduler.step()

        # if args.gradient_accumulation_steps == 1 or global_step % (args.gradient_accumulation_steps - 1) == 0:
        #     optimizer.step()
        #     lr_scheduler.step()
        #     optimizer.zero_grad()

        if global_rank == 0:
            logger.info('optimizer step lr scheduler done')

        logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
        # Checks if the accelerator has performed an optimization step behind the scenes

        logs = {
            "loss": loss.detach().item(),
            "lr": optimizer.param_groups[0]["lr"],
        }  # lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)

        if args.save_steps != -1 and global_step % args.save_steps == 0:
            torch.cuda.synchronize()
            save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
            if global_rank == 0:
                os.makedirs(save_path, exist_ok=True)
            booster.save_model(unet, os.path.join(save_path, 'unet', "diffusion_pytorch_model.bin"))
            if global_rank == 0:
                if not os.path.exists(os.path.join(save_path, "config.json")):
                    shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
                logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
        if global_step >= args.max_train_steps:
            break

    torch.cuda.synchronize()
    save_path = os.path.join(args.output_dir, f"{epoch}_-1")
    if global_rank == 0:
        os.makedirs(save_path, exist_ok=True)
    booster.save_model(unet, os.path.join(save_path, 'unet', "diffusion_pytorch_model.bin"))
    logger.info(f"Saving model checkpoint to {args.output_dir} on rank {global_rank}")
    if global_rank == 0:
        if not os.path.exists(os.path.join(args.output_dir, "config.json")):
            shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)

if name == "main": args = parse_args() main(args)