tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
203 stars 20 forks source link

what's the meaning of implementing the hook_compute_diag function? #46

Closed zengjie617789 closed 2 years ago

zengjie617789 commented 2 years ago

I am confused about the details of the _hook_compute_diag that why should multiply the grad with x, which is the input of the layer before back propagation. Other implement of fisher matrix is like below:

for n, p in self.model.named_parameters():
                precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset) 

Anyone who can explain this , thank you in advance.

tfjgeorge commented 2 years ago

Hi

In the piece of code that you provide, you have an outer loop through each individual example of the dataset used to compute the Fisher. But that is inefficient, and you can get individual gradients using tricks such as what is described in https://arxiv.org/abs/1510.01799

In NNGeometry we leverage such tricks, in order to improve compute efficiency.

zengjie617789 commented 2 years ago

Thank you your instant response. The respository is awesome. Here is another problem that when I want to implement these code to some model, such as yolox which is anchor-free and the number of output is more than 403200 without decoding. I am confused how to set the num output . Obviously, It's not wise to set such large num to n_output. Finally, Could you give some suggestions on this? thank you in advance.

tfjgeorge commented 2 years ago

If I understand correctly your needs, I recommend you use the FIM_MonteCarlo metric instead of the FIM one. In the latter you will need to loop through all 403200 outputs, whereas in the former only the output with non-negligible probability will be sampled.

https://nngeometry.readthedocs.io/en/latest/api/metrics.html#nngeometry.metrics.FIM_MonteCarlo