yabufarha / ms-tcn

Other
214 stars 58 forks source link

About implementation of KL divergence Loss #2

Closed Rheelt closed 5 years ago

Rheelt commented 5 years ago

Thank you for your contribution! I use torch.nn.KLDivLoss to implement KL Loss, but the loss I get is NAN. So can you tell me the details of KL Loss implemented in MS-TCN. Thanks.

yabufarha commented 5 years ago

Hi, We used the following line of code for the KL divergence loss: self.kldiv(F.log_softmax(p[:, :, 1:], dim=1), F.softmax(p.detach()[:, :, :-1], dim=1)) where self.kldiv = nn.KLDivLoss() As in the T-MSE loss, This loss is also weighted with \lambda = 0.15

I hope this would help.

Best, Yazan

Rheelt commented 5 years ago

Thanks for your reply.