FoundationVision / VAR

[GPT beats diffusion🔥] [scaling laws in visual generation📈] Official impl. of "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction". An *ultra-simple, user-friendly yet state-of-the-art* codebase for autoregressive image generation!
MIT License
4.03k stars 302 forks source link

How do you make Transformer generate tokens in parallel? #1

Open jzhang38 opened 6 months ago

jzhang38 commented 6 months ago

It is mentioned multiple times in the paper that all tokens from the same scale r are generated in parallel. Did I overlook or there is actually little description about how to generate tokens in parallel in the VAR transformer?

keyu-tian commented 6 months ago

The parallel decoding is done by feeding next-scale queries (e.g., 4x4=16) to the transformer decoder and getting 16 predicted token distributions simultaneously. Sampling from them can get the 16 tokens in parallel. We'll consider include a more thorough explanation like this in our paper and thank you @jzhang38.

jzhang38 commented 6 months ago

Thanks for the quick answer!

The parallel decoding is done by feeding next-scale queries (e.g., 4x4=16) to the transformer decoder and getting 16 predicted token distributions simultaneously.

This is still not very clear to me. Do you still use causal masks within these 16 tokens? Eagerly waiting for the code.

jbaron34 commented 6 months ago

The parallel decoding is done by feeding next-scale queries (e.g., 4x4=16) to the transformer decoder and getting 16 predicted token distributions simultaneously.

Are the 16 tokens predicted independently or are they allowed to attend to each other during decoding?

Sampling from them can get the 16 tokens in parallel

Do you just use greedy sampling for each token? Did you experiment with different sampling strategies?

keyu-tian commented 6 months ago

@jzhang38 @jbaron34 During decoding these 16 tokens can freely attend to each other with no mask between them (they just cannot attend to future token maps like the next 5x5). After that, a single linear layer followed by softmax will convert each token feature to a probability distribution. No greedy sampling is used, they are sampled in parallel as if independent to each other. It's almost the same as the sampling of GPT but VAR will sample multiple tokens at the same time.

More specifically: for a multi-scale token map 1x1 -> 2x2 -> 3x3, the attention mask used in training will be like:

[
  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
].

During inference we have 3 autoregressive iterations in total, and in the last one we will get a feature shaped (B, 9, C) and then linearly projected to (B, 9, V). The sampling is done on (Bx9, V) to get Bx9 tokens.

PS: you can also check the code at https://github.com/FoundationVision/VAR/blob/main/models/var.py#L144

chenllliang commented 6 months ago

@jzhang38 @jbaron34 These 16 tokens can attend to each other during decoding ("decoding" means a standard transformer forward process). After that, a single linear layer followed by softmax will convert each token feature to a probability distribution. No greedy sampling is used, we directly sample 16 tokens in parallel. It's almost the same as the standard sampling of GPT but we sample multiple tokens at the same time.

More specifically: for a multi-scale token map 1x1 -> 2x2 -> 3x3, the attention mask used in training will be like:

[
  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
].

During inference we have 3 autoregressive iterations in total, and in the last one we will get a feature shaped (B, 9, C) and then linearly projected to (B, 9, V). The sampling is done on (Bx9, V) to get Bx9 tokens.

@keyu-tian Thanks for your great work! I feel that the prediction within each scale's token map is more like MAE style instead of AR style, that each input hidden state will finally predict the VQ token at the same position and they can see the other tokens within the same scale and tokens in previous scale, instead of next position like AR. The output 14 VQ token of the given attention map during training would be 'r1, r2-1, r2-2, r2-3, r2-4, r3-1, r3-2, r3-3, r3-4, r3-5, r3-6, r3-7, r3-8, r3-9'. The AR process is among different scale map's generation. Am I correct about it?

Another question is how do you set the input_ids and labels for transformer in the training stage, can you also give an example?

Also the ArXiv badge seems to have a wrong paper id in the readme.

keyu-tian commented 6 months ago

