Closed isaacrob closed 5 months ago
Thanks @isaacrob. AR's inference optimization like the caching you mentioned or speculate decoding will improve both AR and VAR models. In 3.1 we focus on the plain decoding processes, just for giving a more intuitive sense of the speedup by VAR's parallel decoding and small number of iterations.
Thanks for the response!
Does caching reduce the complexity of VAR? It seems like your largest scale has O(n^2) tokens so any attention would take at minimum O(n^4).
@isaacrob yeah, i find it's true that the complexity order hasn't been reduced (though there may be significant improvements in constants). Nonetheless, it is important to note that the complexity only reflects the total amount of computation, ignoring whether the computations can be parallelized or have to be sequential. In our tests with caching, the wall-clock cost of VAR was at least an order of magnitude lower than AR's. In other words, even although both have a total complexity of O(n^4), AR's computations are highly sequential, resulting in lower practical efficiency.
I understand that complexity doesn't necessarily map to clock time, I just was surprised that you claimed a complexity improvement in the paper when it looked like both should be O(n^4) for real-world implementations of transformers.
So just to make sure I'm understanding, it sounds like both are O(n^4) but you still argue for a speedup due to parallel decoding vs sequential decoding. Is this because each resolution is embedded via a transformer encoder as opposed to a transformer decoder? And the autoregressive statement is because, even though each individual scale is not autoregressive, the model as a whole is autoregressive because the transformer encoder for each resolution is conditioned on the embedding from the previous?
Basically yes, but I prefer to say it's more like a transformer decoder, just like those non-autoregressive or n-tokens-one-pass language models. Transformer encoder won't have causal mask or causal dependency.
Great results! :) I have a question though about complexity
In section 3.1, you provide three issues with autoregressive transformer models. The third says that an autoregressive transformer generating n^2 tokens will take n^6 time. In your Appendix you show this by assuming that for each token i the attentions with the previous (i-1)^2 tokens need to be computed, which takes O(i^2), so the total time is the sum of i=1 to n^2 of i^2 which is O(n^6).
However, in practice we cache the intermediate representations for each token during autoregressive sampling. See here for an explanation. This means for each i we reuse the attentions that were computed for the preceding (i-1) tokens, and so only have to compute attentions with the ith token and each of the previous tokens, which takes O(i) time. Therefore the sum is i=1 to n^2 of i which is O(i^4), the same as the result that you report for your method.
Please let me know if for some reason the caching trick doesn't apply to the AR methods you're comparing against, or if your method can also benefit from caching in that way to speed up generation!