facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8k stars 565 forks source link

`memory_efficient_attention`: `torch.compile` compatibility #920

Open achalddave opened 8 months ago

achalddave commented 8 months ago

🐛 Bug

Using xformers.memory_efficient_attention with FSDP and torch.compile fails when using bfloat16, but works when using float32. It's unclear to me if this is an xformers bug, an FSDP bug, or a torch.compile bug. It might be related to https://github.com/pytorch/pytorch/issues/112164, and it came up in our codebase where we use xformers: https://github.com/mlfoundations/open_lm/issues/72

Command

torchrun --nproc_per_node 2 script.py

To Reproduce

Steps to reproduce the behavior:

  1. Save code sample below as script.py
  2. Run torchrun --nproc_per_node 2 script.py
# script.py
import torch
import torch.nn as nn

from torch.distributed.fsdp import MixedPrecision, FullyShardedDataParallel as FSDP
from xformers.ops import memory_efficient_attention
import xformers.ops as xops

class Layer(nn.Module):
    def __init__(self, n_feat):
        super().__init__()
        self.linear_out = nn.Linear(n_feat, n_feat)

    def forward(self, x):
        B, N, C = x.shape
        x = memory_efficient_attention(x, x, x, attn_bias=xops.LowerTriangularMask())
        return self.linear_out(x.reshape([B, N, C]))

###
dtype = torch.bfloat16  # Setting this to torch.float32 makes this code work.
###

torch.distributed.init_process_group(backend="nccl")
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
torch.cuda.set_device(device)
FEAT_SIZE = 128
MAX_LEN = 100
BATCH_SIZE = 8

batch = torch.zeros(BATCH_SIZE, MAX_LEN, FEAT_SIZE).to(device)
mha = Layer(FEAT_SIZE).to(device)

mp_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
mha_fsdp = FSDP(mha, use_orig_params=True, device_id=device, mixed_precision=mp_policy)

compile_mha = torch.compile(mha_fsdp).to(device)
output = compile_mha(batch)
output.mean().backward()

Expected behavior

Code runs without error.

Environment

Please copy and paste the output from the environment collection script from PyTorch (or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env
PyTorch version: 2.0.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.31

Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1028-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==1.5.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] pytorch-ranger==0.1.1
[pip3] st-moe-pytorch==0.1.1
[pip3] torch==2.0.1+cu118
[pip3] torch-optimizer==0.3.0
[pip3] torchdata==0.6.0
[pip3] torchmetrics==0.11.3
[pip3] torchtext==0.15.1
[pip3] torchvision==0.15.2+cu118
[conda] numpy                     1.25.2                   pypi_0    pypi

Additional context

xformers version: 0.0.22.

danthe3rd commented 7 months ago

Hi, Thanks for reporting this! A lot of operators in xFormers don't support torch.compile at the moment. This is on our roadmap, but might take ~months to get there (we might also need to fix some bugs in PyTorch as well...)

achalddave commented 7 months ago

Ah, okay, thanks! Is there an issue that tracks this that we could follow? We'd love to support torch.compile+xformers attention in our repo.

danthe3rd commented 7 months ago

We can use this issue to track. However this particular error might be related to FSDP ... xFormers operator will most likely incur a graph break (which will make performance worse), but shouldn't cause an exception or error.