@jzhang38 @jbaron34 These 16 tokens can attend to each other during decoding ("decoding" means a standard transformer forward process). After that, a single linear layer followed by softmax will convert each token feature to a probability distribution. No greedy sampling is used, we directly sample 16 tokens in parallel. It's almost the same as the standard sampling of GPT but we sample multiple tokens at the same time. More specifically: for a multi-scale token map 1x1 -> 2x2 -> 3x3, the attention mask used in training will be like:

[
  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
].

During inference we have 3 autoregressive iterations in total, and in the last one we will get a feature shaped (B, 9, C) and then linearly projected to (B, 9, V). The sampling is done on (Bx9, V) to get Bx9 tokens.

@keyu-tian Thanks for your great work! I feel that the prediction within each scale's token map is more like MAE style instead of AR style, that each input hidden state will finally predict the VQ token at the same position and they can see the other tokens within the same scale and tokens in previous scale, instead of next position like AR. The output 14 VQ token of the given attention map during training would be 'r1, r2-1, r2-2, r2-3, r2-4, r3-1, r3-2, r3-3, r3-4, r3-5, r3-6, r3-7, r3-8, r3-9'. The AR process is among different scale map's generation. Am I correct about it?

Another question is how do you set the input_ids and labels for transformer in the training stage, can you also give an example?

BERT is not like VAR or AR. I feel "they can see the other tokens within the same scale and tokens in previous scale" is just the definition of Autoregressive learning -- akin to how GPT's autoregressive mechanism allows each token to be aware of all preceding tokens.

To clarify, in a GPT-like language model, each step uses the previous token as the current input, with visibility limited to all prior tokens. VAR operates similarly, using the previous token map to predict the next, building upon all earlier maps. Taking the final step of "1x1->2x2->3x3" as example, VAR will interpolate its last prediction, i.e. the 2x2 map, to 3x3, as the next step's inout. This 3x3 map can attend to all of its preceding maps (both 1x1 and 2x2). So this process is almost the same as what GPT's doing. A token is to GPT what a token map is to VAR.

During training, GPT will use "right shift" to convert ground truth into input. Likewise, VAR use interpolations 1x1 -> 2x2, 2x2->3x3 to get the inputs of 2nd and 3rd AR steps, and the first step's input would be a learnable start token. ps: these interpolations are done in the continuous embedding space, not the token space (integer space).

BERT is another radically different system. It uses bidirectional attention (while GPT/VAR uses triangular/block-wise-triangular causal mask to ensure unidirectional dependency). It also doesn't follow the "last output as next input" pattern that characterizes AR or VAR models.

And thank you for the heads-up! We'll correct that right away.

bonejay commented 6 months ago

So during inference in order to generate rk we attend r1 til r(k-1). Is that really needed. Can't we simply just attend r(k-1) and mask everything else for the cross attention when we ofc do the same during training?

I mean diffusion models also only see one step (t-1) of the denoised image not the whole history.

keyu-tian commented 6 months ago

So during inference in order to generate rk we attend r1 til r(k-1). Is that really needed. Can't we simply just attend r(k-1) and mask everything else for the cross attention when we ofc do the same during training?

I mean diffusion models also only see one step (t-1) of the denoised image not the whole history.

IMO based only on step (t-1) could be an interesting experiment, akin to diffusion models. This will make VAR more like a ``recurrent super-resolution'' model. But for now, our VAR just follows AR so all preceding token maps will be used.

BTW we've just released a demo notebook for everyone to play with! @bonejay @jzhang38 @chenllliang @jbaron34. Check it out at https://github.com/FoundationVision/VAR/blob/main/demo_sample.ipynb for a deeper dive into our sampling process. Enjoy and cheers! 🍻

chenllliang commented 6 months ago

So during inference in order to generate rk we attend r1 til r(k-1). Is that really needed. Can't we simply just attend r(k-1) and mask everything else for the cross attention when we ofc do the same during training? I mean diffusion models also only see one step (t-1) of the denoised image not the whole history.

