Leeroo-AI / mergoo

A library for easily merging multiple LLM experts, and efficiently train the merged LLM.
https://www.leeroo.com/
GNU Lesser General Public License v3.0
360 stars 19 forks source link

[Fix] Fix the Error of q, k, and v states must have the same dtype when using flash attention forward. #15

Closed jacklanda closed 1 month ago

jacklanda commented 1 month ago

Step for Error Reproduction

Once we enable the argument use_cache=True in the Hugging Face generation_config something like the following:

generation_config = GenerationConfig(
            bos_token_id=128000,
            eos_token_id=128001,
            pad_token_id=self.tokenizer.pad_token_id,
            use_cache=True,
)

You can simply pass the generation config with use_cache=True to model and make it forward as usual, and I believe you'll get the error I met.

Error Message

I will get the following error message: "query and key must have the same dtype". And I found that this is due to the misalignment between the dtype of query_state and value_state.

Potential Solution

I guess there exists potential casting for q, k, and v states. Hence, I try to cast back the dtype of q, k, and v to the same target dtype.

Reference

The committed code I wrote strongly refers to the existing implementation in mergoo.