Open tgaddair opened 8 months ago
Quick sanity check on the math:
import torch
h1 = 32
h2 = 32
h3 = 8
h4 = 8
r = 8
b = 4
x = torch.randn((b, h1))
qA = torch.randn((h1, r))
qB = torch.randn((r, h2))
kA = torch.randn((h1, r))
kB = torch.randn((r, h3))
vA = torch.randn((h1, r))
vB = torch.randn((r, h4))
y_q = (x @ qA) @ qB
y_k = (x @ kA) @ kB
y_v = (x @ vA) @ vB
y = torch.cat([y_q, y_k, y_v], dim=1)
print(y, y.shape)
A = torch.zeros((h1, r * 3))
B = torch.zeros((r * 3, h2 + h3 + h4))
A[:, 0:r] = qA
A[:, r:r*2] = kA
A[:, r*2:r*3] = vA
B[0:r, 0:h2] = qB
B[r:r*2, h2:h2+h3] = kB
B[r*2:r*3, h2+h3:h2+h3+h4] = vB
print(A.shape, B.shape)
y2 = (x @ A) @ B
print(y2, y2.shape)
torch.allclose(y, y2)
Everything looks good. There is some increased memory overhead due to needing to pad the B tensor with zeros:
elems1 = sum(v.numel() for v in [qA, qB, kA, kB, vA, vB])
elems2 = sum(v.numel() for v in [A, B])
print(elems1, elems2, elems2 / elems1, elems1 / elems2)
We get about a 67% increase in memory overhead, so we may want to make this optional.
Performance difference with Mistral-7B:
64 tokens generated, 1x A100
rank 8 (q, v):
rank 16 (q, k, v):
Latency reduction, particularly when using all 3 of q, k, v, is meaningful but not clear it's worth the 67% increase in memory usage for the adapter.
Also, it looks like there are numerical challenges with the SGMV kernel that are causing some corruption with these fused ranks. We'll need to resolve those to get this working.
Due to the numerical issues, we could revisit this after tackling #160, which will allow us to pad SGMV ops to a particular (supported) rank.
It seems that the AWQ quantized model already support the fused qkv, might consider it also.
Currently, we treat each of the Q, K, V LoRAs as distinct tensors, meaning we do 3 SGMV calls per layer instead of 1. We should fuse them to improve batching.