IMO based only on step (t-1) could be an interesting experiment, akin to diffusion models. This will make VAR more like a ``recurrent super-resolution'' model. But for now, our VAR just follows AR so all preceding token maps will be used.

BTW we've just released a demo notebook for everyone to play with! @bonejay @jzhang38 @chenllliang @jbaron34. Check it out at https://github.com/FoundationVision/VAR/blob/main/demo_sample.ipynb for a deeper dive into our sampling process. Enjoy and cheers! 🍻

Thanks! It works for me.

sample

jungle-gym-ac commented 5 months ago

@jzhang38 @jbaron34 During decoding these 16 tokens can freely attend to each other with no mask between them (they just cannot attend to future token maps like the next 5x5). After that, a single linear layer followed by softmax will convert each token feature to a probability distribution. No greedy sampling is used, they are sampled in parallel as if independent to each other. It's almost the same as the sampling of GPT but VAR will sample multiple tokens at the same time.

More specifically: for a multi-scale token map 1x1 -> 2x2 -> 3x3, the attention mask used in training will be like:

[
  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
].

During inference we have 3 autoregressive iterations in total, and in the last one we will get a feature shaped (B, 9, C) and then linearly projected to (B, 9, V). The sampling is done on (Bx9, V) to get Bx9 tokens.

PS: you can also check the code at https://github.com/FoundationVision/VAR/blob/main/models/var.py#L144

Thank you for your explanation! It seems to me that a single generation step of VAR is very similar to the Perceiver Resampler Module in Flamingo. The model is fed with new query tokens in every autoregressive "next-scale prediction" step, unlike GPT whose output of one step will be fed as input for the next step. I think you should really explain this with graphs and text in your paper, not only to avoid confusion, but also because it is a great innovation.

By the way, it seems there is a line of [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] missing in this attention mask matrix.

keyu-tian commented 5 months ago

@jungle-gym-ac many thanks for your advice and we ll explain on this more in our updated version. Also thanks for pointing out this typo lol.

eisneim commented 5 months ago

for anyone that might have a hard time understanding why the attention mask looks like what @keyu-tian has shown, here is my illustration:

Screenshot 2024-04-07 at 22 24 12
icoz69 commented 5 months ago

@jzhang38 @jbaron34 These 16 tokens can attend to each other during decoding ("decoding" means a standard transformer forward process). After that, a single linear layer followed by softmax will convert each token feature to a probability distribution. No greedy sampling is used, we directly sample 16 tokens in parallel. It's almost the same as the standard sampling of GPT but we sample multiple tokens at the same time. More specifically: for a multi-scale token map 1x1 -> 2x2 -> 3x3, the attention mask used in training will be like:

[
  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
].

During inference we have 3 autoregressive iterations in total, and in the last one we will get a feature shaped (B, 9, C) and then linearly projected to (B, 9, V). The sampling is done on (Bx9, V) to get Bx9 tokens.

@keyu-tian Thanks for your great work! I feel that the prediction within each scale's token map is more like MAE style instead of AR style, that each input hidden state will finally predict the VQ token at the same position and they can see the other tokens within the same scale and tokens in previous scale, instead of next position like AR. The output 14 VQ token of the given attention map during training would be 'r1, r2-1, r2-2, r2-3, r2-4, r3-1, r3-2, r3-3, r3-4, r3-5, r3-6, r3-7, r3-8, r3-9'. The AR process is among different scale map's generation. Am I correct about it? Another question is how do you set the input_ids and labels for transformer in the training stage, can you also give an example?

BERT is not like VAR or AR. I feel "they can see the other tokens within the same scale and tokens in previous scale" is just the definition of Autoregressive learning -- akin to how GPT's autoregressive mechanism allows each token to be aware of all preceding tokens.

