Closed WilTay1 closed 12 months ago
Yes, this model can be used with any version of LLaMA, 7B, 13B and 70B.
"""
Adapted from https://github.com/cedrickchee/llama/blob/main/chattyllama/combined/inference.py
"""
key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}
for i, ckpt in enumerate(ckpts):
checkpoint = torch.load(ckpt, map_location="cpu")
for parameter_name, parameter in self.llama.named_parameters():
short_name = parameter_name.split(".")[-2]
if "gate" in parameter_name or "lora" in parameter_name or "bias" in parameter_name:
continue
if key_to_dim[short_name] is None and i == 0:
parameter.data = checkpoint[parameter_name]
elif key_to_dim[short_name] == 0:
size = checkpoint[parameter_name].size(0)
parameter.data[size * i: size * (i + 1), :] = checkpoint[
parameter_name
]
elif key_to_dim[short_name] == -1:
size = checkpoint[parameter_name].size(-1)
parameter.data[:, size * i: size * (i + 1)] = checkpoint[
parameter_name
]
del checkpoint
The above shown part of our code makes it compatible with the different forms of LLaMA weights, even the ones with split weights (13B and 70B).
And what is the Python version for this model?
We use Python 3.9.17 for our project (detail has been added to README).
I hope this answers your question, and you can close the issue if this is resolved. We would appreciate you starring our repo 😊.
And what is the Python version for this model?