arcee-ai / mergekit

Tools for merging pretrained large language models.
GNU Lesser General Public License v3.0
4.04k stars 350 forks source link

Idea: Scaling the Down-Projection Matrix in 'Mixture of Experts' Models #294

Open jukofyork opened 2 months ago

jukofyork commented 2 months ago

Problem

In a Mixture of Experts (MoE) LLM, the gating network outputs a categorical distribution of $n$ values (chosen from $n{max}$), which is then used to create a convex combination of the $n$ outputs of the chosen expert MLP blocks only (eg: $n$=2 and $n{max}$ = 8 for Mixtral-8x7b and Mixtral-8x22b). If the model was trained to choose only the top $n$ experts and we want to change the chosen number of experts to $m$, how should we scale the down-projection matrix of the MLP to maintain the expected norm of the final output?

Solution

For simplicity, let's assume that the output of each expert is an i.i.d. random vector with a norm of $r$ and the gating network outputs a discrete uniform distribution where $g_i = \frac{1}{n}$ for all $i$. The final output is a convex combination of the expert outputs:

$$\vec{Sn} = \sum{i=1}^n g_i \vec{v_i}$$

The expected norm of this output is:

$$E[|\vec{Sn}|] = rE\left[\left|\sum{i=1}^n g_i \vec{ui}\right|\right] = r\sqrt{\sum{i=1}^n g_i^2} = \frac{r}{\sqrt{n}}$$

NOTE: The last equality holds only for a balanced distribution, where $g_i = \frac{1}{n}$ for all $i$.

If we change the number of experts to $m$, and the gating network outputs a balanced distribution over $m$ experts, the expected norm of the output becomes:

$$E[|\vec{Sm}|] = rE\left[\left|\sum{i=1}^m g_i \vec{ui}\right|\right] = r\sqrt{\sum{i=1}^m g_i^2} = \frac{r}{\sqrt{m}}$$

To make the expected norm of the output with $m$ experts equal to the expected norm of the output with $n$ experts, we need to scale the down-projection matrix of the MLP by a factor of $\sqrt{\frac{n}{m}}$:

$$\vec{v_i}' = \sqrt{\frac{n}{m}} \vec{v_i}$$

With this scaling, the expected norm of the output with $m$ experts becomes:

