punica-ai / punica

Serving multiple LoRA finetuned LLM as one
Apache License 2.0
883 stars 40 forks source link

[Question] How to avoid matrices conflict #48

Open gyliu513 opened 3 months ago

gyliu513 commented 3 months ago

Assuming W of shape [H1, H2] is the weight of the pretrained model, LoRA adds two small matrices A of shape [H1, r] and B of [r, H2]. Running a input x on the finetuned model would be y := x @ (W + A@B), which is the same as y := x@W + x@A@B. When there are n LoRA models, there will be A1, B1, A2, B2, ..., An, Bn. Given a input batch X := (x1,x2,...,xn) that maps to each LoRA model, the output is Y := X@W + (x1@A1@B1, x2@A2@B2, ..., xn@An@Bn). The left-hand-side computes the input batch on the pretrained model. It is quite efficient. The latency is almost the same as when there's only one input, thanks to the strong batching effect.

So the question is if A1,A2,A3 shape has conflict, like modifying same parameters in base model, does the result will still be correct? Thanks!