facebookresearch / SingleModelUncertainty

Learning error bars for neural network predictions
Other
68 stars 14 forks source link

Difficulty understanding the penalty on weights #1

Closed akhilpandey95 closed 4 years ago

akhilpandey95 commented 4 years ago

I was going through the code used for experimentation on the toy data set. I was having an issue understanding the penalty that is issued while issuing the orthonormal certificates.

opt.zero_grad()
error = c(xi[0]).pow(2).mean()
penalty = (c.weight @ c.weight.t() - torch.eye(k)).pow(2).mean()
(error + penalty).backward()
opt.step()

So, could you please explain what does the line penalty do, is it some sort of a regularization ?

PS : I am new to pytorch

lopezpaz commented 4 years ago

Dear @akhilpandey95,

That penalty is the orthonormality constraint appearing in Equation (4) in the paper: http://papers.nips.cc/paper/8870-single-model-uncertainties-for-deep-learning.pdf