FoundationVision / VAR

[GPT beats diffusion🔥] [scaling laws in visual generation📈] Official impl. of "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction". An *ultra-simple, user-friendly yet state-of-the-art* codebase for autoregressive image generation!
MIT License
3.78k stars 285 forks source link

Image reconstruction via Transformer. #55

Open minimini-1 opened 1 month ago

minimini-1 commented 1 month ago

Hello, I have a question about the image reconsturction via VAR. I want the transformer model to predict the ground truth tokens, just like in the training situation, by obtaining image tokens through an vq-encoder, and then interpolating the tokens, finally inputting them into the transformer. (like inversion in diffusion models)

However, when I configured the code, there was a difference from the original image. Could I have missed something, or is this approach not feasible?

Here's original image and recon image. original image

original image

recon image

recon image

And Here's my code.

gt_idx = vae.img_to_idxBl(img)
tr_input_embed = quantize_local.idxBl_to_var_input(gt_idx)

tr_input_embed's shape is [B, 679, 32] And I implement this code in the VAR class.

def image_recon_forward(self, tr_input_embed, gt_start_emd):
    bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)

    B = tr_input_embed.shape[0]

    with torch.cuda.amp.autocast(enabled=False):
        sos = cond_BD = self.class_emb(torch.tensor(1000).repeat(B).to(tr_input_embed.device)).unsqueeze(1)
        sos = sos.expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
        if self.prog_si == 0: x_BLC = sos
        else: x_BLC = torch.cat((sos, self.word_embed(tr_input_embed.float())), dim=1)
        x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC;  pos: 1LC

    f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
    attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
    cond_BD_or_gss = self.shared_ada_lin(cond_BD)

    temp = x_BLC.new_ones(8, 8)
    main_type = torch.matmul(temp, temp).dtype

    x_BLC = x_BLC.to(dtype=main_type)
    cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
    attn_bias = attn_bias.to(dtype=main_type)

    AdaLNSelfAttn.forward
    for i, b in enumerate(self.blocks):
        x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
    x_BLC = self.get_logits(x_BLC.float(), cond_BD)

    idx_Bl = x_BLC.argmax(dim=-1)
    idx_Bl[:,0] = gt_start_emd.squeeze(1)
    h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl)
    h_list = []
    for si, (st, ed) in enumerate(self.begin_ends):
        pn = self.patch_nums[si]
        h_list.append(h_BChw[:,st:ed,:].reshape(B, int((ed-st)**0.5), int((ed-st)**0.5), self.Cvae).permute(0,3,1,2))
        f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw[:,st:ed,:].transpose_(1,2).reshape(B, self.Cvae, pn, pn))
    for b in self.blocks: b.attn.kv_caching(False)
    return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5)

Again, thanks for your great work!