Open Adamits opened 10 months ago
So this is an approximation/hack, right? I'm fine with it, and maybe we could treat it as a separate architecture to keep things simple.
So this is an approximation/hack, right?
I don't think so. I think e.g. the attention of t wrt t-1 will always be the same. So caching it is just a memory v. runtime tradeoff. The default is to pass in the full sequence of decoded tokens at every timestep and recompute all of the attentions each time. I will think that through more and test before opening a PR though.
treat it as a separate architecture to keep things simple.
Sure we can do it that way.
So this is an approximation/hack, right?
I don't think so. I think e.g. the attention of t wrt t-1 will always be the same. So caching it is just a memory v. runtime tradeoff. The default is to pass in the full sequence of decoded tokens at every timestep and recompute all of the attentions each time. I will think that through more and test before opening a PR though.
Thanks for clarification. All the better, then.
Transformer inference (i.e. with no teacher forcing) is slow. In practice I think people typically implement some kind of caching so that at each timestep, we do not need to recompute the embeddings and attentions between all previously decoded timesteps.
I have a quick and dirty implementation of this in an experimental fork, where I basically tell the decoder layer to only get the attention from the most recently decoded target, and all other representations are concatenated on. There are probably other tricks that I can find by e.g. inspecting some huggingface transformers inference code.
I propose adding an option to the transformer encoder decoders to use caching, wherein a CacheTransformerDecoder module is used.
This is a low-priority TODO since we do validation with accuracy, rather than loss, and accuracy can be reliably predicted with teacher forcing if the targets are provided.