Leeroo-AI / mergoo

A library for easily merging multiple LLM experts, and efficiently train the merged LLM.
https://www.leeroo.com/
GNU Lesser General Public License v3.0
360 stars 19 forks source link

Phi3 merge issue #10

Closed PhilipMay closed 1 month ago

PhilipMay commented 1 month ago

Hi.

I am executing this code to merge Phi-3 models: https://github.com/Leeroo-AI/mergoo/blob/main/notebooks/integrate_phi3_experts.ipynb

The result looks like this:

ll integrate_phi3_experts_float16/
total 14963528
-rw-r--r--  1 A337384  staff   293B May  8 10:18 added_tokens.json
-rw-r--r--  1 A337384  staff   4.6K May  8 10:16 config.json
-rw-r--r--  1 A337384  staff   7.1G May  8 10:18 model.safetensors
-rw-r--r--  1 A337384  staff   569B May  8 10:18 special_tokens_map.json
-rw-r--r--  1 A337384  staff   1.8M May  8 10:18 tokenizer.json
-rw-r--r--  1 A337384  staff   488K May  8 10:18 tokenizer.model
-rw-r--r--  1 A337384  staff   3.1K May  8 10:18 tokenizer_config.json

IMHO the merge of 3 Phi-3 models should have a much bigger size than 7.1G on disk. Because 7.1G is the size of one single Phi-3 model... Something is wrong here.

For debugging here is the result of a model print:

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
print(model)

Prints:

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3SuScaledRotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm()
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm()
      )
    )
    (norm): Phi3RMSNorm()
  )
  (lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
alirezamshi commented 1 month ago

Please use from mergoo.models.modeling_phi3 import Phi3ForCausalLM for loading the merged model as suggested in the notebook, then, re-print the model. Thanks

PhilipMay commented 1 month ago

Please use from mergoo.models.modeling_phi3 import Phi3ForCausalLM for loading the merged model as suggested in the notebook, then, re-print the model. Thanks

Ok. sure. New code:

from mergoo.models.modeling_phi3 import Phi3ForCausalLM
model_id =  "./integrate_phi3_experts_float16"
model = Phi3ForCausalLM.from_pretrained(model_id)
print(model)

Output:

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3SuScaledRotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): MoeLayer(
            (gate): Linear(in_features=3072, out_features=3, bias=False)
            (experts): ModuleList(
              (0-2): 3 x Linear(in_features=3072, out_features=16384, bias=False)
            )
          )
          (down_proj): MoeLayer(
            (gate): Linear(in_features=8192, out_features=3, bias=False)
            (experts): ModuleList(
              (0-2): 3 x Linear(in_features=8192, out_features=3072, bias=False)
            )
          )
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm()
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm()
      )
    )
    (norm): Phi3RMSNorm()
  )
  (lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
alirezamshi commented 1 month ago

MoE architecture is presented in FF layers e.g. (gate): Linear(in_features=8192, out_features=3, bias=False). Our Mixture-of-Experts feature is inspired by this work, which converts FF layers to MoE style, and averages attention layers. So, it is the expected output model.

PhilipMay commented 1 month ago

Well, I uploaded the merged model to HF:

https://huggingface.co/PhilipMay/Mergoo-Phi-3-MoE-example

On that page you can see that the MoE model has 3.82B params. That is exactly the same as a single base model.

How can a MoE model of 3 experts have the exact same number of parameters as one single expert. IMO something is wrong here - sorry.

PhilipMay commented 1 month ago

PS: In my opinion, that's the whole point of MoE models. To combine several basic models and then select a subset on demand at runtime. But if I select a subset of the same number of parameters, then my model doesn't get better, but worse/dumber.

gitsailor5 commented 1 month ago

@PhilipMay, With identical weights, only one copy is saved instead of 3. Could you verify the GPU memory usage when loading phi3 and phi3-MOE? The disparity should be noticeable. After training the model, since the weights won't be uniform, all expert weights will be saved.

gitsailor5 commented 1 month ago

*given that those weights are same, do check that first.

PhilipMay commented 1 month ago

But we merge 3 different models. The weights should not be the same.

gitsailor5 commented 1 month ago

Here's a test that verifies whether the weights you're trying to merge are exactly the same.

import torch
from transformers import AutoModelForCausalLM

model_1 = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-128k-instruct", 
    torch_dtype= torch.float16,
    device_map = "auto")

model_2 = AutoModelForCausalLM.from_pretrained(
    "RDson/Phi-3-mini-code-finetune-128k-instruct-v1", 
    torch_dtype= torch.float16,
    device_map = "auto")

model_3 = AutoModelForCausalLM.from_pretrained(
    "NickyNicky/Phi-3-mini-128k-instruct_function", 
    torch_dtype= torch.float16,
    device_map = "auto")

for i in range(len(model_1.model.layers)):
    layers = ["gate_up_proj", "down_proj"]
    for layer in layers:
        assert (
            getattr(model_1.model.layers[i].mlp, layer).weight == 
            getattr(model_2.model.layers[i].mlp, layer).weight).all().item()
        assert (
            getattr(model_1.model.layers[i].mlp, layer).weight == 
            getattr(model_3.model.layers[i].mlp, layer).weight).all().item()
gitsailor5 commented 1 month ago

@PhilipMay ^^