microsoft / MInference

[NeurIPS'24 Spotlight] To speed up Long-context LLMs' inference, approximate and dynamic sparse calculate the attention, which reduces inference latency by up to 10x for pre-filling on an A100 while maintaining accuracy.
https://aka.ms/MInference
MIT License
798 stars 38 forks source link

[Feature Request]: Support Mistral Model #39

Open PatchouliTIS opened 4 months ago

PatchouliTIS commented 4 months ago

Is your feature request related to a problem? Please describe.

May I ask if you are going to implement your framework on Mistral and Mixtral models? Including optimizations for vllm?

Describe the solution you'd like

No response

Additional context

No response

iofu728 commented 4 months ago

Hi @PatchouliTIS, thanks for your suggestion.

We are happy to support Mistral-style models. However, we have not found a powerful and widely recognized version with long context windows greater than 128K. Do you have any recommendations?

PatchouliTIS commented 4 months ago

Thank you for your reply : ) In my scenario I want to utilize your framework to do some inference task on an distilled Mistral-style model and its max_seq_lens is 4096, which is far from 'long context windows', so I'm afraid I can't offer much useful insight.

btw, there are more questions about MInference I would like to ask:

  1. why there is only vertical_and_slash method in ./config/XXX_best_pattern.json ? Block-Sparse AttnImpl seems never been used.
  2. How to generate Optimal Sparse Pattern for mistral-style model? I followed the instructions in Offline Kernel-Aware Sparse Pattern Search:
cd experiments/infinite_bench
python run_infinitebench.py \
    --task kv_retrieval \
    --model_name_or_path gradientai/Llama-3-8B-Instruct-262k \
    --data_dir ./data \
    --output_dir ./results \
    --max_seq_length 30000 \
    --rewrite \
    --is_search \
    --start_example_id 3 \
    --topk_dims_file_path Llama_3_8B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json \
    --num_eval_examples 20 --topk 1 --starting_layer 0 --attn_type minference

But the sparse pattern generation script NEED topk_dims_file_path as its input , which is just the pattern config I need to generate !

  1. could you please display the complete version of the dependency package? I found that sometimes MInference can run in py3.9 but fail in py3.10, success in vllm==0.4.2/0.4.3 but fail in 0.4.0
iofu728 commented 4 months ago

Hi @PatchouliTIS, thanks for your questions,

  1. The config we have released is the only vertical+slash version, which has relatively good generalization. In our tests, using "stream_llm" and "block_sparse" resulted in performance loss for some lengths and tasks.
  2. Apologies for any confusion caused by our parameter naming. When using --is_search, --topk_dims_file_path refers to the path for the output config.
  3. Due to an oversight, we currently only support vllm>=0.4.1. You can upgrade to this version to use MInference.
PatchouliTIS commented 4 months ago

image I composed the sparse pattern generation script above, created the corresponding json file under current directory, commenced this script, and only received errors like this: image

it seems the topk_dims_file_path is still been considered as an input needed to parse in the current version of MInference.

iofu728 commented 4 months ago

Hi @PatchouliTIS, I checked the logs, and it seems that you cannot use vllm during the search process. You need to set --attn_type minference.

polarispw commented 4 months ago

Hi @iofu728, @PatchouliTIS, is it going well for the pattern search? When applying the script to my customized Llama-based MoE models, there are some problems.

  1. The search will exit by "assert False" after the search of the last layer, instead of showing results. Is it ok? 图像2024-7-19 14 40
  2. Though exits abnormal, optimal patterns have been saved and I manually add it to the model2path.py. But when evaluating ppl, things become worse. Do you have any idea about this? 图像2024-7-19 14 42

All experiments are based on transformers=4.41.0

iofu728 commented 4 months ago

Hi @iofu728, @PatchouliTIS, is it going well for the pattern search? When applying the script to my customized Llama-based MoE models, there are some problems.

  1. The search will exit by "assert False" after the search of the last layer, instead of showing results. Is it ok? 图像2024-7-19 14 40
  2. Though exits abnormal, optimal patterns have been saved and I manually add it to the model2path.py. But when evaluating ppl, things become worse. Do you have any idea about this? 图像2024-7-19 14 42

All experiments are based on transformers=4.41.0

Hi @polarispw, thanks for the feedback.

  1. Yes, assert False indicates that the search process has ended.
  2. If the 1k PPL is that high, I suspect there might be an issue with the patch. Try using minference_with_dense; if it's also high, it confirms that there's a problem with the patch.
polarispw commented 4 months ago

@iofu728 Thanks for your reply but using minference_with_dense doesn't seem to work. Could you please offer the following info to debug with patch:

  1. The changes made by patch are mainly in MHA and should be transparent to MLP?
  2. Which version of transformers was mainly used?
  3. Scores in the above terminal outputs look well(0.7~0.9), right? Your detailed reply will be helpful :)

    It works when only patching the attention part of the model, leaving other modules unchanged.

chenwuperth commented 3 months ago

Hi @iofu728, @PatchouliTIS

Thanks for the discussion

We are happy to support Mistral-style models. However, we have not found a powerful and widely recognized version with long context windows greater than 128K. Do you have any recommendations?

Would you consider supporting the Mistral long context model

https://huggingface.co/aws-prototyping/MegaBeam-Mistral-7B-512k

It is a mistral-7B model that supports 512K context and has scored 88.7 on the RULER long-context benchmark (where Llama3.1-8B scored 88.3). 100% on the NIAH. It works with vLLM (0.4.2 and 0.5.1 tested).

Currently for 512K the TTFT is around 52 seconds on 8 A100 w/o any optimisation, and with prefix caching is still 44 seconds. Would be interesting to see how it performs on [MInference]!

iofu728 commented 3 months ago

Hi @iofu728, @PatchouliTIS

Thanks for the discussion

We are happy to support Mistral-style models. However, we have not found a powerful and widely recognized version with long context windows greater than 128K. Do you have any recommendations?

Would you consider supporting the Mistral long context model

https://huggingface.co/aws-prototyping/MegaBeam-Mistral-7B-512k

It is a mistral-7B model that supports 512K context and has scored 88.7 on the RULER long-context benchmark (where Llama3.1-8B scored 88.3). 100% on the NIAH. It works with vLLM (0.4.2 and 0.5.1 tested).

Currently for 512K the TTFT is around 52 seconds on 8 A100 w/o any optimisation, and with prefix caching is still 44 seconds. Would be interesting to see how it performs on [MInference]!

Hi @chenwuperth,

Thank you for your support of MInference. We will look into this further. If you have the searched config, feel free to submit a PR. You can follow the guidelines here: Offline Kernel-aware Sparse Pattern Search.