Closed LPudding closed 4 months ago
Hi @LPudding , thanks for your attention to our work! What are the width and height of the input 2D sequence? It's possible that the CUDA operator might not be as well optimized as the latest version of PyTorch, or that the CUDA operator is slower on long sequences.
Thank you for your patient explanation. We used inception transformer as backbone, so the width and height involved in attention operations at different stages are 32 32, 16 16, 16 16, and 8 8, respectively. In addition, due to server limitations, we used pytorch 1.12.1+cu102.
Thank you for providing this information.
The cuda implementation for iRPE may be slow on the sequence whose length is larger than 14x14. I will try to improve the training speed.
Here are my performance table and related code. The CUDA operator is faster than torch's implementation.
Sequence Length | 8x8 | 16x16 | 32x32 |
---|---|---|---|
torch | 0.52 | 6.67 | 105.52 |
cuda impl. | 0.25 | 0.44 | 9.55 |
from irpe import build_rpe
from irpe import get_rpe_config
import torch
import time
rpe_config = get_rpe_config(
ratio=1.9,
method="product",
mode='ctx',
shared_head=True,
skip=0,
rpe_on='k',
)
head_dim = 128
num_heads = 12
irpe_q, irpe_k, irpe_v = build_rpe(rpe_config,
head_dim=head_dim,
num_heads=num_heads)
irpe_k.cuda()
def get_cost(E):
C = 768
B = 128
L = E * E
q = torch.randn(B, num_heads, L, C // num_heads, device='cuda')
warm = 10
for _ in range(warm):
out = irpe_k(q)
out.sum().backward()
torch.cuda.synchronize()
T = 100
tic = time.time()
for _ in range(T):
out = irpe_k(q)
out.sum().backward()
torch.cuda.synchronize()
toc = time.time()
cost = (toc - tic) / T * 1000
return cost
costs = []
for e in [8, 16, 32]:
cost = get_cost(e)
print(e, cost)
costs.append(cost)
print(costs)
Thank you for your help. I found that when applying irpe to q and k, the time difference is not significant, but when applying irpe to v, the time will be significantly extended.
Yes. The iRPE on value is not optimized with custom CUDA operator.
Hello, thank you for bringing such a meaningful work.
After compiling the cuda operator, I found that using IRPE significantly slowed down the training speed. Originally, it would take about 18 minutes to train in one epoch without using IRPE, but using IRPE would take 38 minutes. May I ask where the problem may be?
For simplicity, we used the following parameters and code. Due to the large number of tokens we have, we used average pooling to reduce the complexity of self attention.