CUNY-CL / yoyodyne

Small-vocabulary sequence-to-sequence generation with optional feature conditioning
Apache License 2.0
29 stars 16 forks source link

Add caching for transformer inference #154

Open Adamits opened 10 months ago

Adamits commented 10 months ago

Transformer inference (i.e. with no teacher forcing) is slow. In practice I think people typically implement some kind of caching so that at each timestep, we do not need to recompute the embeddings and attentions between all previously decoded timesteps.

I have a quick and dirty implementation of this in an experimental fork, where I basically tell the decoder layer to only get the attention from the most recently decoded target, and all other representations are concatenated on. There are probably other tricks that I can find by e.g. inspecting some huggingface transformers inference code.

I propose adding an option to the transformer encoder decoders to use caching, wherein a CacheTransformerDecoder module is used.

This is a low-priority TODO since we do validation with accuracy, rather than loss, and accuracy can be reliably predicted with teacher forcing if the targets are provided.

kylebgorman commented 10 months ago

So this is an approximation/hack, right? I'm fine with it, and maybe we could treat it as a separate architecture to keep things simple.

Adamits commented 10 months ago

So this is an approximation/hack, right?

I don't think so. I think e.g. the attention of t wrt t-1 will always be the same. So caching it is just a memory v. runtime tradeoff. The default is to pass in the full sequence of decoded tokens at every timestep and recompute all of the attentions each time. I will think that through more and test before opening a PR though.

treat it as a separate architecture to keep things simple.

Sure we can do it that way.

kylebgorman commented 10 months ago

So this is an approximation/hack, right?

I don't think so. I think e.g. the attention of t wrt t-1 will always be the same. So caching it is just a memory v. runtime tradeoff. The default is to pass in the full sequence of decoded tokens at every timestep and recompute all of the attentions each time. I will think that through more and test before opening a PR though.

Thanks for clarification. All the better, then.