keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
805 stars 244 forks source link

Polishing T5 inference #1698

Open monatis opened 4 months ago

monatis commented 4 months ago

Hi, as it's reported in several issues (#1271, #1413), t5 still lacks some of workflows. Particularly, I'm trying to optimize T5 conditional generation. I started by porting code from BartSeq2SeqLM, but one immediate thing that caught my attention is that T5 uses its own MHA implementation which lacks the kv cache functionality implemented in CachedMultiHeadAttention. This can be achieved in two ways:

  1. Add rel_attn_bias support to CachedMultiHeadAttention, or
  2. Add kv cache support to T5MultiHeadAttention. I'm also planning to upstream what I came up with. The question is, which one would you prefer, and which one do you think would be easier to hack? I'm more for the option 2, but is there anything I'm missing?
mattdangerw commented 4 months ago

Much obliged! That'd be quite helpful.

Add kv cache support to T5MultiHeadAttention.

This is the better option IMO. It's too hard to keep all the different LLMs today using a single attention class, so we've basically relegated CachedMultiHeadAttention/TransformerDecoder/TransformerEncoder to models that follow the original transformer architecture, like BERT, GPT2, etc.

Basically everything these days will diverge a bit (usually putting position information in attention). So extending T5MultiHeadAttention will be consistent with our current state. We can always refactor in the future if the need is clear. Though honestly RoPE is the dominant choice, not relative bias like this.