I tried to use FSDP to wrap a Unet model, but somehow during backward some assert failed suggesting that the the FSDP model has unexpected states. Seems that the pre_backward_hook was not properly triggered for some FSDP wrappers.
For the auto_wrap_policy, I am using functools.partial(size_based_auto_wrap_policy, min_num_params=int(0)) to recursively wrap all sub modules of the Unet.
### **The error stack traces**
```bash
creating model and diffusion...
creating data loader...
training...
/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py:9: UserWarning: is_namedtuple is deprecated, please use the python checks instead
warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
[forward_backward] forward finished
Asserting FSDP instance is: FullyShardedDataParallel(
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): Sequential(
(0): FullyShardedDataParallel(
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): GroupNorm32(32, 192, eps=1e-05, affine=True)
)
)
(1): FullyShardedDataParallel(
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): SiLU()
)
)
(2): FullyShardedDataParallel(
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): Dropout(p=0.1, inplace=False)
)
)
(3): FullyShardedDataParallel(
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
)
)
ERROR: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>] but current state is TrainingState_.IDLE
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 3178, in _wait_for_post_backward
m._assert_state(TrainingState_.BACKWARD_PRE)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 3560, in _assert_state
traceback.print_stack()
Traceback (most recent call last):
File "scripts/image_train.py", line 91, in <module>
main()
File "scripts/image_train.py", line 45, in main
TrainLoop(
File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 227, in run_loop
self.run_step(batch, cond)
File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 262, in run_step
self.forward_backward(batch, cond)
File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 304, in forward_backward
self.mp_trainer.backward(loss)
File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/fp16_util.py", line 188, in backward
self.scaler.scale(loss).backward()
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/_tensor.py", line 484, in backward
torch.autograd.backward(
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 191, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f8a6e2e3f80> returned NULL without setting an error
Traceback (most recent call last):
File "scripts/image_train.py", line 91, in <module>
main()
File "scripts/image_train.py", line 45, in main
TrainLoop(
File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 227, in run_loop
self.run_step(batch, cond)
File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 262, in run_step
self.forward_backward(batch, cond)
File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/train_util.py", line 304, in forward_backward
self.mp_trainer.backward(loss)
File "/home/ubuntu/lixf/diffusion_zero3/guided_diffusion/fp16_util.py", line 188, in backward
self.scaler.scale(loss).backward()
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/_tensor.py", line 484, in backward
torch.autograd.backward(
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 191, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f8a6e2e3f80> returned NULL without setting an error
Versions
Collecting environment information...
PyTorch version: 1.12.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.8.10 (default, Jun 22 2022, 20:18:18) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1013-oracle-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB
Nvidia driver version: 470.129.06
cuDNN version: Probably one of the following:
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Some forward in the Unet definition are calling forward on elements of nn.Sequential separately, after I removed those "split forward", this issue disappeared.
š Describe the bug
I tried to use FSDP to wrap a Unet model, but somehow during
backward
some assert failed suggesting that the the FSDP model has unexpected states. Seems that thepre_backward_hook
was not properly triggered for some FSDP wrappers.For the
auto_wrap_policy
, I am usingfunctools.partial(size_based_auto_wrap_policy, min_num_params=int(0))
to recursively wrap all sub modules of the Unet.The Unet model definition I am using: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py (with some trivial code changes to make FSDP works), and I am running script https://github.com/openai/guided-diffusion/blob/main/scripts/image_train.py to train a Unet model.
Commands
Versions
Collecting environment information... PyTorch version: 1.12.0+cu113 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.4 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.31
Python version: 3.8.10 (default, Jun 22 2022, 20:18:18) [GCC 9.4.0] (64-bit runtime) Python platform: Linux-5.15.0-1013-oracle-x86_64-with-glibc2.29 Is CUDA available: True CUDA runtime version: Could not collect GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB GPU 1: NVIDIA A100-SXM4-40GB GPU 2: NVIDIA A100-SXM4-40GB GPU 3: NVIDIA A100-SXM4-40GB GPU 4: NVIDIA A100-SXM4-40GB GPU 5: NVIDIA A100-SXM4-40GB GPU 6: NVIDIA A100-SXM4-40GB GPU 7: NVIDIA A100-SXM4-40GB
Nvidia driver version: 470.129.06 cuDNN version: Probably one of the following: /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.1 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.1 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
Versions of relevant libraries: [pip3] ema-pytorch==0.0.8 [pip3] numpy==1.23.1 [pip3] pytorch-warmup==0.1.0 [pip3] torch==1.12.0+cu113 [pip3] torchaudio==0.12.0+rocm5.1.1 [pip3] torchvision==0.13.0+cu113