mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.37k stars 817 forks source link

Has any thought been given to using LoRA to increase the number of experts (100x) with minimal memory? #95

Open sixChar opened 7 months ago

sixChar commented 7 months ago

As I understand the current MoeLayer, a gate calculates the weight to be applied to the output of each expert, the top k are selected and run on the data, finally the results are multiplied by their respective weights and summed.

This means you have to store n copies of the layers, one for each expert.

If instead you had a single base set of parameters and each expert was defined by a low rank matrix you could hold a lot more experts in the same memory. Calculating the weights would be the same but instead of taking a weighted sum of the output you could take a weighted sum of the parameters and then use the summed parameters to calculate the result. (This would start to look a lot like Schmidhuber's fast weight programmers)

Using the code below, I could add 100 experts (at rank 4) to a feed forward module with dim=512 and hid_dim=2048 with only ~2x increase in number of parameters. I would expect this ratio to get better as the dim/hid_dim gets larger.

Is there some fatal flaw that makes this approach not worth it?

Here are the possible flaws I could think of but none of them strike me as compelling.

I didn't test performance so it may not work as well. Although LoRA has worked pretty well in a number of places it may be that MoE relies on high rank differences between experts. It could also be that the linear combination of matrices prior to the non-linearity(s) is not as powerful as combining the result.

Additionally, there is a performance hit to the mixing of experts as you have to do on the order of 3 * (dim + hid_dim) * rank * num_tokens * num_experts_per_token operations to mix the parameters plus the extra (dim, rank) * (rank, hid_dim) matrix multiplications and the addition with the base matrix. I haven't really looked at it but I'm pretty sure this is made up for by the fact that once the models are mixed you are running a single feed-forward on the input rather than multiple.

Thoughts?

Some rough code to illustrate the idea based off of the FeedForward and MoeLayer modules: `

class MultiLoraLinear(nn.Module):
     def __init__(self, ins:int, outs:int, num_loras:int, rank:int):
          super().__init__()
          # I'm pretty sure nn.Linear initializes weights from uniform(-sqrt(num_ins),+sqrt(num_ins))
          init_scale = 2 / sqrt(ins)
          self.w_base = nn.Parameter((torch.rand(ins, outs) - 0.5) * init_scale)
          self.w_loras_a = nn.Parameter((torch.rand(num_loras, ins, rank) - 0.5) * init_scale)
          self.w_loras_b = nn.Parameter((torch.rand(num_loras, rank, outs) - 0.5) * init_scale)
          self.num_loras = num_loras
          self.ins = ins
          self.outs = outs

      def forward(self, x, expert_weights, expert_indices):
          ## construct weight matrix from weighted sum of lora params
          # select out the lora params to use 
          selected_w_loras_a = self.w_loras_a[expert_indices,:,:]

          # multiply the subset of lora params by their weighting and sum them
          w_lora_a = torch.sum(selected_w_loras_a * expert_weights.unsqueeze(-1).unsqueeze(-1), dim=1)
          selected_w_loras_b = self.w_loras_b[expert_indices,:,:]
          w_lora_b = torch.sum(selected_w_loras_b * expert_weights.unsqueeze(-1).unsqueeze(-1), dim=1)

          # Construct the full lora matrix as lora_a * lora_b transpose (but for each token/batch)
          # b: batch, i: num ins, k: rank, o: num outs
          w_lora = torch.einsum("bik,bko->bio", w_lora_a, w_lora_b)
          w = self.w_base + w_lora

          return torch.einsum("bi,bio->bo", x, w)

class MoreMoeFeedForward(nn.Module):
        def __init__(self, gate: nn.Module, dim: int, hid_dim: int, num_experts: int, lora_rank: int, num_experts_per_tok=5):
            super().__init__()
            assert num_experts > 0
            self.gate = gate
            self.num_experts = num_experts
            self.lora_rank = lora_rank
            self.num_experts_per_tok = num_experts_per_tok

            self.w1 = MultiLoraLinear(
                dim,
                hid_dim,
                num_experts,
                lora_rank,
            )
            self.w2 = MultiLoraLinear(
                hid_dim,
                dim,
                num_experts,
                lora_rank,
            )
            self.w3 = MultiLoraLinear(
                dim,
                hid_dim,
                num_experts,
                lora_rank,
            )

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            # Mostly copied from MoeLayer
            inputs_squashed = x.view(-1, x.shape[-1])
            gate_logits = self.gate(inputs_squashed)
            weights, indices = torch.topk(
                gate_logits,
                self.num_experts_per_tok
            )
            weights = F.softmax(
                weights,
                dim=1,
                dtype=torch.float
            ).type_as(x)

            # Mostly copied from FeedForward
            res_squashed =  self.w2(
                nn.functional.silu(
                    self.w1(inputs_squashed, weights, indices)
                ) * self.w3(inputs_squashed, weights, indices),
                weights,
                indices
            )
            return res_squashed.view(x.shape)

`

WillJStone commented 7 months ago

A relevant publication

https://arxiv.org/abs/2310.18339

sixChar commented 7 months ago

That's basically what I was thinking of except specific to the forward module and not only for fine-tuning

sixChar commented 7 months ago

Also I did a little profiling on a CPU on a smaller model with batch size 4, 1024 tokens and 8 experts (with 3 used per token).

Initializing the model and running one inference on a random input gives:

CPU Time: 44.692ms CPU Memory: 9.39 Gb

CPU Time: 10.150ms CPU Memory: 1.60 Gb

Additionally, I don't think you probably wouldn't need to retrain something like Mistral 8x7B from scratch. You could do some stuff with singular value decomposition to get a good approximation of the base and lora matrices.

Edit: For fun I profiled the same model but with 100 experts using 30 of them per token:

CPU Time: 31.468ms CPU Memory: 2.62 Gb

Edit 2: Thinking a bit more, I'm not sure that this setup will be as effective since the results of multiple networks won't necessarily be the same as the result of a single network that is a linear combination of their features. You might need to run the experts separately which might take a similar amount or even more time than the original way. However memory usage should still be dramatically improved.

WillJStone commented 7 months ago

Yeah, my plan was to implement something that (unfortunately) has to call the FFN num_experts_per_token times, each with a different adapter. Slow but low memory and... maybe has some benefits? I guess the test would be determining if you're better off using a larger LoRA rank with the standard method, or doing the "MoLoRA" method with a roughly equivalent number of new params to the one big LoRA.

(PS Happy new years everyone :smile:)

sixChar commented 7 months ago

Happy new year!

You're right that it would be slower than the currently used method for sparse mixture of experts but I don't know if it would be that much slower since the current method runs num_experts_per_token different full networks. Adding on the calculations for the LoRA version should be within 2x slower. It also might be more amenable to cache optimization since most of the FFN operations will be using the base parameters.

WillJStone commented 6 months ago

Hey did you ever end up trying to train this? As soon as I turn even a small number of gradients on it OOMs pretty quick. I tried with my own implementation that I developed without looking at yours, and I just tried again with yours (slightly modified) and it still happens

sixChar commented 6 months ago

So it turns out I'm really f**ing stupid and forgot to add a line to actually run the model on some input when I was profiling.

It is not actually better on memory, it is much worse. (like 6-7 times worse on a small example)

I think it's because you're essentially making a new weight matrix for each token. There might be a work around but I'm not sure.

Sorry for wasting your time.

WillJStone commented 6 months ago

Nah, no time wasted. I was already playing with this idea before I saw your initial comments. I'm working on implementing the paper I linked above now. The authors released their code but it's a mess