I'm trying to reproduce the results, wondering why for KL divergence loss you used original probability instead of original log probability to loss against augmented log probability. It looks different in the tensorflow implementation.
unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1)
the original tensorflow implementation used
per_example_kl_loss = kl_for_log_probs( tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask
Thanks for the nice implementation!
I'm trying to reproduce the results, wondering why for KL divergence loss you used original probability instead of original log probability to loss against augmented log probability. It looks different in the tensorflow implementation.
unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1)
the original tensorflow implementation usedper_example_kl_loss = kl_for_log_probs( tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask
Thanks !