To clarify, in a GPT-like language model, each step uses the previous token as the current input, with visibility limited to all prior tokens. VAR operates similarly, using the previous token map to predict the next, building upon all earlier maps. Taking the final step of "1x1->2x2->3x3" as example, VAR will interpolate its last prediction, i.e. the 2x2 map, to 3x3, as the next step's inout. This 3x3 map can attend to all of its preceding maps (both 1x1 and 2x2). So this process is almost the same as what GPT's doing. A token is to GPT what a token map is to VAR.

During training, GPT will use "right shift" to convert ground truth into input. Likewise, VAR use interpolations 1x1 -> 2x2, 2x2->3x3 to get the inputs of 2nd and 3rd AR steps, and the first step's input would be a learnable start token. ps: these interpolations are done in the continuous embedding space, not the token space (integer space).

BERT is another radically different system. It uses bidirectional attention (while GPT/VAR uses triangular/block-wise-triangular causal mask to ensure unidirectional dependency). It also doesn't follow the "last output as next input" pattern that characterizes AR or VAR models.

And thank you for the heads-up! We'll correct that right away.

hi, in training, do you mean interpolations between token embeddings in the dictionary and use these embeddings as inputs? During inference, the embeddings of the predicted tokens are interpolated, right?

keyu-tian commented 5 months ago

@jzhang38 @jbaron34 These 16 tokens can attend to each other during decoding ("decoding" means a standard transformer forward process). After that, a single linear layer followed by softmax will convert each token feature to a probability distribution. No greedy sampling is used, we directly sample 16 tokens in parallel. It's almost the same as the standard sampling of GPT but we sample multiple tokens at the same time. More specifically: for a multi-scale token map 1x1 -> 2x2 -> 3x3, the attention mask used in training will be like:

[
  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
].

During inference we have 3 autoregressive iterations in total, and in the last one we will get a feature shaped (B, 9, C) and then linearly projected to (B, 9, V). The sampling is done on (Bx9, V) to get Bx9 tokens.

@keyu-tian Thanks for your great work! I feel that the prediction within each scale's token map is more like MAE style instead of AR style, that each input hidden state will finally predict the VQ token at the same position and they can see the other tokens within the same scale and tokens in previous scale, instead of next position like AR. The output 14 VQ token of the given attention map during training would be 'r1, r2-1, r2-2, r2-3, r2-4, r3-1, r3-2, r3-3, r3-4, r3-5, r3-6, r3-7, r3-8, r3-9'. The AR process is among different scale map's generation. Am I correct about it? Another question is how do you set the input_ids and labels for transformer in the training stage, can you also give an example?

BERT is not like VAR or AR. I feel "they can see the other tokens within the same scale and tokens in previous scale" is just the definition of Autoregressive learning -- akin to how GPT's autoregressive mechanism allows each token to be aware of all preceding tokens. To clarify, in a GPT-like language model, each step uses the previous token as the current input, with visibility limited to all prior tokens. VAR operates similarly, using the previous token map to predict the next, building upon all earlier maps. Taking the final step of "1x1->2x2->3x3" as example, VAR will interpolate its last prediction, i.e. the 2x2 map, to 3x3, as the next step's inout. This 3x3 map can attend to all of its preceding maps (both 1x1 and 2x2). So this process is almost the same as what GPT's doing. A token is to GPT what a token map is to VAR. During training, GPT will use "right shift" to convert ground truth into input. Likewise, VAR use interpolations 1x1 -> 2x2, 2x2->3x3 to get the inputs of 2nd and 3rd AR steps, and the first step's input would be a learnable start token. ps: these interpolations are done in the continuous embedding space, not the token space (integer space). BERT is another radically different system. It uses bidirectional attention (while GPT/VAR uses triangular/block-wise-triangular causal mask to ensure unidirectional dependency). It also doesn't follow the "last output as next input" pattern that characterizes AR or VAR models. And thank you for the heads-up! We'll correct that right away.

hi, in training, do you mean interpolations between token embeddings in the dictionary and use these embeddings as inputs? During inference, the embeddings of the predicted tokens are interpolated, right?

