kandorm / CLINE

Lexically Error Correction BERT.
49 stars 4 forks source link

torch.clamp(self.log_vars[0], min=0) #5

Closed wang304381190 closed 1 year ago

wang304381190 commented 2 years ago

total_loss = torch.exp(-self.log_vars[0]) * mlm_loss + torch.clamp(self.log_vars[0], min=0) + \ torch.exp(-self.log_vars[1]) * tec_loss + torch.clamp(self.log_vars[1], min=0) + \ torch.exp(-self.log_vars[2]) * sec_loss + torch.clamp(self.log_vars[2], min=0) Hello, I wonder why use torch.clamp here. torch.clamp can guarantee the value >= 0, but I think the value here (self.log_vars) are always positive. because we initialize self.log_vars to be zero and it will monotone increase during training, which found after I derive gradients manually. Is there something wrong with me? :)