predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
2.09k stars 138 forks source link

Fuse q,k,v LoRAs #158

Open tgaddair opened 8 months ago

tgaddair commented 8 months ago

Currently, we treat each of the Q, K, V LoRAs as distinct tensors, meaning we do 3 SGMV calls per layer instead of 1. We should fuse them to improve batching.

tgaddair commented 8 months ago

Quick sanity check on the math:

import torch

h1 = 32
h2 = 32
h3 = 8
h4 = 8
r = 8
b = 4

x = torch.randn((b, h1))

qA = torch.randn((h1, r))
qB = torch.randn((r, h2))

kA = torch.randn((h1, r))
kB = torch.randn((r, h3))

vA = torch.randn((h1, r))
vB = torch.randn((r, h4))

y_q = (x @ qA) @ qB
y_k = (x @ kA) @ kB
y_v = (x @ vA) @ vB
y = torch.cat([y_q, y_k, y_v], dim=1)
print(y, y.shape)

A = torch.zeros((h1, r * 3))
B = torch.zeros((r * 3, h2 + h3 + h4))

A[:, 0:r] = qA
A[:, r:r*2] = kA
A[:, r*2:r*3] = vA

B[0:r, 0:h2] = qB
B[r:r*2, h2:h2+h3] = kB
B[r*2:r*3, h2+h3:h2+h3+h4] = vB

print(A.shape, B.shape)

y2 = (x @ A) @ B
print(y2, y2.shape)

torch.allclose(y, y2)

Everything looks good. There is some increased memory overhead due to needing to pad the B tensor with zeros:

elems1 = sum(v.numel() for v in [qA, qB, kA, kB, vA, vB])
elems2 = sum(v.numel() for v in [A, B])
print(elems1, elems2, elems2 / elems1, elems1 / elems2)

We get about a 67% increase in memory overhead, so we may want to make this optional.

Performance difference with Mistral-7B:

64 tokens generated, 1x A100

rank 8 (q, v):

rank 16 (q, k, v):

Latency reduction, particularly when using all 3 of q, k, v, is meaningful but not clear it's worth the 67% increase in memory usage for the adapter.

Also, it looks like there are numerical challenges with the SGMV kernel that are causing some corruption with these fused ranks. We'll need to resolve those to get this working.

tgaddair commented 8 months ago

Branch: https://github.com/predibase/lorax/tree/fuse-qkv

tgaddair commented 8 months ago

Due to the numerical issues, we could revisit this after tackling #160, which will allow us to pad SGMV ops to a particular (supported) rank.

thincal commented 5 months ago

It seems that the AWQ quantized model already support the fused qkv, might consider it also.