Open MetaBlues opened 1 year ago
similarly, https://github.com/huggingface/diffusers/issues/4916 and https://github.com/huggingface/diffusers/issues/4006 also met this bug.
Is there any suggestions?
@MetaBlues thank you for the detailed write up and sorry to see that you are running into these errors. Can you share a bit more about what you have tried? For example, does this problem persist if run on just a single node?
Also, could you please provide a full reproducer that I could copy and run to produce the same error? This will help in trying to debug this error. Thanks!
@mrwyattii thx for your help, this problem can be reproduced on a a single node by deepspeed --include localhost:0,1,2,3 test_stage3_ckpt.py
. Either self.unet.disable_gradient_checkpointing() or switch to stage2 will prevent this issue.
test_stage3_ckpt.py:
import argparse
import datetime
import os
import traceback
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import deepspeed
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
class MinimalDiffusion(nn.Module):
def __init__(self) -> None:
super().__init__()
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
self.scheduler = DDPMScheduler.from_pretrained(
pretrained_model_name_or_path, subfolder="scheduler",
)
self.unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet",
)
self.vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae",
)
self.vae.requires_grad_(False)
self.unet.enable_xformers_memory_efficient_attention()
self.unet.enable_gradient_checkpointing()
def forward(self, **kwargs):
if self.training:
return self._forward_train(**kwargs)
else:
return self._forward_eval(**kwargs)
def _forward_train(
self,
*,
vae_t_image,
encoder_hidden_states,
**_,
):
latents = self.vae.encode(vae_t_image).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
bsz, ch, h, w = latents.shape
device = latents.device
noise = torch.randn_like(latents)
timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bsz,), device=device)
timesteps = timesteps.long()
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
model_pred = self.unet(
noisy_latents,
timesteps,
encoder_hidden_states,
).sample
loss = F.mse_loss(model_pred, noise, reduction="mean")
return loss
def _forward_eval(**kwargs):
pass
def train(self, mode: bool = True):
self.training = mode
self.vae.eval()
self.unet.train(mode)
return self
def init_distributed_mode(args):
assert torch.cuda.is_available() and torch.cuda.device_count() > 1
args.distributed = True
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.local_rank = int(os.environ['LOCAL_RANK'])
args.dist_backend = "nccl"
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
timeout=datetime.timedelta(0, 7200)
)
dist.barrier()
def get_parameter_groups(model):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim <= 1:
group_name = "wo_decay"
weight_decay = 0.
else:
group_name = "w_decay"
weight_decay = 0.01
lr = 5e-5
if group_name not in parameter_group_names:
parameter_group_names[group_name] = {
"params": [],
"weight_decay": weight_decay,
"lr": lr,
}
parameter_group_vars[group_name] = {
"params": [],
"weight_decay": weight_decay,
"lr": lr,
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
return list(parameter_group_vars.values())
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--distributed", default=False, action="store_true")
parser.add_argument("--world-size", default=1, type=int)
parser.add_argument("--rank", default=-1, type=int)
parser.add_argument("--gpu", default=-1, type=int)
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--dist-backend", default="nccl", type=str)
deepspeed.add_config_arguments(parser)
return parser.parse_args()
def main():
args = parse_args()
args.deepspeed_config = "deepspeed_config.json"
print(args)
model = MinimalDiffusion()
parameters = get_parameter_groups(model)
model, optimizer, _, _ = deepspeed.initialize(
args=args,
model=model,
model_parameters=parameters,
)
device = torch.device("cuda")
dtype = torch.float16
model.train()
batch_size = 8
steps = 100
for _ in range(steps):
vae_t_image = torch.randn(batch_size, 3, 512, 512, dtype=dtype, device=device)
encoder_hidden_states = torch.randn(batch_size, 77, 768, dtype=dtype, device=device)
loss = model(
vae_t_image=vae_t_image,
encoder_hidden_states=encoder_hidden_states,
)
model.backward(loss)
model.step()
torch.cuda.synchronize()
if __name__ == "__main__":
try:
main()
except Exception as ex:
print(ex)
print(traceback.format_exc())
and deepspeed_config.json
:
{
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 1,
"steps_per_print": 10,
"zero_allow_untested_optimizer": true,
"optimizer": {
"type": "Adam",
"adam_w_mode": true,
"params": {
"lr": 5e-05,
"weight_decay": 0.01,
"bias_correction": true,
"betas": [
0.9,
0.999
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": 100,
"warmup_min_lr": 0,
"warmup_max_lr": 5e-05,
"warmup_num_steps": 10,
"warmup_type": "linear"
}
},
"fp16": {
"enabled": true,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 15,
"loss_scale_window": 1000,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"reduce_bucket_size": 5e7,
"contiguous_gradients": true,
"stage3_prefetch_bucket_size" : 1e8,
"stage3_max_live_parameters" : 1e9,
"stage3_max_reuse_distance" : 1e9,
"stage3_param_persistence_threshold" : 1e6,
"sub_group_size" : 1e12,
"ignore_unused_parameters": true,
"stage3_gather_16bit_weights_on_model_save": true
}
}
Same problem here. Additional info from me here that might help: torch.utils.checkpoint conflicts with deepspeed zero3 (same output shape runtime error with above), however, if applied with lora using peft, zero3 works pretty well with torch.utils.checkpoint.
@MetaBlues thanks for the reproducer and @wizyoung thank you for the additional information. I just did some quick testing and I can confirm that I'm able to reproduce the error. It's not immediately clear what is causing this, but I believe it's due to a change in diffusers
. If I downgrade from the latest release to diffusers==0.16.1
I am able to successfully run the example. A commit between 0.16.1 and 0.17.0 is causing the error.
Can you both confirm that downgrading diffusers
avoids the error? I will continue to debug for the latest diffusers.
My environment:
accelerate 0.23.0
deepspeed 0.10.4+78c3b148
diffusers 0.17.0
torch 2.0.1
transformers 4.33.2
@mrwyattii downgrading diffusers to 0.16.1 really works. I'll see the difference between 0.16.1 and 0.17.0.
I'm not using diffusers, I just applied zero3 on huggingface llama2 and error occurs when using sft and gradient checkpointing is enabled. I did a simple error trace and found out all this error points to the torch.utils.checkpoint function inside llama2 code in huggingface. @MetaBlues waiting for your findings. :) My env:
accelerate 0.22.0
deepspeed 0.10.0
torch 1.13.1+cu117
transformers 4.33.1
@MetaBlues have you got any clues?
@MetaBlues More information which may be helpful. I find that in the newer version of diffusers(>=0.17.0), the parameter "use_reentrant=False" is passed to torch.utils.checkpoint.checkpoint function in unet_2d_blocks.py, vae.py and transformer_2d.py which may cause this error. If I remove all "use_reentrant=False", the model can be trained normally (I have not tested the performance of the trained model, but at least the backward step works fine).
@MetaBlues More information which may be helpful. I find that in the newer version of diffusers(>=0.17.0), the parameter "use_reentrant=False" is passed to torch.utils.checkpoint.checkpoint function in unet_2d_blocks.py, vae.py and transformer_2d.py which may cause this error. If I remove all "use_reentrant=False", the model can be trained normally (I have not tested the performance of the trained model, but at least the backward step works fine).
@ryanzhangfan @mrwyattii I find that deepspeed has a 'checkpointing.non_reentrant_checkpoint' which was only used by runtime.pipe. Maybe we need a method to handle this use_reentrant
thing automatically.
@MetaBlues @ryanzhangfan Great findings! After setting use_reentrant as True, my zero3 training pipeline goes normal. According to the official pytorch docs here and https://github.com/huggingface/transformers/issues/21381, we need to carefully handle conditions where part of the model is frozen, i.e., not wrapping frozen part inside checkpointing function. The future version of pytorch will set this argument to False.
One simple fix is to place the following codes on the top of the imports:
from functools import partial
import torch.utils.checkpoint
torch.utils.checkpoint.checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=True)
Describe the bug I want to use torch.utils.checkpoint() in
diffusers.models.unet_3d_blocks
to reduce VRAM occupied like this:Yet when I used this new block in unet training I encountered an error said "RuntimeError: output with shape [0] doesn't match the broadcast shape [1280, 1280, 3, 1, 0]", and above error was gone when I set
unet.disable_gradient_checkpointing()
.To Reproduce Steps to reproduce the behavior:
Expected behavior Use stage3 for UNet3DConditionModel training with unet.enable_gradient_checkpointing().
ds_report output Please run
ds_report
to give us details about your setup.Screenshots If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
Launcher context Accelerate launcher
Deepspeed config: