You can simply pass the generation config with use_cache=True to model and make it forward as usual, and I believe you'll get the error I met.
Error Message
I will get the following error message: "query and key must have the same dtype". And I found that this is due to the misalignment between the dtype of query_state and value_state.
Potential Solution
I guess there exists potential casting for q, k, and v states. Hence, I try to cast back the dtype of q, k, and v to the same target dtype.
Reference
The committed code I wrote strongly refers to the existing implementation in mergoo.
Step for Error Reproduction
Once we enable the argument
use_cache=True
in the Hugging Face generation_config something like the following:You can simply pass the generation config with
use_cache=True
to model and make it forward as usual, and I believe you'll get the error I met.Error Message
I will get the following error message: "query and key must have the same dtype". And I found that this is due to the misalignment between the
dtype
ofquery_state
andvalue_state
.Potential Solution
I guess there exists potential casting for q, k, and v states. Hence, I try to cast back the
dtype
of q, k, and v to the same targetdtype
.Reference
The committed code I wrote strongly refers to the existing implementation in mergoo.