@icoz69 yeah, they're both like that. The "right shifted" interpolations for producing teacher-forced training inputs can be found at https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L169. The interpolations during inference on the embeddings of predicted tokens are at https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L187. All interpolations occur in the VQ embedding space, which we've chosen for its spatial structure that lends itself well to interpolation.

icoz69 commented 5 months ago

@jzhang38 @jbaron34 These 16 tokens can attend to each other during decoding ("decoding" means a standard transformer forward process). After that, a single linear layer followed by softmax will convert each token feature to a probability distribution. No greedy sampling is used, we directly sample 16 tokens in parallel. It's almost the same as the standard sampling of GPT but we sample multiple tokens at the same time. More specifically: for a multi-scale token map 1x1 -> 2x2 -> 3x3, the attention mask used in training will be like:

[
  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
].

During inference we have 3 autoregressive iterations in total, and in the last one we will get a feature shaped (B, 9, C) and then linearly projected to (B, 9, V). The sampling is done on (Bx9, V) to get Bx9 tokens.

@keyu-tian Thanks for your great work! I feel that the prediction within each scale's token map is more like MAE style instead of AR style, that each input hidden state will finally predict the VQ token at the same position and they can see the other tokens within the same scale and tokens in previous scale, instead of next position like AR. The output 14 VQ token of the given attention map during training would be 'r1, r2-1, r2-2, r2-3, r2-4, r3-1, r3-2, r3-3, r3-4, r3-5, r3-6, r3-7, r3-8, r3-9'. The AR process is among different scale map's generation. Am I correct about it? Another question is how do you set the input_ids and labels for transformer in the training stage, can you also give an example?

BERT is not like VAR or AR. I feel "they can see the other tokens within the same scale and tokens in previous scale" is just the definition of Autoregressive learning -- akin to how GPT's autoregressive mechanism allows each token to be aware of all preceding tokens. To clarify, in a GPT-like language model, each step uses the previous token as the current input, with visibility limited to all prior tokens. VAR operates similarly, using the previous token map to predict the next, building upon all earlier maps. Taking the final step of "1x1->2x2->3x3" as example, VAR will interpolate its last prediction, i.e. the 2x2 map, to 3x3, as the next step's inout. This 3x3 map can attend to all of its preceding maps (both 1x1 and 2x2). So this process is almost the same as what GPT's doing. A token is to GPT what a token map is to VAR. During training, GPT will use "right shift" to convert ground truth into input. Likewise, VAR use interpolations 1x1 -> 2x2, 2x2->3x3 to get the inputs of 2nd and 3rd AR steps, and the first step's input would be a learnable start token. ps: these interpolations are done in the continuous embedding space, not the token space (integer space). BERT is another radically different system. It uses bidirectional attention (while GPT/VAR uses triangular/block-wise-triangular causal mask to ensure unidirectional dependency). It also doesn't follow the "last output as next input" pattern that characterizes AR or VAR models. And thank you for the heads-up! We'll correct that right away.

hi, in training, do you mean interpolations between token embeddings in the dictionary and use these embeddings as inputs? During inference, the embeddings of the predicted tokens are interpolated, right?

@icoz69 yeah, they're both like that. The "right shifted" interpolations for producing teacher-forced training inputs can be found at https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L169. The interpolations during inference on the embeddings of predicted tokens are at https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L187. All interpolations occur in the VQ embedding space, which we've chosen for its spatial structure that lends itself well to interpolation.

thanks for the clarifications. it's clear to me now. good work!

plutoyuxie commented 5 months ago

from https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L169 I found the var input is prepared following these steps: for k = 1, ..., K do   rk = queue_pop(R);   zk = lookup(Z, rk);   zk = interpolate(z, hK, wK); # upsampling   f_hat = f_hat + phik(zk);   next_scale = interpolate(f_hat, hk+1, wk+1); # downsampling   ...

these ( hk+1 * wk+1) tokens seem to be the initial value.

lorenmt commented 5 months ago

