Code in line 80 of model.py :
fisher_diagonals = [(g ** 2).mean() for g in loglikelihood_grads]
should be:
fisher_diagonals = [(g ** 2).mean(dim=0) for g in loglikelihood_grads]
You are right, and this is a critical bug. I fixed it and updated the change log in the README. I also checked that the experimental results are reproduced. Thanks for reporting.
Code in line 80 of model.py :
fisher_diagonals = [(g ** 2).mean() for g in loglikelihood_grads]
should be:fisher_diagonals = [(g ** 2).mean(dim=0) for g in loglikelihood_grads]