arcee-ai / mergekit

Tools for merging pretrained large language models.
GNU Lesser General Public License v3.0
4.51k stars 395 forks source link

Idea regarding the new 8x22b Mixtral model and the inverse of 'model stock' method #267

Open jukofyork opened 5 months ago

jukofyork commented 5 months ago

I see people are trying to extract the Mistral-22b ancestor from the MoE model by averaging the MLP layers and wondered if the 'model stock' method in Mergekit could be inverted:

No idea if it could work and don't expect you to go to a lot of trouble to try this, but if anyone reading is interested or knows those currently trying to get the 22b ancestor; it could be worth investigating.

mhenrichsen commented 5 months ago

It seems that someone already have had some success extracting a single 22b model from the 8 experts. https://huggingface.co/Vezora/Mistral-22B-v0.2

Maybe @cg123 can share the script?

jukofyork commented 5 months ago

It seems that someone already have had some success extracting a single 22b model from the 8 experts. https://huggingface.co/Vezora/Mistral-22B-v0.2

Maybe @cg123 can share the script?

https://arxiv.org/abs/2403.19522 https://github.com/arcee-ai/mergekit/blob/main/mergekit/merge_methods/model_stock.py

It probably won't work as it's really aimed at fine-tuned models rather than continued pretraining, and also MoE training will try to make the models specialise and the assumptions from the paper won't be valid.

There's probably lots of other ways to try to infer what the ancestor model could have been too, eg: calculate the gradient for the full MoE over a sample of data, and instead of moving in the direction that improves the model; project backwards and try to find the point where the projections are closest (you can possibly even use the diagonal hessian estimate to help with this too).

It's definitely not an easy thing to do and if they have something working in v0.2 then they must have found some way (just averaging the weights like they tried in v0.1 yesterday was just destined to fail).

thomasgauthier commented 5 months ago

@jukofyork I have been playing around with merging the 8x22B experts into a dense model.

So far the best results I had were using the softmaxed router logits over exllamav2 calibration data as weights for linear layer-wise merging. Doing it naively results in a model with higher perplexity than the best performing single dense expert. However when combined with a top_k=2 mask and a low temperature (near 0), I got a lower perplexity than any single dense expert model.

Looking at the router logits, it is clear that lower layers tend prefer some experts while later layers prefer other experts. Using the router logits as a source for merging weights results in a dense model that has some consistency with how the activations usually happen in the MoE model, at least more so than when using any single expert or average.

This is not the same however as finding the ancestor, but it seems promising so far for condensing MoE models into minimally performant dense models.

I should note I have not tried this method yet on the 8x22B model as I was just experimenting with a tiny 4x400M mixtral model.

I will be trying this on the 8x22B and if results show the technique scales well I will be probably make a PR for integrating this as a new merge method in mergekit.

jukofyork commented 5 months ago

@jukofyork I have been playing around with merging the 8x22B experts into a dense model.

So far the best results I had were using the softmaxed router logits over exllamav2 calibration data as weights for linear layer-wise merging. Doing it naively results in a model with higher perplexity than the best performing single dense expert. However when combined with a top_k=2 mask and a low temperature (near 0), I got a lower perplexity than any single dense expert model.

Yeah, that seems like a good method and like you say, the distribution it's passing on should be somewhat similar to what happens in the actual MoE.

I assume you are summing up the logits over the calibration sample and then passing this sum through the softmax to get the weights? This is the correct way do it (as opposed to trying to renormalise the arithmetic mean of the post-softmax outputs, etc).

jukofyork commented 5 months ago

You could probably also extract multiple complimentary/orthogonal models with your idea too:

You could then (in theory) put these back in a k×22B MoE, but I think you'd have to retrain the routing networks.

Using k-means is just the simplest way I can think of, but there are probably better ways to do it via an SVD to enforce/encourage orthogonality.

thomasgauthier commented 5 months ago

Yes true, the technique could be useful for creating MoEs too, didn't think of that!

