Open harish-kamath opened 8 months ago
You can configure the PyTorch loss function to take log of targets or just targets. By default the targets are not in log-space and so this is what I used. There may be numerical stability benefits but honestly I don’t remember if there was some rationale behind this.
There are examples of both in the docs: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
https://github.com/collabora/WhisperSpeech/blob/80b268b74900b2f7ca7a36a3c789607a3f4cd912/whisperspeech/vq_stoks.py#L344
Why use log softmax on the model logits, but softmax on the teacher logits?