Sara-Ahmed / SiT

Self-supervised vIsion Transformer (SiT)
320 stars 49 forks source link

Passing token logits to the loss #8

Closed julian-carpenter closed 3 years ago

julian-carpenter commented 3 years ago

First of all, I want to thank you for making the code available. It is well written and easily understandable.

I have a question about the rotational token and the tokens used for reconstructing the original image. I'm am not an expert on using PyTorch as I have always used tensorflow, so please forgive me if I ask stupid things.

As far as I can see, you're defining all the heads for the SSL loss here: https://github.com/Sara-Ahmed/SiT/blob/27cc31d2206168bdab7b9f4f9e10412ff3641d69/vision_transformer_SiT.py#L193-L206 And as far as I understand it, there is no final activation on these heads; They return logits, am I correct?

In the train_SSL routine, you then pass these logits to your criterion routine: https://github.com/Sara-Ahmed/SiT/blob/27cc31d2206168bdab7b9f4f9e10412ff3641d69/engine.py#L146-L148

Which, if the train_mode is SSL, is the MTL_loss routine. There the logits are passed directly into their respective loss functions: https://github.com/Sara-Ahmed/SiT/blob/27cc31d2206168bdab7b9f4f9e10412ff3641d69/losses.py#L32-L51

Am I missing something? Especially in the reconstruction case, I think this cannot work, as you have normalized original images and unnormalized reconstructed image-logits and calculate the l1 loss between them. The contrastive loss should be fine, as you normalize the logits in the loss function. However, the CE loss for the rotational token is also calculated without prior activation, and I wonder why.

Could you point me to the error in my thinking? Thanks

Sara-Ahmed commented 3 years ago

Happy that you liked the work. For the rotation loss, I am using CrossEntropyLoss, and it has softmax embedded. For the reconstruction loss, it is a regression problem, there is no need to normalize or apply any activations prior to the loss function. I am not sure if it will help, but Transformer is quite strong anyway, and it is able to reconstruct the image in few epochs if you do not apply severe corruptions.

julian-carpenter commented 3 years ago

Alrighty, thank you for the explanations! I close the Issue