@keyu-tian Considering the fundamental difference between autoregressive language modelling and autoregressive image modelling: individual tokens could be ambiguous without history context in language; while the same thing is probably not true for image-scale modelling in your setup. I am wondering whether you have considered modelling scaling image resolution in a Markov process p(rk|r{k-1}) instead to further save inference time?

Really love to hear your insights.

keyu-tian commented 5 months ago

hi @lorenmt, if you think about the actual process of VQVAE multi-scale encoding (introduced in Section 3.2 and Algorithm 1, 2), each rk contains only a part of the full information of an image. So all previous token maps r{<k} are needed for generating r_{k}.

ZetangForward commented 5 months ago

hi @lorenmt, if you think about the actual process of VQVAE multi-scale encoding (introduced in Section 3.2 and Algorithm 1, 2), each rk contains only a part of the full information of an image. So all previous token maps r{<k} are needed for generating r_{k}.

Then, I think this is a big attention map if it has so many previous token maps. I would like to know how large a GPU you have used and the maximum size of r_k you can train?

keyu-tian commented 5 months ago

@ZetangForward that depends on the specific GPU architecture. For 512x512, 40G-A100 is enough. Should use Deepspeed Zero2/3, FSDP, or Sequence Parallel for saving GPU memory.

ZetangForward commented 5 months ago

@ZetangForward that depends on the specific GPU architecture. For 512x512, 40G-A100 is enough. Should use Deepspeed Zero2/3, FSDP, or Sequence Parallel for saving GPU memory.

ok, thanks, get it.

daiyixiang666 commented 4 months ago

I have a confuse between the train and inference step, during training, for example 1x1 - 2x2 - 3x3, we have the 1x1 + 2x2 = 5 , where the 2x2 has the information of 1x1, but during the inference when get the 2x2 token, we feed the 2x2 (upsample) from the 1x1, which is different from the traning stage, which make me confused, I think should not the mask become the following is a better choose ? becasuse during inference we actually does not have the previous information `[ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],

[0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],

[0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],

[0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1],

[0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1] ].`

eanson023 commented 4 months ago

Hi @keyu-tian , I am very concerned about whether VAR can handle variable length data (ie not fixed H, W). If the H and W of an image change to 14, and the VAR scale K is still {1x1, ..., 13x13, 16x16}, how can I predict [EOS] token like GPT? Because the input of each scale of VAR relies on the interpolation of the previous scale (https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L169), the learned codebook cannot be modified to add an [EOS] idx when training the transformer (https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L39), like the traditional AR embedding:

# two dummy tokens, one for [EOS], one for padding
self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)

Hope you understand what I mean, this problem has been bothering me and I hope there is a must good solution.

MiracleDance commented 4 months ago

I have a confuse between the train and inference step, during training, for example 1x1 - 2x2 - 3x3, we have the 1x1 + 2x2 = 5 , where the 2x2 has the information of 1x1, but during the inference when get the 2x2 token, we feed the 2x2 (upsample) from the 1x1, which is different from the traning stage, which make me confused, I think should not the mask become the following is a better choose ? becasuse during inference we actually does not have the previous information `[ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],

[0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],

[0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],

[0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1],

[0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1] ].`

that is also my confusion, looking forward to the answer~

keyu-tian commented 4 months ago

hi @daiyixiang666 @MiracleDance you may miss we also do the upsampling during training.

See https://github.com/FoundationVision/VAR/blob/main/trainer.py#L68C58-L68C76 and https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L169.

In that function, we upsample the original 1x1 token map to 2x2, the original 2x2 to 3x3, and discard (omit) the original 3x3. So this 2x2 and 3x3 contains the information of original 1x1 and original 2x2, respectively. In the beginning of VAR's forward, this 2x2+3x3 will be concated with a 1x1 [sos] to its left, to get a full 1x1+2x2+3x3 sequence. So this sequence contains the information of [SOS], original 1x1, and original 2x2. Taking this as input, VAR will predict 1x1 given [SOS], predict 2x2 given [SOS]+original 1x1, and predict 3x3 given [SOS]+original 1x1+original 2x2. This requires a causal mask, not that "do not have the previous information"-style mask.