$$E[|\vec{Sm}|] = rE\left[\left|\sum{i=1}^m g_i \vec{vi}'\right|\right] = r\sqrt{\frac{n}{m}}E\left[\left|\sum{i=1}^m g_i \vec{u_i}\right|\right] = \frac{r}{\sqrt{n}}$$

Which is the same as the expected norm of the output with $n$ experts.

Scale Factor

The scale factor $\sqrt{\frac{n}{m}}$ depends only on the ratio of the original number of experts ($n$) to the new number of experts ($m$). It does not depend on the norm $r$ of the expert outputs (with the given assumptions...).


(sorry for the AI generated text again - but it's so much easier than trying to write all that Latex!)

This all assumes I have correctly understood what the Mixtral-style MoE architecture is doing though (it's not 100% clear from the paper).

If this shows promise then the i.i.d. assumption and the discrete uniform distribution simplification can be removed by sampling the actual outputs of the expert MLPs / gating networks (the i.i.d. assumption can be improved on if we are happy to just guess values for $\rho$ [see the other thread for example], but to use a concrete categorical distribution we would need to sample from it I think).

I'm going to try this on Mixtral-8x7b-Instruct now and see if it improves the perplexity vs pervious attempts:

https://rentry.org/HowtoMixtral https://old.reddit.com/r/LocalLLaMA/comments/18m6zjz/for_exllamav2_how_many_mixtral_experts_are/

@cg123 I see you already have a parameter called residual_scale so for the mergekit-moe merges it should be pretty easy to try scaling the models designed to not be in a MOE by $\frac{1}{\sqrt{m}}$ , etc.

jukofyork commented 2 months ago

Also it would be nice if we could post these sort of ideas to 'Discussions' instead of 'Issues'! :)

jukofyork commented 2 months ago

Well there is a pretty big problem with the discrete uniform distribution assumption and it's causing the weights to be scaled far too much... So without actually being able to measure anything the next best assumption is a Zipf distribution:

import numpy as np

def zipf_distribution(N, s):
    """Generate Zipf distribution for N experts with parameter s."""
    ranks = np.arange(1, N+1)
    weights = 1 / np.power(ranks, s)
    normalization = np.sum(weights)
    probabilities = weights / normalization
    return probabilities

def expected_norm_squared(probabilities, num_experts):
    """Calculate the expected norm squared for a subset of experts."""
    return np.sum(probabilities[:num_experts]**2)

def calculate_scaling_factor(N, n, m, s):
    """Calculate the scaling factor alpha for given N, n, m, and s."""
    probabilities = zipf_distribution(N, s)
    norm_squared_n = expected_norm_squared(probabilities, n)
    norm_squared_m = expected_norm_squared(probabilities, m)
    alpha = np.sqrt(norm_squared_n / norm_squared_m)
    return alpha

N = 8  # num_local_experts
n = 2  # num_experts_per_tok
s = 0  # Skew parameter (0 = Uniform, 0.5 = Square-Root, 1 = Zipf's law)

# Print the Zipf distribution for the given s
probabilities = zipf_distribution(N, s)
print(f"Zipf distribution for s = {s}: {[f'{p:.4f}' for p in probabilities]}")

# Loop over all values of m from 1 to N
for m in range(1, N+1):
    alpha = calculate_scaling_factor(N, n, m, s)
    print(f"For m = {m}, Scaling factor alpha: {alpha:.4f}")
Zipf distribution for s = 0: ['0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250']
For m = 1, Scaling factor alpha: 1.4142
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.8165
For m = 4, Scaling factor alpha: 0.7071
For m = 5, Scaling factor alpha: 0.6325
For m = 6, Scaling factor alpha: 0.5774
For m = 7, Scaling factor alpha: 0.5345
For m = 8, Scaling factor alpha: 0.5000
Zipf distribution for s = 0.5: ['0.2288', '0.1618', '0.1321', '0.1144', '0.1023', '0.0934', '0.0865', '0.0809']
For m = 1, Scaling factor alpha: 1.2247
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9045
For m = 4, Scaling factor alpha: 0.8485
For m = 5, Scaling factor alpha: 0.8105
For m = 6, Scaling factor alpha: 0.7825
For m = 7, Scaling factor alpha: 0.7606
For m = 8, Scaling factor alpha: 0.7429
Zipf distribution for s = 1: ['0.3679', '0.1840', '0.1226', '0.0920', '0.0736', '0.0613', '0.0526', '0.0460']
For m = 1, Scaling factor alpha: 1.1180
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9583
For m = 4, Scaling factor alpha: 0.9370
For m = 5, Scaling factor alpha: 0.9241
For m = 6, Scaling factor alpha: 0.9155
For m = 7, Scaling factor alpha: 0.9093
For m = 8, Scaling factor alpha: 0.9046

I'll see if I can run a grid-search overnight.

jukofyork commented 2 months ago

Here's the yaml file if anybody is interested:

# mergekit-yaml --verbose --cuda mixtral-scaled.yaml mixtral-scaled-m
# ~/LLMs/llama.cpp/convert.py mixtral-scaled-m --outfile mixtral-scaled-m.gguf --outtype q8_0
# ~/LLMs/llama.cpp/build/bin/perplexity -m mixtral-scaled-m.gguf -f ~/LLMs/misc/datasets/wikitext-2-raw//wiki.test.raw -ngl 1000

const_tag: &MODEL Mixtral-8x7B-Instruct-v0.1

############################################################################
# Don't forget to also set `num_experts_per_tok` value in `config.json`!!! #
############################################################################

#const_tag: &RESIDUAL_SCALE_FACTOR 1.1180  # [s=0 --> 7.2995]
#const_tag: &RESIDUAL_SCALE_FACTOR 1.0     # 4.4103 +/- 0.02355
const_tag: &RESIDUAL_SCALE_FACTOR 0.9583  # [s=0 --> 4.6758]

# The `down_proj` of each MLP expert seems to be held in the `w2.weight` tensor for Mixtral:
# > current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
# > current_hidden_states = self.w2(current_hidden_states)
models:
  - model: *MODEL
    parameters:
      scale:
        - filter: w2.weight
          value: *RESIDUAL_SCALE_FACTOR
        - value: 1.0

dtype: bfloat16
merge_method: passthrough
jukofyork commented 2 months ago

This isn't doing much useful... For 3 experts:

s=0 --> PPL = 4.6758 s=0.5 --> PPL = 4.5406 s=1 --> PPL = 4.4835 s=2 --> PPL = 4.4546

and s=2 is almost the same as unused:

Zipf distribution for s = 2: ['0.6547', '0.1637', '0.0727', '0.0409', '0.0262', '0.0182', '0.0134', '0.0102']
For m = 1, Scaling factor alpha: 1.0308
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9942
For m = 4, Scaling factor alpha: 0.9924
For m = 5, Scaling factor alpha: 0.9917
For m = 6, Scaling factor alpha: 0.9913
For m = 7, Scaling factor alpha: 0.9912
For m = 8, Scaling factor alpha: 0.9910

vs 2 experts & stock settings --> PPL = 4.4103 +/- 0.02355

It still may be useful to try setting residual_scale for the mergekit-moe merges as they are likely to be much more correlated and less likely to mess up the early embedding transformation layers...

jukofyork commented 2 months ago

So next I'm going to try to attenuate the MOE-routing softmax-gate's distribution:

# mergekit-yaml --verbose --cuda mixtral-scaled.yaml mixtral-scaled-m
# ~/LLMs/llama.cpp/convert.py mixtral-scaled-m --outfile mixtral-scaled-m.gguf --outtype q8_0
# ~/LLMs/llama.cpp/build/bin/perplexity -m mixtral-scaled-m.gguf -f ~/LLMs/misc/datasets/wikitext-2-raw/wiki.test.raw -ngl 1000

const_tag: &MODEL Mixtral-8x7B-Instruct-v0.1

############################################################################
# Don't forget to also set `num_experts_per_tok` value in `config.json`!!! #
############################################################################

const_tag: &QK_ATTENUATION_FACTOR 1.0    # NOTE: The scaling effect is QK_ATTENUATION_FACTOR^2 because of the dot-product!!!
const_tag: &GATE_ATTENUATION_FACTOR 0.9  # NOTE: Setting this < 1 will attenuate the MOE-routing softmax-gate's distribution.
const_tag: &RESIDUAL_SCALE_FACTOR 1.0    # NOTE: Attempt to rescale the residual stream when we change `num_experts_per_tok`.

models:
  - model: *MODEL
    parameters:
      scale:
        - filter: q_proj.weight
          value: *QK_ATTENUATION_FACTOR
        - filter: k_proj.weight
          value: *QK_ATTENUATION_FACTOR
        - filter: block_sparse_moe.gate.weight
          value: *GATE_ATTENUATION_FACTOR
        - filter: experts.w2.weight
          value: *RESIDUAL_SCALE_FACTOR
        - value: 1.0

dtype: bfloat16
merge_method: passthrough

and then the score matrix like we did for the frankenmerges:

jukofyork commented 2 months ago

Not really worth bothering with with I think. At best just going to get something about the same but slower to run:

# 2 experts & stock settings   : PPL = 4.4103 +/- 0.02355

# 3 experts
# QK_ATTENUATION_FACTOR 1.10   : PPL = 4.5309 +/- 0.02444
# QK_ATTENUATION_FACTOR 1.05   : PPL = 4.4808 +/- 0.02415
# QK_ATTENUATION_FACTOR 0.95   : PPL = 4.4471 +/- 0.02401
# QK_ATTENUATION_FACTOR 0.90   : PPL = 4.4858 +/- 0.02431
# GATE_ATTENUATION_FACTOR 1.50 : PPL = 4.5641 +/- 0.02446
# GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4235 +/- 0.02377
# GATE_ATTENUATION_FACTOR 1.10 : PPL = 4.4329 +/- 0.02385
# GATE_ATTENUATION_FACTOR 0.98 : PPL = 4.4561 +/- 0.02404
# GATE_ATTENUATION_FACTOR 0.95 : PPL = 4.4639 +/- 0.02410
# GATE_ATTENUATION_FACTOR 0.90 : PPL = 4.4807 +/- 0.02422
# GATE_ATTENUATION_FACTOR 0.80 : PPL = 4.5236 +/- 0.02454
# QK_ATTENUATION_FACTOR 0.95 & GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4218 +/- 0.02380

# 4 experts
# GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4539 +/- 0.02402
jukofyork commented 2 months ago

Maybe this idea does have some use after all. If we can scale the gate weight tensor with n=8 to work as closely as possible to n=2, then very low bit quantized models using 2-3 bpw might actually work better and see their perplexity grow less slowly (due to more active weights cancelling out more of the noise caused by quantization).

This assumes that the optimal scale factor doesn't just approximate the hard n=2 thresholding with a soft n=8 version that barely uses the other 6 sets of MLP weights (ie: doesn't shift the lower valued logits so far down that the Gumbel error distributions effectively head towards -inf and contribute almost nothing to the gated sum...).