idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Memory usage: native PyTorch vs. "full"-Attention #68

Closed GregorKobsik closed 3 years ago

GregorKobsik commented 3 years ago

Hello,

I wanted to leave some observations of myself here regarding the memory consumption (which is often a critical factor). It might be of some interesst for other who want to benchmark their implementation.

The fast-transformer implementation of full-self-attention uses around 35% more GPU memory and is slightly slower, than the nativ PyTorch implementation. I would like to note, that this is true for my specific setup and I run only a limited number of test runs (4 each), which I report here. I did only discover this, as my initial configuration/implementation in PyTorch did fit into the memory.

Both used modells use some embedding beforehand and differ only in the TransformerEncoderLayer / TransformerEncoderBuilder. I did not construct a minimal example, just exchanged the modules in my workflow to test different implementations.

The following numbers belong to this specific configuration:

Architecture: encoder only Attention mask: Causal masked (upper triangle) Layer number: 8 Embedding dimension: 64 Number of heads: 4 Feed-forward dimension: 4 * 64 Max sequence length: 4096 Batch size: 1 GPU: single RTX 2080 Ti

Peak memory usage in each run: native PyTorch: 6152 - 6200 GB fast-transformers: 8312 - 8454 GB

Computation time per epoch in each run: native PyTorch: 9min 9s - 9min 33s fast-transformers: 10min 18s - 10min 48s

The same configuration with 16 layers does fit into the GPU (~11GB) using native PyTorch and throws an OOM with fast-transformers. I suppose this is not an important issue, as long as both implementations provide similar results (might test it in the next couple of days on my specific setup, too), as the focus of the library lies on efficient implementations.

angeloskath commented 3 years ago

Hi,

Could you provide a simple bench script. I am using the following and the timings are practically identical on my RTX 2060 S as is the memory.

import torch
from fast_transformers.attention import AttentionLayer, FullAttention
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask, FullMask, TriangularCausalMask

if __name__ == "__main__":

    start = torch.cuda.Event(enable_timing=True)
    stop = torch.cuda.Event(enable_timing=True)
    x1 = torch.rand(32, 512, 256).cuda()
    x2 = torch.rand(512, 32, 256).cuda()
    lengths = LengthMask(torch.full((32,), 512, dtype=torch.long).cuda())
    attn_mask1 = TriangularCausalMask(512, device="cuda")
    attn_mask2 = torch.triu(torch.ones(512, 512)).bool().cuda()

    transformer = torch.nn.TransformerEncoder(
        torch.nn.TransformerEncoderLayer(256, 4, dim_feedforward=1024),
        4,
        torch.nn.LayerNorm(256)
    ).cuda()

    transformer(x2, mask=attn_mask2).sum().backward()
    start.record()
    for i in range(10):
        transformer(x2, mask=attn_mask2).sum().backward()
    stop.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(stop))

    transformer = TransformerEncoderBuilder.from_kwargs(
        n_layers=4,
        n_heads=4,
        query_dimensions=64,
        value_dimensions=64,
        feed_forward_dimensions=1024
    ).get().cuda()

    transformer(x1, attn_mask=attn_mask1).sum().backward()
    start.record()
    for i in range(10):
        transformer(x1, attn_mask=attn_mask1).sum().backward()
    stop.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(stop))

The outcome is

1196.923095703125
1146.894287109375

Cheers, Angelos

GregorKobsik commented 3 years ago

Hey, thanks for the quick code. Will run it tomorrow and later this week try to cut out my model from my project to run it separately.

I run the code above on two setups each 4 times. Lets call it simple bench script, when referring to it. The averaged output is:

GPU: RTX 2070

1140.389
1175.254

GPU: RTX 2080 Ti

795.2336
795.4859

Python: 3.7.9 PyTorch: 1.7.1 Driver Ver: 450.80 CUDA: 11.0

angeloskath commented 3 years ago

Sure, feel free to share a benchmark. Maybe a naive question, have you ensured that the ordering of the input sequence dimensions is correct?

