daskol / lotr

Low Tensor Rank adaptation of large language models
https://arxiv.org/abs/2402.01376
Apache License 2.0
7 stars 1 forks source link

Please publish end-to-end application example #2

Open dmikushin opened 1 month ago

dmikushin commented 1 month ago

Dear all,

It would be great to see an end-to-end practical example of LoTR. By "practical" I mean that one takes, for example some existing LLM weights file, compresses it into a smaller weights file with LoTR, and then uses the new weights file for inference. For the first part I imagine something like this:

"""Loads your model and its weights, freezes all parameters, and then replaces the `torch.nn.Linear` layers with `LoTRLinear` layers. The compressed model is then saved to a new weights file.
"""

from argparse import ArgumentParser, Namespace
import torch
from lotr import LoTR, LoTRLinear
from transformers import AutoModelForSequenceClassification

parser = ArgumentParser(description=__doc__)
parser.add_argument('--model', type=str, required=True, help='Pretrained model name or path')
parser.add_argument('--input_state', type=str, required=True, help='Path to the .pth file to load weights from')
parser.add_argument('--output_state', type=str, required=True, help='Path to the .pth file to save the compressed model weights')
parser.add_argument('--rank', type=int, default=2, help='Rank for the LoTRLinear layers')

def main(ns: Namespace):
    # Load your model
    model = AutoModelForSequenceClassification.from_pretrained(ns.model)

    # Load your .pth file
    model.load_state_dict(torch.load(ns.input_state))

    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Create a shared LoTR container
    lotr = LoTR(model.config.hidden_size, model.config.hidden_size, rank=ns.rank)

    # Replace default torch.nn.Linear layers with LoTRLinear variant
    for layer in model.roberta.encoder.layer:
        layer.attention.self.query = LoTRLinear.from_linear(
            linear=layer.attention.self.query,
            lotr=lotr,
            scale=1.0,
        )
        layer.attention.self.value = LoTRLinear.from_linear(
            linear=layer.attention.self.value,
            lotr=lotr,
            scale=1.0,
        )

    # Save the compressed model
    torch.save(model.state_dict(), ns.output_state)

if __name__ == '__main__':
    main(parser.parse_args())

Does this make sense?

daskol commented 1 month ago

Unfortunately, there is no built-in support.

As long as I understood you correctly, you want to convert LoTRLinear layers back to Linear in order to merge adapter weights to original weight matrix $W$. So you probably are looking for LoTRLinear.to_linear method but it is not implemented at the moment. However, it is quite easy to do it manually. For each LoTRLinear layer, one should contract factors for $s$-th slice as follows.

def to_linear(self: LoTRLinear) -> Linear:
  self.linear.weights += torch.einsum('ij,jk,kl->il', self.lotr.rhs, self.lotr.mid, self.lotr.lhs)
  return self.linear

In this way, you can restore original model architecture and save checkpoint which can be easily restored later.

dmikushin commented 1 month ago

I'm not sure. To simplify this discussion, let's speak in terms of linear algebra. Suppose that I have a linear equation with a dense matrix Ax = b. Assume it can only be solved with an approximate method, such as Krylov method, e.g. BiCGStab. BiCGStab requires to multiply the problem matrix by a vector many times. But the method is generic: it does not care how and where the matrix is multiplied by a vector, it only requests from me to give it a result of multiplication. Multiply operation is a black box from the solver's point of view. So I would compress the matrix in some way that it's not too dense and easy to multiply, and I multiply it directly whenever solver requests. This is what I practically expect from LoTR.

In my understanding, the whole purpose of LoTR is to compress the weights and never go back to the original weights again. So, we should not ever do to_linear, because it means we re-create the whole original dense matrix out of compressed matrix, in order to multiply. Instead, we need to teach torch to multiply by a LoTR representation of tensor on-the-fly. I'm eager to add the missing LoTR program mechanics to allow that :pray:

daskol commented 1 month ago

If I am understand you correctly, you think that the LoTR is used to compress the whole weight matrix. However, I think there is a confusion. The LoTR is used to represent correction $\delta W$ in low-rank form directly. So it is not used to represent to whole weight matrix $W$, only the correction part $\delta W$ to the original weight matrix $W$. The whole weight matrix can be high rank while corrections are low-rank. For your use case, you might consider using TT decomposition or other tensor decompositions for the whole weight matrix, e.g. (Novikov, 2015) or (Chekalina, 2023).