FoundationVision / VAR

[NeurIPS 2024 Oral][GPT beats diffusion🔥] [scaling laws in visual generation📈] Official impl. of "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction". An *ultra-simple, user-friendly yet state-of-the-art* codebase for autoregressive image generation!
MIT License
4.3k stars 316 forks source link

请问VQVAE(stage1)阶段是怎样使用多级VectorQuantizer的? #42

Closed YilanWang closed 7 months ago

YilanWang commented 7 months ago

感谢作者开源的代码,我发现paper的algorithm 2里写了是把z_k随着分辨率升高,也一直插值,然后一起送入decoder里,但是我看代码,decoder好像是直接那z_k使用,并没有相关的插值啊? `

    class Decoder(nn.Module):

        def forward(self, z):
            # z to block_in
            # middle
            h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))

            for i_level in reversed(range(self.num_resolutions)):
                for i_block in range(self.num_res_blocks + 1):
                    h = self.up[i_level].block[i_block](h)
                    if len(self.up[i_level].attn) > 0:
                        h = self.up[i_level].attn[i_block](h)
                if i_level != 0:
                    h = self.up[i_level].upsample(h)

            h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
            return h

`

然后相应的upsample函数也是直接的上采样: `

    class Upsample2x(nn.Module):
        def __init__(self, in_channels):
            super().__init__()
            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

        def forward(self, x):
    return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))

`

请问下algorithm2是怎么体现的呀?感谢感谢~~

keyu-tian commented 7 months ago

@YilanWang 可参考quant.py里面的相关插值代码

YilanWang commented 7 months ago

@YilanWang 可参考quant.py里面的相关插值代码

感谢回复~,我就是发现quant有插值,在var里用了,但是vqvae.py好像没有使用欸,是直接拿1x32x16x16(suppose 输入是256)直接用的,包括我导出的onnx计算图也是这样 :)

Snipaste_2024-04-29_10-37-49

@keyu-tian

keyu-tian commented 7 months ago

@YilanWang 可参考quant.py里面的相关插值代码

感谢回复~,我就是发现quant有插值,在var里用了,但是vqvae.py好像没有使用欸,是直接拿1x32x16x16(suppose 输入是256)直接用的,包括我导出的onnx计算图也是这样 :)

Snipaste_2024-04-29_10-37-49

@keyu-tian

嗯嗯,参考代码在什么地方调用了就ok