The embeddings must be l2 normalized before calculating the logits. This fix corrects both the loss that is logged and the training loss. The validation loss was already correct, as the model forward function normalizes the embeddings before computing the logits.
The embeddings must be l2 normalized before calculating the logits. This fix corrects both the loss that is logged and the training loss. The validation loss was already correct, as the model forward function normalizes the embeddings before computing the logits.