During inference, VAR will also predict 1x1 given [SOS], predict 2x2 given [SOS]+predicted 1x1, and predict 3x3 given [SOS]+predicted 1x1+predicted 2x2. So the training and inference are aligned well, and are the same as what GPT does.

keyu-tian commented 4 months ago

Hi @keyu-tian , I am very concerned about whether VAR can handle variable length data (ie not fixed H, W). If the H and W of an image change to 14, and the VAR scale K is still {1x1, ..., 13x13, 16x16}, how can I predict [EOS] token like GPT? Because the input of each scale of VAR relies on the interpolation of the previous scale (https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L169), the learned codebook cannot be modified to add an [EOS] idx when training the transformer (https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L39), like the traditional AR embedding:

# two dummy tokens, one for [EOS], one for padding
self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)

Hope you understand what I mean, this problem has been bothering me and I hope there is a must good solution.

When we want VAR to generate an image of some other size, we should give a new 'schedule'. E.g., if we want to generate an image with 12x18 latent (a 2:3 image) we can set the schedule as [1x1, 2x2, 2x3, 4x6, 6x9, 8x12, 12x18]. With appropriate position embedding interpolation, VAR can generate the 12x18 latent.

So VAR is currently not designed to be able to output a [EOS]. But i believe this is worth exploring in the future :D.

eanson023 commented 4 months ago

Hi @keyu-tian , I am very concerned about whether VAR can handle variable length data (ie not fixed H, W). If the H and W of an image change to 14, and the VAR scale K is still {1x1, ..., 13x13, 16x16}, how can I predict [EOS] token like GPT? Because the input of each scale of VAR relies on the interpolation of the previous scale (https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L169), the learned codebook cannot be modified to add an [EOS] idx when training the transformer (https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L39), like the traditional AR embedding:

# two dummy tokens, one for [EOS], one for padding
self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)

Hope you understand what I mean, this problem has been bothering me and I hope there is a must good solution.

When we want VAR to generate an image of some other size, we should give a new 'schedule'. E.g., if we want to generate an image with 12x18 latent (a 2:3 image) we can set the schedule as [1x1, 2x2, 2x3, 4x6, 6x9, 8x12, 12x18]. With appropriate position embedding interpolation, VAR can generate the 12x18 latent.

So VAR is currently not designed to be able to output a [EOS]. But i believe this is worth exploring in the future :D.

Thanks for your reply, got it

daiyixiang666 commented 3 months ago

I have a confuse between the train and inference step, during training, for example 1x1 - 2x2 - 3x3, we have the 1x1 + 2x2 = 5 , where the 2x2 has the information of 1x1, but during the inference when get the 2x2 token, we feed the 2x2 (upsample) from the 1x1, which is different from the traning stage, which make me confused, I think should not the mask become the following is a better choose ? becasuse during inference we actually does not have the previous information [ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1] ].

that is also my confusion, looking forward to the answer~

KV-cache!!! I know, this make me confused about the input shape during inference

HalvesChen commented 3 months ago

I have a confuse between the train and inference step, during training, for example 1x1 - 2x2 - 3x3, we have the 1x1 + 2x2 = 5 , where the 2x2 has the information of 1x1, but during the inference when get the 2x2 token, we feed the 2x2 (upsample) from the 1x1, which is different from the traning stage, which make me confused, I think should not the mask become the following is a better choose ? becasuse during inference we actually does not have the previous information [ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1] ].

that is also my confusion, looking forward to the answer~

KV-cache!!! I know, this make me confused about the input shape during inference

get it! you are right!

dunbar12138 commented 2 months ago

