facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.69k stars 621 forks source link

Activation checkpointing on fused SwiGLU is not working when AMP is enabled. #1151

Closed warpuv closed 1 week ago

warpuv commented 1 week ago

🐛 Bug

Activation checkpointing on SwiGLU is not working when AMP is enabled. When AMP is disabled everything works as expected.

To Reproduce

Steps to reproduce the behavior:

import torch
from torch.utils.checkpoint import checkpoint
from xformers.ops import SwiGLU

loss_fn = torch.nn.MSELoss()
scaler = torch.amp.GradScaler()

def main():
    device = 'cuda'
    dtype = torch.float32

    shape = (64, 128, 256, 128)
    in_features, hidden_features, out_features = shape[1:]

    x = torch.randn(shape[:2], dtype=dtype, device=device)
    x.requires_grad_(True)
    model = SwiGLU(in_features, hidden_features, out_features, bias=True, _pack_weights=True).to(device=device, dtype=dtype)

    with torch.autocast(device_type='cuda', dtype=torch.float16):
        y = checkpoint(model, x, use_reentrant=False)
        y = loss_fn(y, y.detach())
    scaler.scale(y).backward()

if __name__ == '__main__':
    main()
Traceback (most recent call last):
  File "/workspace/src/my/fused_swiglu/proof_with_amp.py", line 25, in <module>
    main()
  File "/workspace/src/my/fused_swiglu/proof_with_amp.py", line 22, in main
    scaler.scale(y).backward()
  File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1129, in unpack_hook
    frame.check_recomputed_tensors_match(gid)
  File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 865, in check_recomputed_tensors_match
    raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 5
Number of tensors saved during recomputation: 0

Environment

Collecting environment information...
PyTorch version: 2.5.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.0
Libc version: glibc-2.35

Python version: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.13.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB

Nvidia driver version: 560.28.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   43 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          112
On-line CPU(s) list:             0-111
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC Processor
CPU family:                      23
Model:                           1
Thread(s) per core:              2
Core(s) per socket:              56
Socket(s):                       1
Stepping:                        2
BogoMIPS:                        3999.99
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 arat umip
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       1.8 MiB (56 instances)
L1i cache:                       3.5 MiB (56 instances)
L2 cache:                        28 MiB (56 instances)
L3 cache:                        8 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-111
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; LFENCE, IBPB conditional, STIBP conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-copyright==0.2.4
[pip3] lion-pytorch==0.2.2
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.4.0
[pip3] rotary-embedding-torch==0.8.4
[pip3] torch==2.5.0
[pip3] torchaudio==2.5.1
[pip3] torchmetrics==1.5.1
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] Could not collect