Thank you for your excellent work! I have some questions regarding the computation and memory consumption when training VAR and AR models.
Currently, the VAR model is configured with patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), resulting in a total token count of up to 680. In comparison, an equivalent AR model with patch_num=16 has only 256 tokens. Since all tokens are generated in parallel when training VAR and AR models, the significantly larger token count in the VAR model requires substantially more computation and memory when calculating self-attention.
I understand that the VAR model, due to fewer iterations, infers faster than the AR model. However, does the significantly larger token count in the VAR model during training lead to substantially higher computational and memory consumption, making it difficult to train the VAR model on very-high-resolution image datasets or video datasets? Are there any good solutions for this issue?
Thank you for your excellent work! I have some questions regarding the computation and memory consumption when training VAR and AR models.
Currently, the VAR model is configured with patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), resulting in a total token count of up to 680. In comparison, an equivalent AR model with patch_num=16 has only 256 tokens. Since all tokens are generated in parallel when training VAR and AR models, the significantly larger token count in the VAR model requires substantially more computation and memory when calculating self-attention.
I understand that the VAR model, due to fewer iterations, infers faster than the AR model. However, does the significantly larger token count in the VAR model during training lead to substantially higher computational and memory consumption, making it difficult to train the VAR model on very-high-resolution image datasets or video datasets? Are there any good solutions for this issue?
感谢你们的非常出色的工作!我对训练VAR模型和AR模型所消耗的计算量和显存有一些疑问。
目前VAR设置的patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),总的token数量可以达到680。相较于patch_num=16的同等AR模型而言,它们的token数量只有256。由于在训练VAR模型和训练AR模型的时候,所有的token都是并行训练的,因此VAR中很大的token数量在计算自注意力时会需要消耗显著更多的计算量和显存。
我知道VAR模型在推理的时候由于迭代次数更少,它的推理会快于AR模型。但是在训练VAR模型的时候,是否会因为VAR模型显著更多的token数量,导致它会消耗显著更多的计算量和显存,从而使得VAR模型难以在超高分辨率图像数据集或者视频数据集中进行训练?有什么好的解决方法吗?