Closed hurongliang closed 8 months ago
Could you share the command you ran to get that issue? (Also that argpartition error message looks a bit funky, we should probably fix it).
In the current implementation of Moe, it cannot handle less than 2x3, so 2x2 won't work with TopK. I have encountered the similar error before but haven't had the chance to look into the details.
Huh, so that’s generation with 2 experts? I can check it, maybe there is a bug in topk.
cloudyu/Yi-34Bx2-MoE-60B-DPO
Yeah, there are 2 experts with 2 local experts in the model that don't work with the current Moe sparse block. https://huggingface.co/cloudyu/Yi-34Bx2-MoE-60B-DPO/blob/main/config.json#L17-L20
Btw, I think the bug is how we use argpartition
. It should receive ne-1
instead of ne
here
Basically argpartition
puts the kth
value in it's sorted spot (where kth
is zero-based). So to get the smallest two items you would do argpartition(a, kth=1)[:2]
. If kth
is the number of experts then it's out of bounds so argpartition
crashes.
@awni
Could you share the command you ran to get that issue? (Also that argpartition error message looks a bit funky, we should probably fix it).
Step by step to get the issue.
Download cloudyu/Yi-34Bx2-MoE-60B-DPO
huggingface-cli download --resume-download cloudyu/Yi-34Bx2-MoE-60B-DPO
Convert weights to mlx format with 4bit quantized.
from mlx_lm import convert
convert('cloudyu/Yi-34Bx2-MoE-60B-DPO', mlx_path='mlx-community/Yi-34Bx2-MoE-60B-DPO-4bit-mlx', quantize=True, q_bits=4)
Run python test_hello.py mlx-community/Yi-34Bx2-MoE-60B-DPO-4bit-mlx -mlx
.
# file test_hello.py
from mlx_lm import load, generate
import argparse
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Convert models on huggingface hub to mlx format.")
parser.add_argument("repo_id", help="repo_id on huggingface.co")
parser.add_argument("-mlx", "--mlx", help="True if model is mlx format", action="store_true", default=False)
parser.add_argument("-chat", "--chat", help="True if model is a chat model", action="store_true", default=False)
parser.add_argument("--prompt", help="Prompt", default="hello")
args = parser.parse_args()
is_mlx = args.mlx
repo_id = args.repo_id
is_chat_model = args.chat
prompt = args.prompt
if is_mlx:
model, tokenizer = load(repo_id)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
print(response)
else:
if is_chat_model:
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).to('mps')
model = model.eval()
response, history = model.chat(tokenizer, prompt, history=[])
print(response)
else:
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id)
input_ids = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**input_ids, max_new_tokens=10000)
print(tokenizer.decode(outputs[0]))
Console output
==========
Prompt: hello
Traceback (most recent call last):
File "/Users/hurongliang/git/llmresearch/test_hello.py", line 20, in <module>
response = generate(model, tokenizer, prompt=prompt, verbose=True)
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/utils.py", line 214, in generate
for (token, prob), n in zip(
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/utils.py", line 157, in generate_step
logits, cache = model(y[None], cache=cache)
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 260, in __call__
out, cache = self.model(inputs, cache)
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 243, in __call__
h, cache[e] = layer(h, mask, cache[e])
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 209, in __call__
r = self.block_sparse_moe(self.post_attention_layernorm(h))
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 160, in __call__
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
ValueError: [argpartition] Received invalid kth 2along axis -1 for array with shape: (1,2)
Thank you @hurongliang. The fix is in #515
Origianl model: cloudyu/Yi-34Bx2-MoE-60B-DPO
Error stack trace:
Dependencies