Open aleximmer opened 3 years ago
I face the problem that I want to compute the Laplace approximation of ResNets, which contain BatchNorm. I intend to ignore them and warn the user about it. More concretely, I will send a warning to the user and will initialize the kronecker factors for the Batchnorm layers to 0. Since the backend also returns 0 as values for these layers, they should infact just be kept ignored. I plan to update the Asdl interface for that.
Is the approach correct ? I can implement it and create a pull request.
Just returning them as zero could be fine within the backend but might be problematic for the posterior approximation. The prior will still be non-zero and when Computing the predictive or marginal likelihood, the batch norm curvature will be non-zero. I think what could be safer is to return a single factor as None and then handle these cases in the matrix.py module within the Kronecker classes. This way you can make sure that they do not impact any important quantity.
Currently, these do not support BatchNorm due to the backends but this should not fail silently when
all
-weights Laplace is used on networks with Batchnorm.