Hi @keyu-tian , I am very concerned about whether VAR can handle variable length data (ie not fixed H, W). If the H and W of an image change to 14, and the VAR scale K is still {1x1, ..., 13x13, 16x16}, how can I predict [EOS] token like GPT? Because the input of each scale of VAR relies on the interpolation of the previous scale (https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L169), the learned codebook cannot be modified to add an [EOS] idx when training the transformer (https://github.com/FoundationVision/VAR/blob/main/models/quant.py#L39), like the traditional AR embedding:

# two dummy tokens, one for [EOS], one for padding
self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)

Hope you understand what I mean, this problem has been bothering me and I hope there is a must good solution.

When we want VAR to generate an image of some other size, we should give a new 'schedule'. E.g., if we want to generate an image with 12x18 latent (a 2:3 image) we can set the schedule as [1x1, 2x2, 2x3, 4x6, 6x9, 8x12, 12x18]. With appropriate position embedding interpolation, VAR can generate the 12x18 latent.

So VAR is currently not designed to be able to output a [EOS]. But i believe this is worth exploring in the future :D.

Thanks for the amazing work and the intensive discussion!! I read through all the conversations and found it super helpful in better understanding the paper.

Regarding this thread, does it mean VAR could generate a 12x18 latent even without training on samples of this size? This surprises me because, during training, VAR is only exposed to interpolations and positional embeddings at square dimensions (e.g., 2x2, 3x3). I find it unintuitive why it could generalize to new dimensions (e.g., 2x3, 12x18)

I thought that to generate a 12x18 latent, VAR had to do outpainting with a sliding context window of square dimensions.

Looking forward to your reply!

MohamedAliRashad commented 2 months ago

I aplogoize for everyone in this discussion but i still can't wrap my head on how transformers take a token map and output a token map that has a different dimensions in each step.

My simple understanding of transformers that they take [B, L, C] where B is the batch size, L is the sequence length and C is embedding dimension of the tokens inputted. Before the LM Head the model generates [B, L, D] where D is the new embedding of the tokens after processing. This output goes to the LM Head and produce probability distributions for the next tokens for all the input tokens => [B, L, V] where V is the vocab size and the values are normalized.

Now, how this work in VAR (If my understanding is correct) ? especially, That VAR output different number of tokens at each step

daiyixiang666 commented 2 months ago

I aplogoize for everyone in this discussion but i still can't wrap my head on how transformers take a token map and output a token map that has a different dimensions in each step.

My simple understanding of transformers that they take [B, L, C] where B is the batch size, L is the sequence length and C is embedding dimension of the tokens inputted. Before the LM Head the model generates [B, L, D] where D is the new embedding of the tokens after processing. This output goes to the LM Head and produce probability distributions for the next tokens for all the input tokens => [B, L, V] where V is the vocab size and the values are normalized.

Now, how this work in VAR (If my understanding is correct) ? especially, That VAR output different number of tokens at each step

Try to run the code by yourself and print the shape that you want to figure out during inference

MohamedAliRashad commented 2 months ago

@daiyixiang666 My closest analogy to what's going here, is that rk (the scale we want to predict) we input its shape flatten (of course) and rely on the nature of transformers where it predicts the next token for every input token it has.

Let's say I want to predict the scale 3x3, I make 9 dummy tokens in sequence and i input them after the last scale i generated (flattened also with positional and level encoding added) and those 9 dummy tokens will have 9 next tokens predicted for them. Those 9 next tokens are the output i want.

This works without those 9 dummy tokens affecting the results because of the attention mask that tells the model not to take them into consideration, take only the previous ones.

The only thing that is scratching my head with this theory is it's not very scalable with bigger scale. In an image of 256x256 you will need 65536 dummy tokens which is not feasiable for processing.

daiyixiang666 commented 2 months ago

That is why we need the VQVAE to compress the image, for example, the 256x256will be downgrade to the 16x16

MohamedAliRashad commented 2 months ago

@daiyixiang666 I thought the final scale will be the original image scale. If this is the case then everything makes sense.

daiyixiang666 commented 2 months ago

No no no, you can print the input shape of the transformer by yourself