facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.3k stars 581 forks source link

`memory_efficient_attention` makes no difference #678

Open FrancescoSaverioZuppichini opened 1 year ago

FrancescoSaverioZuppichini commented 1 year ago

❓ Questions and Help

Hi guys,

Thanks a lot for the amazing work. I am trying to use xformers on CLIP, following the timm tutorial I've put together the following code

MODEL_NAME = "ViT-L/14"
import clip
import torch
from torch import nn
from xformers.components import MultiHeadDispatch
from xformers.ops import memory_efficient_attention_forward

model, _ = clip.load(MODEL_NAME, device="cuda")
from torch.utils import benchmark

class Attention(torch.nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        attn_drop=0.0,
        proj_drop=0.0,
        attn_mask=None,
    ):
        super().__init__()
        self.num_heads = num_heads

        self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = torch.nn.Dropout(attn_drop)
        self.proj = torch.nn.Linear(dim, dim)
        self.proj_drop = torch.nn.Dropout(proj_drop)
        self.attn_mask = attn_mask

    def forward(self, x, *args, **kwargs):
        x = x.permute(1,0,2)
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )

        qkv = qkv.flatten(1, 2)
        q, k, v = qkv.unbind()

        x = memory_efficient_attention_forward(q, k, v, op=None)
        x = x.reshape(B, self.num_heads, N, C // self.num_heads)

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x.permute(1,0,2)

def profile_model(fn, min_run_time=2):
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    res = benchmark.Timer(
        stmt='fn()',
        globals={"fn": fn},
        label="profile",
        sub_label="",
        description=""
    ).blocked_autorange(min_run_time=min_run_time)
    torch.cuda.synchronize()
    memory = torch.cuda.max_memory_allocated() / 2 ** 20
    memory = f"Memory used: {memory} MB"
    print(res)
    print(memory)

with torch.no_grad():

    vit =  model.visual

    x = torch.randn((1, 3, 224, 224), device="cuda").half()

    print(vit(x).shape)

    profile_model(lambda : vit(x))

    for i, resblock in enumerate(vit.transformer.resblocks):
        embed_dim = resblock.attn.embed_dim
        num_heads = resblock.attn.num_heads
        resblock.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=True)

    def _convert_weights_to_fp16(l):
            if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
                l.weight.data = l.weight.data.half()
                if l.bias is not None:
                    l.bias.data = l.bias.data.half()

            if isinstance(l, Attention):
                l = l.cuda().half()

    vit.apply(_convert_weights_to_fp16)

    profile_model(lambda : vit(x))

    print(vit(x).shape)

It outputs

<torch.utils.benchmark.utils.common.Measurement object at 0x7f0401be5c70>
profile
  Median: 12.10 ms
  IQR:    0.76 ms (11.43 to 12.20)
  17 measurements, 10 runs per measurement, 1 thread
Memory used: 908.04052734375 MB
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0401c60df0>
profile
  Median: 11.47 ms
  IQR:    0.11 ms (11.42 to 11.54)
  18 measurements, 10 runs per measurement, 1 thread
Memory used: 908.04052734375 MB

So basically, no change. Am I doing something wrong?

Thanks a lot,

Fra

danthe3rd commented 1 year ago

Your batch size is very small (1), and you might be CPU-bound. You should try increasing it:

B = 64 # Or lower if you get OOMs
x = torch.randn((B, 3, 224, 224), device="cuda").half()

This should give you some speedup. You can also further improve the speed by avoiding the transpose calls, something like that:

    def forward(self, x, *args, **kwargs):
        x = x.permute(1,0,2)
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
        )

        q, k, v = qkv.unbind(2)

        x = memory_efficient_attention_forward(q, k, v, op=None)
        # x = x.reshape(B, self.num_heads, N, C // self.num_heads).transpose(1, 2)

        x = x.reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x.permute(1,0,2)
FrancescoSaverioZuppichini commented 1 year ago

@danthe3rd thanks a lot for the reply, I've updated the code but I cannot see any speed up. batch_size=64

<torch.utils.benchmark.utils.common.Measurement object at 0x7f99d81e9460>
profile
  Median: 248.12 ms
  IQR:    0.32 ms (248.02 to 248.35)
  9 measurements, 1 runs per measurement, 1 thread
Memory used: 1432.09716796875 MB
CLIP xformers
<torch.utils.benchmark.utils.common.Measurement object at 0x7f99d8026640>
profile
  Median: 231.34 ms
  IQR:    1.04 ms (231.02 to 232.06)
  9 measurements, 1 runs per measurement, 1 thread
Memory used: 1432.09716796875 MB
danthe3rd commented 1 year ago

Can you report the output of this command:

python -m xformers.info
FrancescoSaverioZuppichini commented 1 year ago

@danthe3rd

