huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.99k stars 26.29k forks source link

Extend Fx supported models with KV cache #33132

Open xuzifei-dmatrix opened 2 weeks ago

xuzifei-dmatrix commented 2 weeks ago

Feature request

I noticed only llama and opt models are supported for FX tracing with KV Cache right now, can I check what is the plan to extend it to more models? Thanks!

Motivation

I would like run fx traced graphmodules for generate(), which uses KV Cache. Right now it works for OPT and LLama, but I would like try on more models.

Your contribution

If someone could point me to the general design pattern to make a model FX supported with KV cache or the lines of changes in modeling_opt.py or modeling_llama.py that made them work, I would be happy to submit PRs to make more models work.

LysandreJik commented 2 weeks ago

May be of interest to @michaelbenayoun!