bytedance / lightseq

LightSeq: A High Performance Library for Sequence Processing and Generation
Other
3.22k stars 329 forks source link

No acceleration compared with timm vit block #410

Closed woolpeeker closed 2 years ago

woolpeeker commented 2 years ago

I use the code below to test the vit block speed. The output shows the speed is almost the same between pytorch and lightseq

Did I missed something?

Output for forward only:

timm finished 500 running, avg_time: 76.379987 ms light_seq finished 500 running, avg_time: 75.543549 ms

The output for forward + backward:

timm finished 500 running, avg_time: 228.803998 ms light_seq finished 500 running, avg_time: 227.007331 ms

from timm.models.vision_transformer import Block
from lightseq.training.ops.pytorch.transformer_encoder_layer import LSTransformerEncoderLayer
from easydict import EasyDict as edict
import torch.nn as nn
import torch
import time
import sys
sys.path.append('./')

torch.backends.cudnn.benchmark = True

def generate_dummy_data(args):
    inputs = torch.randn([args.bs, args.num_token, args.dim]).cuda()
    return (inputs, )

def get_timm_block(args):
    return Block(
        dim=args.dim,
        num_heads=args.num_heads,
        mlp_ratio=args.mlp_ratio,
        qkv_bias=False,
        drop=False,
        attn_drop=False,
        init_values=None,
        drop_path=0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm
    )

class LSBlockWrapper(LSTransformerEncoderLayer):
    def forward(self, x):
        B, N, C = x.shape
        mask = torch.zeros([B, N, N], device=x.device, dtype=x.dtype)
        return super().forward(x, mask)

def get_ls_block(args):
    config = LSBlockWrapper.get_config(
        max_batch_tokens=args.num_token * args.bs,
        max_seq_len=args.num_token,
        hidden_size=args.dim,
        intermediate_size=int(args.mlp_ratio * args.dim),
        nhead=args.num_heads,
        attn_prob_dropout_ratio=0,
        hidden_dropout_ratio=0,
        activation_dropout_ratio=0,
        pre_layer_norm=True,
        fp16=False,
        local_rank=0,
        activation_fn='gelu')
    return LSBlockWrapper(
            config=config,
            initial_weights=None,
            initial_biases=None
        )

def run(module, args, name='Unknown'):
    inputs = generate_dummy_data(args)

    # cudnn warmup
    for _ in range(50):
        if args.backward:
            module(*inputs).sum().backward()
        else:
            module(*inputs)

    torch.cuda.synchronize()
    t0 = time.time()

    for _ in range(args.num_iter):
        if args.backward:
            module(*inputs).sum().backward()
        else:
            module(*inputs)

    torch.cuda.synchronize()
    t1 = time.time()

    avg_time = (t1 - t0) * 1000 / args.num_iter
    print(
        f'>>> {name} finished {args.num_iter} running, avg_time: {avg_time:.6f} ms')
    return avg_time

def main():
    args = edict()
    args.num_iter = 500
    args.backward = False

    args.bs = 128
    args.dim = 1280
    args.num_heads = 16
    args.mlp_ratio = 4.0
    args.num_token = 256

    timm_block = get_timm_block(args).cuda()
    ls_block = get_ls_block(args).cuda()

    run(timm_block, args, name='timm')
    run(ls_block, args, name='light_seq')

    print('Finished.')

if __name__ == '__main__':
    main()
Taka152 commented 2 years ago

It seems like you are using fp32, could you try fp16?

woolpeeker commented 2 years ago

Sure, I tested it with fp16.

with backward=False

timm finished 500 running, avg_time: 10.408471 ms light_seq finished 500 running, avg_time: 9.462291 ms

with backward=True

timm finished 500 running, avg_time: 31.718561 ms light_seq finished 500 running, avg_time: 30.036270 ms only 1.7 ms difference.

The test environment is: PyTorch version: 1.12.1 CUDA used to build PyTorch: 11.3 Python version: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0] (64-bit runtime) lightseq==2.2.1

woolpeeker commented 2 years ago

Do you have official test result of ViT between native pytorch and lightseq?

Taka152 commented 2 years ago

We have tested on 8xA100, and this is the result.

image

BTW, could you tell me the card you are using, I will check if I can reproduce your results.

woolpeeker commented 2 years ago

I used a single A100-80G GPU

godweiyang commented 2 years ago

I used a single A100-80G GPU

Hi, I tested fp16 precision of your script. Below is the result: IHVkWS8bgR

I ran it for multiple times, and the results are the same.

I also ran other dims and batch_sizes, and all the results show that lightseq is faster.

woolpeeker commented 2 years ago

Thanks for testing. The lightseq result is the same with mine, around 9.5 ms in our machine, the original timm layer is a little faster than yours. Ours is around 10.4 ms. Lightseq is always faster than original layer, but the gap is smaller in our machine.

I will close this issue.

BTW, do you have plan to implement the flashAttention in lightseq? I heard it is much faster than previous methods.

syorami commented 1 year ago

hi @woolpeeker I'm also trying to integrate lightseq into my project with timm ViT model. After switching to the lightseq layer, can we still load the same pretrained model weights?

woolpeeker commented 1 year ago

yes, you just need to organize the weight tensor following lightseq' docs. I have test it. the result can align to timm

syorami commented 1 year ago

@woolpeeker Thanks! This saves much time for me.

syorami commented 1 year ago

hi @woolpeeker I followed the doc and integrate lightseq transformer layer into timm ViT but the speed improvement is trivial. I'm wondering if you can observe any speedup as I'm using same GPU (A100) and the above snippet code gives me exactly same results as yours. I guess the relative improvement heavily depends on the batch size or other hyperparameters.

BTW, switching to FlashAttention could give me a speed boost of around 10% for the whole timm ViT model and for a single attention block, the speedup is 45%. I hope this could help.