(dl) ➜  ~ python -m xformers.info
xFormers 0.0.16
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.flshattF:               available
memory_efficient_attention.flshattB:               available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        available
memory_efficient_attention.tritonflashattB:        available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
is_functorch_available:                            False
pytorch.version:                                   1.13.1+cu117
pytorch.cuda:                                      available
gpu.compute_capability:                            8.6
gpu.name:                                          NVIDIA GeForce RTX 3090
build.info:                                        available
build.cuda_version:                                1107
build.python_version:                              3.9.16
build.torch_version:                               1.13.1+cu117
build.env.TORCH_CUDA_ARCH_LIST:                    5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE:                     Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              None
build.env.XFORMERS_PACKAGE_FROM:                   wheel-v0.0.16
source.privacy:                                    open source
danthe3rd commented 1 year ago

I ran your script on my machine (with A100 GPU) and got a nice speedup:

torch.Size([64, 768])
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2a13d94460>
profile
  Median: 160.57 ms
  IQR:    0.02 ms (160.55 to 160.58)
  13 measurements, 1 runs per measurement, 1 thread
Memory used: 1432.09716796875 MB
<torch.utils.benchmark.utils.common.Measurement object at 0x7f2a13c558e0>
profile
  Median: 101.79 ms
  IQR:    0.06 ms (101.77 to 101.83)
  20 measurements, 1 runs per measurement, 1 thread
Memory used: 1432.09716796875 MB
torch.Size([64, 768])

Maybe you can try with the latest xformers development version (pip install --pre -U xformers)?

NOTE: I modified the script slightly:

script.py ```python MODEL_NAME = "ViT-L/14" import clip import torch from torch import nn from xformers.components import MultiHeadDispatch from xformers.ops import memory_efficient_attention_forward model, _ = clip.load(MODEL_NAME, device="cuda") from torch.utils import benchmark class Attention(torch.nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, attn_mask=None, ): super().__init__() self.num_heads = num_heads self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = torch.nn.Dropout(attn_drop) self.proj = torch.nn.Linear(dim, dim) self.proj_drop = torch.nn.Dropout(proj_drop) self.attn_mask = attn_mask def forward(self, x, *args, **kwargs): x = x.permute(1,0,2) B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) ) q, k, v = qkv.unbind(2) x = memory_efficient_attention_forward(q, k, v, op=None) x = x.reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x.permute(1,0,2) def profile_model(fn, min_run_time=2): torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() res = benchmark.Timer( stmt='fn()', globals={"fn": fn}, label="profile", sub_label="", description="" ).blocked_autorange(min_run_time=min_run_time) torch.cuda.synchronize() memory = torch.cuda.max_memory_allocated() / 2 ** 20 memory = f"Memory used: {memory} MB" print(res) print(memory) with torch.no_grad(): vit = model.visual x = torch.randn((64, 3, 224, 224), device="cuda").half() print(vit(x).shape) def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, Attention): l = l.cuda().half() vit.apply(_convert_weights_to_fp16) profile_model(lambda : vit(x)) for i, resblock in enumerate(vit.transformer.resblocks): embed_dim = resblock.attn.embed_dim num_heads = resblock.attn.num_heads resblock.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=True) vit.apply(_convert_weights_to_fp16) profile_model(lambda : vit(x)) print(vit(x).shape) ```
FrancescoSaverioZuppichini commented 1 year ago

I've updated, I still don't see any real difference

<torch.utils.benchmark.utils.common.Measurement object at 0x7f364f553bb0>
profile
  Median: 247.37 ms
  IQR:    0.72 ms (246.91 to 247.63)
  9 measurements, 1 runs per measurement, 1 thread
Memory used: 1432.09716796875 MB
CLIP xformers
<torch.utils.benchmark.utils.common.Measurement object at 0x7f364f5c1ee0>
profile
  Median: 229.60 ms
  IQR:    1.31 ms (229.01 to 230.33)
  9 measurements, 1 runs per measurement, 1 thread
Memory used: 1432.09716796875 MB

Memory is the same, which torch version are you using?

This is so interesting ahahha maybe xformers doesn't play nice with my 3090? Which driver are you running?

danthe3rd commented 1 year ago

You still have some 7% speedup. Memory might not differ much because you run in no_grad mode, and your sequences are not that long (257).

maybe xformers doesn't play nice with my 3090

Yes it's possible. Different GPUs have different characteristics, and our kernels have been mostly optimized for V100/A100 as that's what we use internally for research.

FrancescoSaverioZuppichini commented 1 year ago

You still have some 7% speedup.

Still a win 🥳

Memory might not differ much because you run in no_grad mode, and your sequences are not that long (257). Tried using memory_efficient_attention, there is indeed some memory saving

profile
  Median: 248.75 ms
  IQR:    0.69 ms (248.62 to 249.31)
  9 measurements, 1 runs per measurement, 1 thread
Memory used: 21231.14501953125 MB
CLIP xformers
<torch.utils.benchmark.utils.common.Measurement object at 0x7fde79d89d60>
profile
  Median: 232.57 ms
  IQR:    0.98 ms (232.18 to 233.17)
  9 measurements, 1 runs per measurement, 1 thread
Memory used: 18147.27001953125 MB

Yes it's possible. Different GPUs have different characteristics, and our kernels have been mostly optimized for V100/A100 as that's what we use internally for research. Any resource on that?

Moreover, is there an optimization guide? Like, where is best to some some attention compared to others? I am happy to lear more about and contribute with articles and blogs