keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
740 stars 218 forks source link

Keep kv cache as list of tensors maybe better than one tensor #1562

Open lingzhi98 opened 3 months ago

lingzhi98 commented 3 months ago

Describe the bug If we keep kv cache as list of tensors, there has no need to concatenate kv caches of each decoder blocks (https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gemma/gemma_causal_lm.py#L225). It is helpful for model performance.

Expected behavior Remove useless concatenation to improve performance.

lingzhi98 commented 3 months ago

Spliting kv cache into key cache and value cache is also important (https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gemma/gemma_attention.py#L166).

mattdangerw commented 3 months ago

@lingzhi98 thanks! We are planning some generation improvements so will definitely check this out. Agreed we can let performance be our guide. Probably particularly jax compiled performance.

Were you thinking of a specific backend/compiled with XLA/not compiled? What's motivating the suggestion?

lingzhi98 commented 3 months ago

I use jax as keras backend. I have seen the concatenation become the main overhead if increasing batch size. Due to keep kv caches as one tensor, we need slice the kv cache to get corresponding key/value cache to compute attention output and then update cache. Dynamic update slice fusion will blocked by this slice op (https://github.com/openxla/xla/blob/main/xla/service/gpu/ir_emission_utils.cc#L472) and hurts performance again.