shansongliu / MU-LLaMA

MU-LLaMA: Music Understanding Large Language Model
GNU General Public License v3.0
221 stars 16 forks source link

Great work! Can I use llama-13B instead of 7B in the model? #3

Closed WilTay1 closed 12 months ago

WilTay1 commented 12 months ago

And what is the Python version for this model?

crypto-code commented 12 months ago

Yes, this model can be used with any version of LLaMA, 7B, 13B and 70B.

"""
Adapted from https://github.com/cedrickchee/llama/blob/main/chattyllama/combined/inference.py
"""
key_to_dim = {
    "w1": 0,
    "w2": -1,
    "w3": 0,
    "wo": -1,
    "wq": 0,
    "wk": 0,
    "wv": 0,
    "output": 0,
    "tok_embeddings": -1,
    "ffn_norm": None,
    "attention_norm": None,
    "norm": None,
    "rope": None,
}
for i, ckpt in enumerate(ckpts):
    checkpoint = torch.load(ckpt, map_location="cpu")
    for parameter_name, parameter in self.llama.named_parameters():
        short_name = parameter_name.split(".")[-2]
        if "gate" in parameter_name or "lora" in parameter_name or "bias" in parameter_name:
            continue
        if key_to_dim[short_name] is None and i == 0:
            parameter.data = checkpoint[parameter_name]
        elif key_to_dim[short_name] == 0:
            size = checkpoint[parameter_name].size(0)
            parameter.data[size * i: size * (i + 1), :] = checkpoint[
                parameter_name
            ]
        elif key_to_dim[short_name] == -1:
            size = checkpoint[parameter_name].size(-1)
            parameter.data[:, size * i: size * (i + 1)] = checkpoint[
                parameter_name
            ]
    del checkpoint

The above shown part of our code makes it compatible with the different forms of LLaMA weights, even the ones with split weights (13B and 70B).

And what is the Python version for this model?

We use Python 3.9.17 for our project (detail has been added to README).

I hope this answers your question, and you can close the issue if this is resolved. We would appreciate you starring our repo 😊.