microsoft / Cream

This is a collection of our NAS and Vision Transformer work.
MIT License
1.61k stars 220 forks source link

Question about the training speed of IRPE #225

Closed LPudding closed 4 months ago

LPudding commented 4 months ago

Hello, thank you for bringing such a meaningful work.

After compiling the cuda operator, I found that using IRPE significantly slowed down the training speed. Originally, it would take about 18 minutes to train in one epoch without using IRPE, but using IRPE would take 38 minutes. May I ask where the problem may be?

For simplicity, we used the following parameters and code. Due to the large number of tokens we have, we used average pooling to reduce the complexity of self attention.

class LowMixerWithREP(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pool_size=2, rpe_config=None,
        **kwargs, ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.dim = dim

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)

        self.pool = nn.AvgPool2d(pool_size, stride=pool_size, padding=0, count_include_pad=False) if pool_size > 1 else nn.Identity()
        self.uppool = nn.Upsample(scale_factor=pool_size) if pool_size > 1 else nn.Identity()

        # image relative position encoding
        if rpe_config == None:
            rpe_config = get_rpe_config(
                ratio=1.9,
                method="product",
                mode='ctx',
                shared_head=True,
                skip=0,
                rpe_on='qkv',
            )

        self.rpe_q, self.rpe_k, self.rpe_v = \
            build_rpe(rpe_config,
                      head_dim=head_dim,
                      num_heads=num_heads)

    def att_fun(self, q, k, v, B, N, C):
        attn = (q @ k.transpose(-2, -1)) * self.scale

        # image relative position on keys
        if self.rpe_k is not None:
            attn += self.rpe_k(q)

        # image relative position on queries
        if self.rpe_q is not None:
            attn += self.rpe_q(k * self.scale).transpose(2, 3)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = attn @ v
        # image relative position on values
        if self.rpe_v is not None:
            out += self.rpe_v(attn)
        x = out.transpose(2, 3).reshape(B, C, N)
        return x

    def forward(self, x):
        # B, C, H, W
        B, _, _, _ = x.shape
        xa = self.pool(x)
        xa = xa.permute(0, 2, 3, 1).view(B, -1, self.dim)
        B, N, C = xa.shape
        qkv = self.qkv(xa).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)
        xa = self.att_fun(q, k, v, B, N, C)
        xa = xa.view(B, C, int(N**0.5), int(N**0.5))#.permute(0, 3, 1, 2)

        xa = self.uppool(xa)
        return xa
wkcn commented 4 months ago

Hi @LPudding , thanks for your attention to our work! What are the width and height of the input 2D sequence? It's possible that the CUDA operator might not be as well optimized as the latest version of PyTorch, or that the CUDA operator is slower on long sequences.

LPudding commented 4 months ago

Thank you for your patient explanation. We used inception transformer as backbone, so the width and height involved in attention operations at different stages are 32 32, 16 16, 16 16, and 8 8, respectively. In addition, due to server limitations, we used pytorch 1.12.1+cu102.

wkcn commented 4 months ago

Thank you for providing this information.

The cuda implementation for iRPE may be slow on the sequence whose length is larger than 14x14. I will try to improve the training speed.

wkcn commented 4 months ago

Here are my performance table and related code. The CUDA operator is faster than torch's implementation.

Sequence Length 8x8 16x16 32x32
torch 0.52 6.67 105.52
cuda impl. 0.25 0.44 9.55
from irpe import build_rpe
from irpe import get_rpe_config
import torch
import time

rpe_config = get_rpe_config(
    ratio=1.9,
    method="product",
    mode='ctx',
    shared_head=True,
    skip=0,
    rpe_on='k',
)

head_dim = 128
num_heads = 12
irpe_q, irpe_k, irpe_v = build_rpe(rpe_config,
    head_dim=head_dim,
    num_heads=num_heads)

irpe_k.cuda()

def get_cost(E):
    C = 768
    B = 128
    L = E * E
    q = torch.randn(B, num_heads, L, C // num_heads, device='cuda')

    warm = 10
    for _ in range(warm):
        out = irpe_k(q)
        out.sum().backward()
        torch.cuda.synchronize()
    T = 100
    tic = time.time()
    for _ in range(T):
        out = irpe_k(q)
        out.sum().backward()
        torch.cuda.synchronize()
    toc = time.time()
    cost = (toc - tic) / T * 1000
    return cost

costs = []
for e in [8, 16, 32]:
    cost = get_cost(e)
    print(e, cost)
    costs.append(cost)

print(costs)
LPudding commented 4 months ago

Thank you for your help. I found that when applying irpe to q and k, the time difference is not significant, but when applying irpe to v, the time will be significantly extended.

wkcn commented 4 months ago

Yes. The iRPE on value is not optimized with custom CUDA operator.