total_loss = torch.exp(-self.log_vars[0]) * mlm_loss + torch.clamp(self.log_vars[0], min=0) + \ torch.exp(-self.log_vars[1]) * tec_loss + torch.clamp(self.log_vars[1], min=0) + \ torch.exp(-self.log_vars[2]) * sec_loss + torch.clamp(self.log_vars[2], min=0)
Hello, I wonder why use torch.clamp here. torch.clamp can guarantee the value >= 0, but I think the value here (self.log_vars) are always positive. because we initialize self.log_vars to be zero and it will monotone increase during training, which found after I derive gradients manually.
Is there something wrong with me? :)
total_loss = torch.exp(-self.log_vars[0]) * mlm_loss + torch.clamp(self.log_vars[0], min=0) + \ torch.exp(-self.log_vars[1]) * tec_loss + torch.clamp(self.log_vars[1], min=0) + \ torch.exp(-self.log_vars[2]) * sec_loss + torch.clamp(self.log_vars[2], min=0)
Hello, I wonder why use torch.clamp here. torch.clamp can guarantee the value >= 0, but I think the value here (self.log_vars) are always positive. because we initialize self.log_vars to be zero and it will monotone increase during training, which found after I derive gradients manually. Is there something wrong with me? :)