Zasder3 / train-CLIP

A PyTorch Lightning solution to training OpenAI's CLIP from scratch.
MIT License
653 stars 78 forks source link

Fix calculation of logits for training loss #3

Closed bob80333 closed 3 years ago

bob80333 commented 3 years ago

The embeddings must be l2 normalized before calculating the logits. This fix corrects both the loss that is logged and the training loss. The validation loss was already correct, as the model forward function normalizes the embeddings before computing the logits.

Zasder3 commented 3 years ago

Really good catch! My only change was using torch.nn.functional.normalize to shrink the code a little bit.