ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.18k stars 875 forks source link

Error 'Received invalid kth 2along axis -1 for array with shape: (1,2)' when generate using mixtral model with MLX format 4 quantized. #504

Closed hurongliang closed 8 months ago

hurongliang commented 8 months ago

Origianl model: cloudyu/Yi-34Bx2-MoE-60B-DPO

Error stack trace:

Traceback (most recent call last):
  File "/Users/hurongliang/git/llmresearch/test_hello.py", line 20, in <module>
    response = generate(model, tokenizer, prompt=prompt, verbose=True)
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/utils.py", line 214, in generate
    for (token, prob), n in zip(
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/utils.py", line 157, in generate_step
    logits, cache = model(y[None], cache=cache)
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 260, in __call__
    out, cache = self.model(inputs, cache)
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 243, in __call__
    h, cache[e] = layer(h, mask, cache[e])
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 209, in __call__
    r = self.block_sparse_moe(self.post_attention_layernorm(h))
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 160, in __call__
    mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
ValueError: [argpartition] Received invalid kth 2along axis -1 for array with shape: (1,2)

Dependencies

pip list |grep mlx
mlx                0.4.0
mlx-lm             0.0.13
awni commented 8 months ago

Could you share the command you ran to get that issue? (Also that argpartition error message looks a bit funky, we should probably fix it).

mzbac commented 8 months ago

In the current implementation of Moe, it cannot handle less than 2x3, so 2x2 won't work with TopK. I have encountered the similar error before but haven't had the chance to look into the details.

awni commented 8 months ago

Huh, so that’s generation with 2 experts? I can check it, maybe there is a bug in topk.

mzbac commented 8 months ago

cloudyu/Yi-34Bx2-MoE-60B-DPO

Yeah, there are 2 experts with 2 local experts in the model that don't work with the current Moe sparse block. https://huggingface.co/cloudyu/Yi-34Bx2-MoE-60B-DPO/blob/main/config.json#L17-L20

awni commented 8 months ago

Btw, I think the bug is how we use argpartition. It should receive ne-1 instead of ne here

Basically argpartition puts the kth value in it's sorted spot (where kth is zero-based). So to get the smallest two items you would do argpartition(a, kth=1)[:2]. If kth is the number of experts then it's out of bounds so argpartition crashes.

hurongliang commented 8 months ago

@awni

Could you share the command you ran to get that issue? (Also that argpartition error message looks a bit funky, we should probably fix it).

Step by step to get the issue.

  1. Download cloudyu/Yi-34Bx2-MoE-60B-DPO

    huggingface-cli download --resume-download cloudyu/Yi-34Bx2-MoE-60B-DPO
  2. Convert weights to mlx format with 4bit quantized.

    from mlx_lm import convert
    convert('cloudyu/Yi-34Bx2-MoE-60B-DPO', mlx_path='mlx-community/Yi-34Bx2-MoE-60B-DPO-4bit-mlx', quantize=True, q_bits=4)
  3. Run python test_hello.py mlx-community/Yi-34Bx2-MoE-60B-DPO-4bit-mlx -mlx.

# file test_hello.py
from mlx_lm import load, generate
import argparse
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Convert models on huggingface hub to mlx format.")
    parser.add_argument("repo_id", help="repo_id on huggingface.co")
    parser.add_argument("-mlx", "--mlx", help="True if model is mlx format", action="store_true", default=False)
    parser.add_argument("-chat", "--chat", help="True if model is a chat model", action="store_true", default=False)
    parser.add_argument("--prompt", help="Prompt", default="hello")
    args = parser.parse_args()

    is_mlx = args.mlx
    repo_id = args.repo_id
    is_chat_model = args.chat
    prompt = args.prompt

    if is_mlx:
        model, tokenizer = load(repo_id)
        response = generate(model, tokenizer, prompt=prompt, verbose=True)
        print(response)
    else:
        if is_chat_model:
            tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
            model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).to('mps')
            model = model.eval()
            response, history = model.chat(tokenizer, prompt, history=[])
            print(response)
        else:
            tokenizer = AutoTokenizer.from_pretrained(repo_id)
            model = AutoModelForCausalLM.from_pretrained(repo_id)
            input_ids = tokenizer(prompt, return_tensors="pt")
            outputs = model.generate(**input_ids, max_new_tokens=10000)
            print(tokenizer.decode(outputs[0]))

Console output

==========
Prompt: hello
Traceback (most recent call last):
  File "/Users/hurongliang/git/llmresearch/test_hello.py", line 20, in <module>
    response = generate(model, tokenizer, prompt=prompt, verbose=True)
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/utils.py", line 214, in generate
    for (token, prob), n in zip(
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/utils.py", line 157, in generate_step
    logits, cache = model(y[None], cache=cache)
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 260, in __call__
    out, cache = self.model(inputs, cache)
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 243, in __call__
    h, cache[e] = layer(h, mask, cache[e])
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 209, in __call__
    r = self.block_sparse_moe(self.post_attention_layernorm(h))
  File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 160, in __call__
    mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
ValueError: [argpartition] Received invalid kth 2along axis -1 for array with shape: (1,2)
awni commented 8 months ago

Thank you @hurongliang. The fix is in #515