FoundationVision / VAR

[GPT beats diffusion🔥] [scaling laws in visual generation📈] Official impl. of "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction". An *ultra-simple, user-friendly yet state-of-the-art* codebase for autoregressive image generation!
MIT License
4.03k stars 303 forks source link

AR Time Complexity? #35

Closed isaacrob closed 5 months ago

isaacrob commented 5 months ago

Great results! :) I have a question though about complexity

In section 3.1, you provide three issues with autoregressive transformer models. The third says that an autoregressive transformer generating n^2 tokens will take n^6 time. In your Appendix you show this by assuming that for each token i the attentions with the previous (i-1)^2 tokens need to be computed, which takes O(i^2), so the total time is the sum of i=1 to n^2 of i^2 which is O(n^6).

However, in practice we cache the intermediate representations for each token during autoregressive sampling. See here for an explanation. This means for each i we reuse the attentions that were computed for the preceding (i-1) tokens, and so only have to compute attentions with the ith token and each of the previous tokens, which takes O(i) time. Therefore the sum is i=1 to n^2 of i which is O(i^4), the same as the result that you report for your method.

Please let me know if for some reason the caching trick doesn't apply to the AR methods you're comparing against, or if your method can also benefit from caching in that way to speed up generation!

keyu-tian commented 5 months ago

Thanks @isaacrob. AR's inference optimization like the caching you mentioned or speculate decoding will improve both AR and VAR models. In 3.1 we focus on the plain decoding processes, just for giving a more intuitive sense of the speedup by VAR's parallel decoding and small number of iterations.

isaacrob commented 5 months ago

Thanks for the response!

Does caching reduce the complexity of VAR? It seems like your largest scale has O(n^2) tokens so any attention would take at minimum O(n^4).

keyu-tian commented 5 months ago

@isaacrob yeah, i find it's true that the complexity order hasn't been reduced (though there may be significant improvements in constants). Nonetheless, it is important to note that the complexity only reflects the total amount of computation, ignoring whether the computations can be parallelized or have to be sequential. In our tests with caching, the wall-clock cost of VAR was at least an order of magnitude lower than AR's. In other words, even although both have a total complexity of O(n^4), AR's computations are highly sequential, resulting in lower practical efficiency.

isaacrob commented 5 months ago

I understand that complexity doesn't necessarily map to clock time, I just was surprised that you claimed a complexity improvement in the paper when it looked like both should be O(n^4) for real-world implementations of transformers.

So just to make sure I'm understanding, it sounds like both are O(n^4) but you still argue for a speedup due to parallel decoding vs sequential decoding. Is this because each resolution is embedded via a transformer encoder as opposed to a transformer decoder? And the autoregressive statement is because, even though each individual scale is not autoregressive, the model as a whole is autoregressive because the transformer encoder for each resolution is conditioned on the embedding from the previous?

keyu-tian commented 5 months ago

Basically yes, but I prefer to say it's more like a transformer decoder, just like those non-autoregressive or n-tokens-one-pass language models. Transformer encoder won't have causal mask or causal dependency.