To answer your question, I'm actually doing the softmax per token per layer (same as the model is doing during inference, but with temperature applied before and with top k mask after).

Then, at each layer I produce one value for each expert by taking the median of the expert probability per tokens, that gave me better results than using the mean. Maybe k-means could be useful here too.

Finally I take the average of that last transformation over all samples in the calibration data. That leaves us with a [num_layer x num_experts] tensor that is used as weights for a layer-wise merge.

But I think you're right, I will try doing the softmax over the logits sum instead as it probably works better.

I hope my explanation makes sense, those things are hard to describe with words sometimes. Happy to share code if it helps.

jukofyork commented 5 months ago

Yeah, it's not easy to visualise.

Screenshot_2024-04-13-22-38-03-62_f541918c7893c52dbd1ee5d319333948

From: https://arxiv.org/pdf/2101.03961.pdf

So my thinking was:

  1. That for every token and every layer you will get 8 pre-logit values (just before the blue categorical distribution in the picture).
  2. Sum up these into a [num_layer x 8] matrix.
  3. Then at the end, pass the 8 sums for each layer into the softmax function.
  4. Use this to take a weighted average of the 8 MLPs in each layer (light blue FFN in the picture).

It's important to do the summing of logits rather than taking the arithmetic mean of the post-softmax outputs. You can take the geometric mean of the post-softmax outputs if that's easier to get at in the code (https://en.wikipedia.org/wiki/Geometric_mean - using the log/sum/exp formula for why).

My idea about using k-means is simply to add a step between 2 and 3 where you calculate multiple centroids of the logit sums and then pass each into the softmax and so on.

Rather than summing into an [num_layer x 8] matrix, you could try storing an [num_layer x num_samples x 8] tensor and then run SVD on each [num_samples x 8] slice and then de-stretch/orthogonalise:

1200px-Singular-Value-Decomposition svg

but I'd try this after stock k-means as may not be important.


I'm also not sure if using all 8 of the values rather than the way that Mixtral only uses 2 will make a difference (I know people have experimented with using > 2 experts and it made the perplexity worse), but I'd try using all 8 to start with.

If this is a problem then probably the best thing to do would be to store the sum of the top 2 post-softmax outputs (of every sample) and use this to scale the final softmax weighing factors used to create the averaged weights.

jukofyork commented 5 months ago

I'm not 100% sure I have understood your method as rereading what I just wrote comes back to a very complicated way of making a weighted average...

jukofyork commented 5 months ago

So far the best results I had were using the softmaxed router logits over exllamav2 calibration data as weights for linear layer-wise merging. Doing it naively results in a model with higher perplexity than the best performing single dense expert. However when combined with a top_k=2 mask.

How does the MoE gating work in Mixtral? Does it actually take the top 2 categorical outputs and then weight the output of the 2 routed MLPs or treat them equally?

EDIT:

Screenshot_2024-04-13-23-32-25-74_3aea4af51f236e4932235fdada7d1643

https://www.linkedin.com/pulse/exploring-mixtral-8x7b-deep-dive-its-architectural-ashish-patel--hkz9f

thomasgauthier commented 5 months ago

Yes in Mixtral the probability distribution from the logits softmax is used to weigh the experts (only top 2), they are not treated equally.

But what you described sounds like it could work, I think I'll give it a shot!

jukofyork commented 5 months ago

No, I think my interpretation of your idea (which I might have got completely wrong!) is doomed to not work. The only way it could work is if the frequency of different experts isn't uniform.

Also I think looking at that you are correct in using the post-softmax outputs for the weights too! Basically you want to try to take a sort of "superimposition" of the 8 networks weighted by their output (I think)?

thomasgauthier commented 5 months ago

Yeah that was the basic idea behind my approach, I don't have a strong theoretical basis for why it should work but it felt worth trying.

The intuition is that there must exist some linear combination of the experts that work decently as a dense model, we just have to find it. The router was the first place I thought of to find that combination as it already contains some intelligence / knowledge regarding the weighing of experts.

The only way it could work is if the frequency of different experts isn't uniform.

From my testing it is not nearly uniform. But that was on the 4x400M, I don't know about 8x7B or 8x22B

jukofyork commented 5 months ago

What aboit this much simpler idea what doesn't require and averaging:

IMG_20240414_001704

For each layer find the FFN who contributed the most at the place marked with a red cross while you pass in the calibration dataset.

and also find the average going into the point with the blue cross.

Then construct a network taking the maximum contribution FFN and then rerun the calibration dataset through layer 1 and adjust the norm layer after (or anything you can adjust that is linear) to match the average stored at the blue cross. Then move on to layer 2 and repeat.

I think this has a much better chance to work that averaging the weights of the FFNs and also relies on the non-uniformity of the FFNs being fired off?

It avoids a lot of the argmax of 2 mixed with the softmax out the outputs,, etc too!

thomasgauthier commented 5 months ago

I can see how focusing on minimizing the layer output differences with the full MoE model might provide good results, but by keeping only one expert and without changing its weights wouldn't we lose some knowledge / representation? The idea behind merging multiple experts was to preserve some of the features of the MoE, like when merging dense models from different domain.

jukofyork commented 5 months ago

The intuition is that there must exist some linear combination of the experts that work decently as a dense model, we just have to find it. The router was the first place I thought of to find that combination as it already contains some intelligence / knowledge regarding the weighing of experts. The problem with this is if you think about the non-linearly of the activation functions: it can only work if they are very similar, but otherwise they will clash. Using relu as an example:

(max(0, x×1.1) + max(0, x×0.9)) / 2 = x×1

but:

(max(0, x×1) + max(0, x×-1)) / 2 = x×0.5 rather than 0

I would assume that the MoE training would push the weights to be quite diverse to make use of the non-lineararity, but it's possible it doesn't as LLMs seem to be quite undertrained and wasteful of parameters anyway?

jukofyork commented 5 months ago

I can see how focusing on minimizing the layer output differences with the full MoE model might provide good results, but by keeping only one expert and without changing its weights wouldn't we lose some knowledge / representation? The idea behind merging multiple experts was to preserve some of the features of the MoE, like when merging dense models from different domain.

Yeah, it's definitely going to lose some knowledge but it might be the case the keeping the FFN intact and working properly retains more than trying to blend stuff passed through a non-linear function? It's definitely worth trying both and comparing to see what happens!

thomasgauthier commented 5 months ago

Yes agreed! Trying both to compare could be worthwhile.

The model I did my testing on is very much undertrained, I guess that could explain the encouraging results I had. Left to see if the results hold with better trained models.

jukofyork commented 5 months ago

I also think the Idea of projecting back to find the ancestor might have some (tiny) hope too, but don't really know what would be a good place to start.

If we assume that each FFN came from the same ancestor and we know the dynamics used to move the weights away from this ancestor towards where they are now, then each set of weights should reduce the degrees of freedom for where the original weights came from and even thought it's unlikely we can find the exact ancestor we might be able to say pick the point in the weight space where they were all closest.

The positive gradient over a calibration sample like you are using should point back to where each set of weights came from very recently and the inverse of the hessian should tell you how far back, but as soon as you move a little away from the current location there are near infinite ways to fly off to a local maximum and only 8 degrees of freedom to try to guide us from this... ☹️

jukofyork commented 5 months ago

Yes agreed! Trying both to compare could be worthwhile.

The model I did my testing on is very much undertrained, I guess that could explain the encouraging results I had. Left to see if the results hold with better trained models.

Keep us posted on how you get on! It's a very interesting problem and I think if you can crack the case of n=1 you can extend to n>1 and probably quite easily make a smaller MoE - just having to retrain the router is probably quite easy compared.

jukofyork commented 5 months ago

Just thought of another way that might work, but not related to merging and would take some serious compute:

Eventually the penalty will get so large all the 8 weights will be the same and with luck the model should still have had time to adapt to this penalty being forced on it.