SanghunYun / UDA_pytorch

UDA(Unsupervised Data Augmentation) implemented by pytorch
Apache License 2.0
275 stars 61 forks source link

kl divergence uses ori_prob instead of ori_log_prob #11

Closed stellaywu closed 3 years ago

stellaywu commented 3 years ago

Thanks for the nice implementation!

I'm trying to reproduce the results, wondering why for KL divergence loss you used original probability instead of original log probability to loss against augmented log probability. It looks different in the tensorflow implementation.

unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1) the original tensorflow implementation used per_example_kl_loss = kl_for_log_probs( tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask

Thanks !

stellaywu commented 3 years ago

sorry realized it's a pytorch thing