dome272 / MaskGIT-pytorch

Pytorch implementation of MaskGIT: Masked Generative Image Transformer (https://arxiv.org/pdf/2202.04200.pdf)
MIT License
398 stars 34 forks source link

Isn't loss only supposed to be calculated on masked tokens? #14

Open EmaadKhwaja opened 1 year ago

EmaadKhwaja commented 1 year ago

In the training loop we have:

imgs = imgs.to(device=args.device)
logits, target = self.model(imgs)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
loss.backward()

However, the output of the transformer is:

  _, z_indices = self.encode_to_z(x)
.
.
.
  a_indices = mask * z_indices + (~mask) * masked_indices

  a_indices = torch.cat((sos_tokens, a_indices), dim=1)

  target = torch.cat((sos_tokens, z_indices), dim=1)

  logits = self.transformer(a_indices)

  return logits, target

which means the returned target is the original unmasked image tokens.

The MaskGIT paper seems to suggest that loss was only calculated on the masked tokens

image

darius-lam commented 1 year ago

I've attempted both strategies for a simple MaskGIT on CIFAR10 but the generation quality seems to still be bad. There are tricks that the authors are not telling us in the paper for their training scheme

xuesongnie commented 1 year ago

I have the same issue. Why loss was calculated on all tokens?

EmaadKhwaja commented 1 year ago

@Lamikins I believe the training issues come from an error in the masking formula. I've ammended the error: https://github.com/dome272/MaskGIT-pytorch/pull/16.

@xuesongnie

xuesongnie commented 1 year ago

@EmaadKhwajareturn logits[~mask], target[~mask] seems a bit problematic, we should calculate masked token loss return logits[mask], target[mask]

EmaadKhwaja commented 1 year ago

@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do r = math.floor(1-self.gamma(np.random.uniform()) * z_indices.shape[1]), but I don't like that because it's different from how the formula appears in the paper

xuesongnie commented 11 months ago

@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do r = math.floor(1-self.gamma(np.random.uniform()) * z_indices.shape[1]), but I don't like that because it's different from how the formula appears in the paper

Hi, bro. I find that poor performance after modifying return logits[mask], target[mask]. It is weird. I guess the embedding layer also needs to train the corresponding unmasked token.