pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.14k stars 22.42k forks source link

[FSDP] FSDP wrapped UNet model running backward crashes #82201

Closed XiangfanLi closed 2 years ago

XiangfanLi commented 2 years ago

šŸ› Describe the bug


### **The error stack traces**

```bash
creating model and diffusion...
creating data loader...
training...
/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py:9: UserWarning: is_namedtuple is deprecated, please use the python checks instead
  warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
[forward_backward] forward finished
Asserting FSDP instance is: FullyShardedDataParallel(
  (_fsdp_wrapped_module): FlattenParamsWrapper(
    (_fpw_module): Sequential(
      (0): FullyShardedDataParallel(
        (_fsdp_wrapped_module): FlattenParamsWrapper(
          (_fpw_module): GroupNorm32(32, 192, eps=1e-05, affine=True)
        )
      )
      (1): FullyShardedDataParallel(
        (_fsdp_wrapped_module): FlattenParamsWrapper(
          (_fpw_module): SiLU()
        )
      )
      (2): FullyShardedDataParallel(
        (_fsdp_wrapped_module): FlattenParamsWrapper(
          (_fpw_module): Dropout(p=0.1, inplace=False)
        )
      )
      (3): FullyShardedDataParallel(
        (_fsdp_wrapped_module): FlattenParamsWrapper(
          (_fpw_module): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
  )
)
ERROR: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>] but current state is TrainingState_.IDLE
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 3178, in _wait_for_post_backward
    m._assert_state(TrainingState_.BACKWARD_PRE)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 3560, in _assert_state
    traceback.print_stack()
Traceback (most recent call last):
  File "scripts/image_train.py", line 91, in <module>
    main()
  File "scripts/image_train.py", line 45, in main
    TrainLoop(
  File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 227, in run_loop
    self.run_step(batch, cond)
  File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 262, in run_step
    self.forward_backward(batch, cond)
  File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 304, in forward_backward
    self.mp_trainer.backward(loss)
  File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/fp16_util.py", line 188, in backward
    self.scaler.scale(loss).backward()
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/_tensor.py", line 484, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 191, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f8a6e2e3f80> returned NULL without setting an error
Traceback (most recent call last):
  File "scripts/image_train.py", line 91, in <module>
    main()
  File "scripts/image_train.py", line 45, in main
    TrainLoop(
  File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 227, in run_loop
    self.run_step(batch, cond)
  File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 262, in run_step
    self.forward_backward(batch, cond)
  File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 304, in forward_backward
    self.mp_trainer.backward(loss)
  File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/fp16_util.py", line 188, in backward
    self.scaler.scale(loss).backward()
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/_tensor.py", line 484, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 191, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f8a6e2e3f80> returned NULL without setting an error

Versions

Collecting environment information... PyTorch version: 1.12.0+cu113 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.31

Python version: 3.8.10 (default, Jun 22 2022, 20:18:18) [GCC 9.4.0] (64-bit runtime) Python platform: Linux-5.15.0-1013-oracle-x86_64-with-glibc2.29 Is CUDA available: True CUDA runtime version: Could not collect GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB GPU 1: NVIDIA A100-SXM4-40GB GPU 2: NVIDIA A100-SXM4-40GB GPU 3: NVIDIA A100-SXM4-40GB GPU 4: NVIDIA A100-SXM4-40GB GPU 5: NVIDIA A100-SXM4-40GB GPU 6: NVIDIA A100-SXM4-40GB GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 470.129.06 cuDNN version: Probably one of the following: /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.1 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] ema-pytorch==0.0.8 [pip3] numpy==1.23.1 [pip3] pytorch-warmup==0.1.0 [pip3] torch==1.12.0+cu113 [pip3] torchaudio==0.12.0+rocm5.1.1 [pip3] torchvision==0.13.0+cu113

XiangfanLi commented 2 years ago

Some forward in the Unet definition are calling forward on elements of nn.Sequential separately, after I removed those "split forward", this issue disappeared.

Screenshot 2022-07-26 at 6 52 23 PM Screenshot 2022-07-26 at 6 52 33 PM
xdd12135 commented 2 years ago

how,please be more specific.