Open Eric8932 opened 2 years ago
During training, the tgt of the loss function is torch.arange(batch_size), which is suitable for unsupervised training. But this will overwrite the labels of the supervised training set.
For now, we only implemented the unsupervised part of simcse. If you are interested, you can contribute your code and open a PR.
During training, the tgt of the loss function is torch.arange(batch_size), which is suitable for unsupervised training. But this will overwrite the labels of the supervised training set.