microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.29k stars 4.09k forks source link

[BUG][stage3] using torch.utils.checkpoint in unet_3d_blocks causes weights size [0] error #4332

Open MetaBlues opened 1 year ago

MetaBlues commented 1 year ago

Describe the bug I want to use torch.utils.checkpoint() in diffusers.models.unet_3d_blocks to reduce VRAM occupied like this:

class DownBlock3D(DownBlock3D):
    def forward(self, hidden_states, temb=None, num_frames=1):
        output_states = ()

        for resnet, temp_conv in zip(self.resnets, self.temp_convs):
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                if is_torch_version(">=", "1.11.0"):
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
                    )
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False
                    )
                else:
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(temp_conv), hidden_states, num_frames
                    )
            else:
                hidden_states = resnet(hidden_states, temb)
                hidden_states = temp_conv(hidden_states, num_frames=num_frames)

            output_states += (hidden_states,)

        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

            output_states += (hidden_states,)

        return hidden_states, output_states

Yet when I used this new block in unet training I encountered an error said "RuntimeError: output with shape [0] doesn't match the broadcast shape [1280, 1280, 3, 1, 0]", and above error was gone when I set unet.disable_gradient_checkpointing().

To Reproduce Steps to reproduce the behavior:

  1. Override DownBlock3D.forward() method with torch.utils.checkpoint like above
  2. use stage3 for training UNet3DConditionModel:
    with ContextManagers(deepspeed_zero3_init_enabled_context_manager()):
    unet = UNet3DConditionModel.from_pretrained(ckpt_path, subfolder="unet") 
    unet._supports_gradient_checkpointing = True
    unet.enable_gradient_checkpointing()

    deepspeed_zero3_init_enabled_context_manager() is a method inspired by here. I use huggingface.accelerate for deepspeed launch and set zero_init_flag to False, thus I only enable zero.init for unet.from_pretrained().

Expected behavior Use stage3 for UNet3DConditionModel training with unet.enable_gradient_checkpointing().

ds_report output Please run ds_report to give us details about your setup.

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/root/miniconda3/envs/base/lib/python3.10/site-packages/torch']
torch version .................... 2.0.1+cu117
deepspeed install path ........... ['/root/miniconda3/envs/base/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.10.3
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.7
shared memory (/dev/shm) size .... 251.52 GB

Screenshots If applicable, add screenshots to help explain your problem.

Traceback (most recent call last):
  File "/unet3d_training/main.py", line 21, in <module>
    eval(cfg.common.mode)(cfg)
  File "/unet3d_training/main.py", line 7, in train
    trainer.train()
  File "/unet3d_training/src/trainers/base.py", line 360, in train
    self.accelerator.backward(loss["all"])
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/accelerate/accelerator.py", line 1847, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 167, in backward
    self.engine.backward(loss, **kwargs)
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1923, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2080, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/root/miniconda3/envs/base/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: output with shape [0] doesn't match the broadcast shape [1280, 1280, 3, 1, 0]

System info (please complete the following information):

Launcher context Accelerate launcher

compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_config_file: /path_to/zero_stage3_config.json
  deepspeed_multinode_launcher: standard
  zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_process_ip: MASTER_IP
main_process_port: MASTER_PORT
main_training_function: main
num_machines: 2
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Deepspeed config:

{
    "fp16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "reduce_bucket_size": 5e7,
        "contiguous_gradients": true,

        "stage3_prefetch_bucket_size" : 1e8,
        "stage3_max_live_parameters" : 1e9,
        "stage3_max_reuse_distance" : 1e9,
        "stage3_param_persistence_threshold" : 1e6,
        "sub_group_size" : 1e12,
        "ignore_unused_parameters": true,
        "stage3_gather_16bit_weights_on_model_save": true,

        "zero_hpz_partition_size": 2,
        "zero_quantized_weights": true,
        "zero_quantized_gradients": true
    },
    "steps_per_print": 100,
    "train_micro_batch_size_per_gpu": "8",
    "gradient_accumulation_steps": "1",
    "train_batch_size": "auto"
}
MetaBlues commented 1 year ago

similarly, https://github.com/huggingface/diffusers/issues/4916 and https://github.com/huggingface/diffusers/issues/4006 also met this bug.

MetaBlues commented 1 year ago

Is there any suggestions?

mrwyattii commented 1 year ago

@MetaBlues thank you for the detailed write up and sorry to see that you are running into these errors. Can you share a bit more about what you have tried? For example, does this problem persist if run on just a single node?