In short, the ordering for PyTorch native is (sequence_length, batch_size, features) while for fast_transformers it is (batch_size, sequence_length, features).

Cheers, Angelos

GregorKobsik commented 3 years ago

I suppose my dimensions are correct (I feel it somehow a strange decision from PyTorch to put the sequence first).

I did not write an own benchmark script, just managed to slightly modifiy your to obtain similar behaviour, as I experience with my modell.

Here is the code. Let's name it bench script v2:

import torch
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask, TriangularCausalMask

if __name__ == "__main__":
    NUM_LAYERS = 4
    DIM_EMBEDDING = 64
    NUM_HEADS = 4
    DIM_FF = 4 * DIM_EMBEDDING
    LEN_INPUT = 4096
    BATCH_SIZE = 1
    FAST_TRANSFORMERS = False

    if FAST_TRANSFORMERS:
        # fast-transformer
        x = torch.rand(BATCH_SIZE, LEN_INPUT, DIM_EMBEDDING).cuda()
        attn_mask = TriangularCausalMask(LEN_INPUT, device="cuda")

        transformer = TransformerEncoderBuilder.from_kwargs(
            n_layers=NUM_LAYERS,
            n_heads=NUM_HEADS,
            query_dimensions=DIM_EMBEDDING // NUM_HEADS,
            value_dimensions=DIM_EMBEDDING // NUM_HEADS,
            feed_forward_dimensions=DIM_FF,
            dropout=0.0,
            activation='relu',
        ).get().cuda()
    else:
        # native PyTorch
        x = torch.rand(LEN_INPUT, BATCH_SIZE, DIM_EMBEDDING).cuda()
        attn_mask = torch.triu(torch.ones(LEN_INPUT, LEN_INPUT)).bool().cuda()

        transformer = torch.nn.TransformerEncoder(
            torch.nn.TransformerEncoderLayer(DIM_EMBEDDING, NUM_HEADS, DIM_FF, 0.0, 'relu'),
            NUM_LAYERS,
            torch.nn.LayerNorm(DIM_EMBEDDING),
        ).cuda()

    def step(x, attn_mask):
        if FAST_TRANSFORMERS:
            transformer(x, attn_mask=attn_mask).sum().backward()
        else:
            transformer(x, mask=attn_mask).sum().backward()

    start = torch.cuda.Event(enable_timing=True)
    stop = torch.cuda.Event(enable_timing=True)
    torch.cuda.reset_peak_memory_stats()
    step(x, attn_mask)
    start.record()
    for i in range(10):
        step(x, attn_mask)
    stop.record()
    torch.cuda.synchronize()

    print("max_allocated:", torch.cuda.max_memory_allocated() / 1024**2)
    print("max_reserved:", torch.cuda.max_memory_reserved() / 1024**2)
    print("total_runtime:", start.elapsed_time(stop))

As before, I run the script 4 times on two different devices.

NUM_LAYERS 4
DIM_EMBEDDING 64
NUM_HEADS 4
DIM_FF 256
LEN_INPUT 4096
BATCH_SIZE 1
GPU Fast   run1 run2 run3 run4 AVG
RTX 2070 FALSE max_allocated 2890.729 2890.729 2890.729 2890.729 2890.729
    total_runtime 830.9622 817.2667 812.1809 811.9242 818.0835
  TRUE max_allocated 3910.765 3910.765 3910.765 3910.765 3910.765
    total_runtime 999.9157 998.6396 1003.36 1009.149 1002.766
               
RTX 2080Ti FALSE max_allocated 2890.729 2890.729 2890.729 2890.729 2890.729
    total_runtime 568.4258 571.1238 577.5462 571.9306 572.2566
  TRUE max_allocated 3910.765 3910.765 3910.765 3910.765 3910.765
    total_runtime 705.8082 701.9031 702.5811 692.4947 700.6968

