Activation checkpointing on SwiGLU is not working when AMP is enabled.
When AMP is disabled everything works as expected.
To Reproduce
Steps to reproduce the behavior:
import torch
from torch.utils.checkpoint import checkpoint
from xformers.ops import SwiGLU
loss_fn = torch.nn.MSELoss()
scaler = torch.amp.GradScaler()
def main():
device = 'cuda'
dtype = torch.float32
shape = (64, 128, 256, 128)
in_features, hidden_features, out_features = shape[1:]
x = torch.randn(shape[:2], dtype=dtype, device=device)
x.requires_grad_(True)
model = SwiGLU(in_features, hidden_features, out_features, bias=True, _pack_weights=True).to(device=device, dtype=dtype)
with torch.autocast(device_type='cuda', dtype=torch.float16):
y = checkpoint(model, x, use_reentrant=False)
y = loss_fn(y, y.detach())
scaler.scale(y).backward()
if __name__ == '__main__':
main()
Traceback (most recent call last):
File "/workspace/src/my/fused_swiglu/proof_with_amp.py", line 25, in <module>
main()
File "/workspace/src/my/fused_swiglu/proof_with_amp.py", line 22, in main
scaler.scale(y).backward()
File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1129, in unpack_hook
frame.check_recomputed_tensors_match(gid)
File "/workspace/envs_python/tiny_llama/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 865, in check_recomputed_tensors_match
raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 5
Number of tensors saved during recomputation: 0
Environment
Collecting environment information...
PyTorch version: 2.5.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.0
Libc version: glibc-2.35
Python version: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.13.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
Nvidia driver version: 560.28.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 112
On-line CPU(s) list: 0-111
Vendor ID: AuthenticAMD
Model name: AMD EPYC Processor
CPU family: 23
Model: 1
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 1
Stepping: 2
BogoMIPS: 3999.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 arat umip
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.8 MiB (56 instances)
L1i cache: 3.5 MiB (56 instances)
L2 cache: 28 MiB (56 instances)
L3 cache: 8 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-111
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; LFENCE, IBPB conditional, STIBP conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-copyright==0.2.4
[pip3] lion-pytorch==0.2.2
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.4.0
[pip3] rotary-embedding-torch==0.8.4
[pip3] torch==2.5.0
[pip3] torchaudio==2.5.1
[pip3] torchmetrics==1.5.1
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] Could not collect
🐛 Bug
Activation checkpointing on SwiGLU is not working when AMP is enabled. When AMP is disabled everything works as expected.
To Reproduce
Steps to reproduce the behavior:
Environment