Also, could you please provide a full reproducer that I could copy and run to produce the same error? This will help in trying to debug this error. Thanks!

MetaBlues commented 1 year ago

@mrwyattii thx for your help, this problem can be reproduced on a a single node by deepspeed --include localhost:0,1,2,3 test_stage3_ckpt.py. Either self.unet.disable_gradient_checkpointing() or switch to stage2 will prevent this issue.

test_stage3_ckpt.py:

import argparse
import datetime
import os
import traceback

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

import deepspeed
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler

class MinimalDiffusion(nn.Module):

    def __init__(self) -> None:
        super().__init__()

        pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"

        self.scheduler = DDPMScheduler.from_pretrained(
            pretrained_model_name_or_path, subfolder="scheduler",
        )
        self.unet = UNet2DConditionModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="unet",
        )
        self.vae = AutoencoderKL.from_pretrained(
            pretrained_model_name_or_path, subfolder="vae",
        )

        self.vae.requires_grad_(False)

        self.unet.enable_xformers_memory_efficient_attention()
        self.unet.enable_gradient_checkpointing()

    def forward(self, **kwargs):
        if self.training:
            return self._forward_train(**kwargs)
        else:
            return self._forward_eval(**kwargs)

    def _forward_train(
        self,
        *,
        vae_t_image,
        encoder_hidden_states,
        **_,
    ):
        latents = self.vae.encode(vae_t_image).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor

        bsz, ch, h, w = latents.shape
        device = latents.device

        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bsz,), device=device)
        timesteps = timesteps.long()
        noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)

        model_pred = self.unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states,
        ).sample

        loss = F.mse_loss(model_pred, noise, reduction="mean")

        return loss

    def _forward_eval(**kwargs):
        pass

    def train(self, mode: bool = True):
        self.training = mode
        self.vae.eval()
        self.unet.train(mode)
        return self

def init_distributed_mode(args):
    assert torch.cuda.is_available() and torch.cuda.device_count() > 1

    args.distributed = True
    args.rank = int(os.environ["RANK"])
    args.world_size = int(os.environ['WORLD_SIZE'])
    args.local_rank = int(os.environ['LOCAL_RANK'])
    args.dist_backend = "nccl"

    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(
        backend=args.dist_backend,
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
        timeout=datetime.timedelta(0, 7200)
    )
    dist.barrier()

def get_parameter_groups(model):
    parameter_group_names = {}
    parameter_group_vars = {}

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if param.ndim <= 1:
            group_name = "wo_decay"
            weight_decay = 0.
        else:
            group_name = "w_decay"
            weight_decay = 0.01

        lr = 5e-5

        if group_name not in parameter_group_names:
            parameter_group_names[group_name] = {
                "params": [],
                "weight_decay": weight_decay,
                "lr": lr,
            }
            parameter_group_vars[group_name] = {
                "params": [],
                "weight_decay": weight_decay,
                "lr": lr,
            }

        parameter_group_vars[group_name]["params"].append(param)
        parameter_group_names[group_name]["params"].append(name)

    return list(parameter_group_vars.values())

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--distributed", default=False, action="store_true")
    parser.add_argument("--world-size", default=1, type=int)
    parser.add_argument("--rank", default=-1, type=int)
    parser.add_argument("--gpu", default=-1, type=int)
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--dist-backend", default="nccl", type=str)
    deepspeed.add_config_arguments(parser)

    return parser.parse_args()

def main():
    args = parse_args()
    args.deepspeed_config = "deepspeed_config.json"
    print(args)
    model = MinimalDiffusion()
    parameters = get_parameter_groups(model)

    model, optimizer, _, _ = deepspeed.initialize(
        args=args,
        model=model,
        model_parameters=parameters,
    )

    device = torch.device("cuda")
    dtype = torch.float16
    model.train()
    batch_size = 8
    steps = 100
    for _ in range(steps):
        vae_t_image = torch.randn(batch_size, 3, 512, 512, dtype=dtype, device=device)
        encoder_hidden_states = torch.randn(batch_size, 77, 768, dtype=dtype, device=device)

        loss = model(
            vae_t_image=vae_t_image,
            encoder_hidden_states=encoder_hidden_states,
        )

        model.backward(loss)
        model.step()
        torch.cuda.synchronize()

if __name__ == "__main__":
    try:
        main()
    except Exception as ex:
        print(ex)
        print(traceback.format_exc())

and deepspeed_config.json:

{
  "train_micro_batch_size_per_gpu": 8,
  "gradient_accumulation_steps": 1,
  "steps_per_print": 10,
  "zero_allow_untested_optimizer": true,
  "optimizer": {
    "type": "Adam",
    "adam_w_mode": true,
    "params": {
      "lr": 5e-05,
      "weight_decay": 0.01,
      "bias_correction": true,
      "betas": [
        0.9,
        0.999
      ],
      "eps": 1e-08
    }
  },
  "scheduler": {
    "type": "WarmupDecayLR",
    "params": {
      "total_num_steps": 100,
      "warmup_min_lr": 0,
      "warmup_max_lr": 5e-05,
      "warmup_num_steps": 10,
      "warmup_type": "linear"
    }
  },
  "fp16": {
    "enabled": true,
    "auto_cast": false,
    "loss_scale": 0,
    "initial_scale_power": 15,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "consecutive_hysteresis": false,
    "min_loss_scale": 1
  },
  "gradient_clipping": 1.0,
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "reduce_bucket_size": 5e7,
    "contiguous_gradients": true,

    "stage3_prefetch_bucket_size" : 1e8,
    "stage3_max_live_parameters" : 1e9,
    "stage3_max_reuse_distance" : 1e9,
    "stage3_param_persistence_threshold" : 1e6,
    "sub_group_size" : 1e12,
    "ignore_unused_parameters": true,
    "stage3_gather_16bit_weights_on_model_save": true
  }
}
wizyoung commented 1 year ago

Same problem here. Additional info from me here that might help: torch.utils.checkpoint conflicts with deepspeed zero3 (same output shape runtime error with above), however, if applied with lora using peft, zero3 works pretty well with torch.utils.checkpoint.

mrwyattii commented 1 year ago

@MetaBlues thanks for the reproducer and @wizyoung thank you for the additional information. I just did some quick testing and I can confirm that I'm able to reproduce the error. It's not immediately clear what is causing this, but I believe it's due to a change in diffusers. If I downgrade from the latest release to diffusers==0.16.1 I am able to successfully run the example. A commit between 0.16.1 and 0.17.0 is causing the error.

Can you both confirm that downgrading diffusers avoids the error? I will continue to debug for the latest diffusers.

My environment:

accelerate                              0.23.0
deepspeed                               0.10.4+78c3b148
diffusers                               0.17.0
torch                                   2.0.1
transformers                            4.33.2
MetaBlues commented 1 year ago

@mrwyattii downgrading diffusers to 0.16.1 really works. I'll see the difference between 0.16.1 and 0.17.0.

wizyoung commented 1 year ago

I'm not using diffusers, I just applied zero3 on huggingface llama2 and error occurs when using sft and gradient checkpointing is enabled. I did a simple error trace and found out all this error points to the torch.utils.checkpoint function inside llama2 code in huggingface. @MetaBlues waiting for your findings. :) My env:

accelerate                              0.22.0
deepspeed                               0.10.0
torch                                   1.13.1+cu117
transformers                            4.33.1
wizyoung commented 1 year ago

@MetaBlues have you got any clues?

ryanzhangfan commented 1 year ago

@MetaBlues More information which may be helpful. I find that in the newer version of diffusers(>=0.17.0), the parameter "use_reentrant=False" is passed to torch.utils.checkpoint.checkpoint function in unet_2d_blocks.py, vae.py and transformer_2d.py which may cause this error. If I remove all "use_reentrant=False", the model can be trained normally (I have not tested the performance of the trained model, but at least the backward step works fine).

MetaBlues commented 1 year ago

@MetaBlues More information which may be helpful. I find that in the newer version of diffusers(>=0.17.0), the parameter "use_reentrant=False" is passed to torch.utils.checkpoint.checkpoint function in unet_2d_blocks.py, vae.py and transformer_2d.py which may cause this error. If I remove all "use_reentrant=False", the model can be trained normally (I have not tested the performance of the trained model, but at least the backward step works fine).

@ryanzhangfan @mrwyattii I find that deepspeed has a 'checkpointing.non_reentrant_checkpoint' which was only used by runtime.pipe. Maybe we need a method to handle this use_reentrant thing automatically.

wizyoung commented 1 year ago

@MetaBlues @ryanzhangfan Great findings! After setting use_reentrant as True, my zero3 training pipeline goes normal. According to the official pytorch docs here and https://github.com/huggingface/transformers/issues/21381, we need to carefully handle conditions where part of the model is frozen, i.e., not wrapping frozen part inside checkpointing function. The future version of pytorch will set this argument to False.

One simple fix is to place the following codes on the top of the imports:

from functools import partial
import torch.utils.checkpoint
torch.utils.checkpoint.checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=True)