We see, that the memory allocation did not change between these devices, only the computation time, which remains proportional to each other. Furthermore the memory allocation, as well as the runtime did not vary much within the same configuration, thus I would from now on, only report one run and perform it on the RTX 2070.

For the next test, I use two different parameter configurations and vary the number of layers in each run. The first parameter set is exactly the one, as you proposed in simple bench script, the second set of parameters is taken from bench script v2.

Configuration 1 2
DIM_EMBEDDING 256 64
NUM_HEADS 4 4
DIM_FF 1024 256
LEN_INPUT 32 4096
BATCH_SIZE 512 1

Results of Configuration 1:

  Fast   2 4 6 8
RTX 2070 FALSE max_allocated 669.4316 1225.982 1782.533 2339.084
    total_runtime 387.242 744.0695 1092.286 1465.725
  TRUE max_allocated 641.4365 1217.987 1794.538 2371.089
    total_runtime 377.4643 752.7672 1093.716 1440.55

Results of Configuration 2:

  Fast   2 4 6 8
RTX 2070 FALSE max_allocated 1832.835 2890.729 3947.624 5004.519
    total_runtime 432.5035 811.8661 1222.595 1631.668
  TRUE max_allocated 2213.869 3910.765 5607.662 OOM
    total_runtime 514.5211 994.1755 1489.83 -

For me, it looks like, the first configuration produces similar results for both implementations. But with a changed parameter set, the memory consumption, as well as the runtime diverge from each other.

I suppose, that the key factor is the much longer sequence length. This seems to my, like an interessting observation, as an efficient implementation will be mostly tested on longer sequences (2k, 4k, 8k, ... ) rather than super short sequences, which can be handled already by a vanilla transformer.

angeloskath commented 3 years ago

Awesome, thanks! It appears that the culprit is batch size = 1. I will look into it. It shouldn't be too hard to equalize the performance.

As an aside, you could always try attention_type="linear" or attention_type="improved-clustered" to get a really significant speed boost at those sequence lengths.

Cheers, Angelos

GregorKobsik commented 3 years ago

I am already testing some of the implementations in my usecase (3D Shape Generation), but it seems like the "linear" attention, lacks behind the vanialla full scaled dot product in terms of relative time convergence and converged results. Computation wise, it is sadly as fast as the vanilla transformer, but uses approximatelly 3.6x less memory, so I should be able to increace the batch size or use longer sequences, to modell higher resolution shapes. P.S. 2 This is strange, as the bench script v2 shows a significant less memory consumption and faster runtime...

The "improved-clustered" attention will be queued next.

My main goal of using your library was to easily test different implementations and swap different attentions, without modifying my code base or adding too much dependencies, to reduce the complexity of my project. Thank you for the big and nice code base.

I also saw, that you are often using einsum in your project. Maybe the opt_einsum lib could provide some more speedup, as suggested by PyTorch.

P.S. I run bench script v2 with different combinations of batch sizes (1, 2, 4, 16) with reduced sequence lenght (1024, 2048), but the runtime and memory seems to diverge here, too.

angeloskath commented 3 years ago

Hi,

Sorry for the late reply, I just found some time to spend on this.

