dome272 / MaskGIT-pytorch

Pytorch implementation of MaskGIT: Masked Generative Image Transformer (https://arxiv.org/pdf/2202.04200.pdf)
MIT License
398 stars 34 forks source link

sample_good() function in transformer.py #15

Open jeeyung opened 1 year ago

jeeyung commented 1 year ago

Hi!

I think the shape of logits from self.tokens_to_logits is [batch, 257, 1026] because you defined self.tok_emb = nn.Embedding(args.num_codebook_vectors + 2, args.dim).

However, the number of codebook's embedding is 1024 so that it occurs errors. Haven't you seen these errors during sampling? Did I miss something here?