So it is going to be funny but what is going on is actually the fact that the default pytorch layer has a single dropout parameter while we have different parameters for the transformer layers and the attention layer. Simply put if you set attention_dropout=0.0 the performance is identical (actually ours is a tiny bit faster but I don't know why :-) ).

The extra memory given the attention dropout makes perfect sense since we need to keep the attention matrix in memory twice. Let me know if you are still experiencing problems. I will also add the test script for completeness and I will close the issue in a few days if there is no change.

import argparse

import torch
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask, TriangularCausalMask

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--batch_size", type=int, default=1
    )
    parser.add_argument(
        "--length", type=int, default=4096
    )
    parser.add_argument(
        "--n_layers", type=int, default=4
    )
    parser.add_argument(
        "--n_heads", type=int, default=4
    )
    parser.add_argument(
        "--dim_embedding", type=int, default=64
    )
    parser.add_argument(
        "--fast", action="store_true"
    )
    args = parser.parse_args()

    NUM_LAYERS = args.n_layers
    DIM_EMBEDDING = args.dim_embedding
    NUM_HEADS = args.n_heads
    DIM_FF = 4 * DIM_EMBEDDING
    LEN_INPUT = args.length
    BATCH_SIZE = args.batch_size
    FAST_TRANSFORMERS = args.fast

    if FAST_TRANSFORMERS:
        # fast-transformer
        x = torch.rand(BATCH_SIZE, LEN_INPUT, DIM_EMBEDDING).cuda()
        attn_mask = TriangularCausalMask(LEN_INPUT, device="cuda")

        transformer = TransformerEncoderBuilder.from_kwargs(
            n_layers=NUM_LAYERS,
            n_heads=NUM_HEADS,
            query_dimensions=DIM_EMBEDDING // NUM_HEADS,
            value_dimensions=DIM_EMBEDDING // NUM_HEADS,
            feed_forward_dimensions=DIM_FF,
            dropout=0.0,
            attention_dropout=0.0,  # this is the difference                <-----------------------
            activation='relu',
        ).get().cuda()
    else:
        # native PyTorch
        x = torch.rand(LEN_INPUT, BATCH_SIZE, DIM_EMBEDDING).cuda()
        attn_mask = torch.triu(torch.ones(LEN_INPUT, LEN_INPUT)).bool().cuda()

        transformer = torch.nn.TransformerEncoder(
            torch.nn.TransformerEncoderLayer(DIM_EMBEDDING, NUM_HEADS, DIM_FF, 0.0, 'relu'),
            NUM_LAYERS,
            torch.nn.LayerNorm(DIM_EMBEDDING),
        ).cuda()

    def step(x, attn_mask):
        if FAST_TRANSFORMERS:
            transformer(x, attn_mask=attn_mask).sum().backward()
        else:
            transformer(x, mask=attn_mask).sum().backward()

    start = torch.cuda.Event(enable_timing=True)
    stop = torch.cuda.Event(enable_timing=True)
    torch.cuda.reset_peak_memory_stats()
    step(x, attn_mask)
    start.record()
    for i in range(10):
        step(x, attn_mask)
    stop.record()
    torch.cuda.synchronize()

    print("max_allocated:", torch.cuda.max_memory_allocated() / 1024**2)
    print("max_reserved:", torch.cuda.max_memory_reserved() / 1024**2)
    print("total_runtime:", start.elapsed_time(stop))
angeloskath commented 3 years ago

Hi Gregor,

I am closing the issue. Let me know in case you are still experiencing any problem or in case you find our transformer implementations slower in any way.

Cheers, Angelos

GregorKobsik commented 3 years ago

Hi,

sorry for the late answer, I was quiet bussy with an other project.

(Previous answer was incorrect, as I run a wrong script, which was quiet similar, but used different attentions.)

I rerun Configuration1 and Configuration 2 with the latest release 0.4.0:

Results of Configuration 1:

  Fast   2 4 6 8
RTX 2070 FALSE max_allocated 669.4316 1225.982 1782.533 2339.084
    total_runtime 387.242 744.0695 1092.286 1465.725
TRUE max_allocated 621.436 1177.986 1734.537 2291.088
    total_runtime 405.331 742.957 1088.935 1441.433

Results of Configuration 2:

  Fast   2 4 6 8
RTX 2070 FALSE max_allocated 1832.835 2890.729 3947.624 5004.519
    total_runtime 432.5035 811.8661 1222.595 1631.668
  TRUE max_allocated 1893.868 2950.764 4007.661 5064.557
    total_runtime 384.786 719.164 1047.640 1390.363

Thanks for the clearification! I will need to update my implementation.

P.S. After the update to 0.4.0, the fast linear attention is twice as fast. My parameters are exactly as configuration 2, but with a bigger dataset. Thanks for the fix!