Open Kabibi opened 5 years ago
The codes in this repo are actually using the empirical Fisher matrix, so take them with a grain of salt.
To compute the true Fisher matrix, you need to do compute the expectation w.r.t. p(y | x, \theta). One way to do that is to approximate it with MC integration: Do forward propagation to get the predictive distribution given x, then sample y ~ p(y | x, \theta) and do backprop and compute the outer product of the gradient. Do this many times and average, and you get the (approximation of) the true Fisher matrix.
I also have the same question, and I am trying to figure it out (but still stuck).
Even if the implementation calculates the empirical FIM, I cannot understand where the y - y*y
term comes from in the formula of grad_loglik_z
. In my below theoretical implication, the term (y-y*y
) should not be there, but if I remove it, the program won't work (nan
returned).
F = cov(grad_loglik_W)
grad_loglik_W = d_loglik / d_W = d_loglik_z * X_train
grad_loglik_z = grad_loglik_y * dy_dz # dy_dz is dz in the program
grad_loglik_y := derivative of log likelihood with respect to y
If the output y
follows a normal distribution, d_loglik_y
will be y - t_train
. However, in the code, it is (y - t_train) / (y - y * y)
. I understand the sign, so no problem, but cannot figure out why y - y * y
is there.
Just posting here if luckily, someone can explain it, while I am also trying to do it myself.
============
I nearly figure out the answer.
The distribution of y is a binomial distribution, not a normal distribution. The formula turns to be a classical form of the fisher scoring applied in GLM training procedure.
Yes, this particular code is for (binary) classification with Bernoulli likelihood. See https://github.com/wiseodd/natural-gradients/blob/7d51f19d315cb393a52a371162aaf2e27d20dbaa/numpy/toy/full_fisher.py#L25-L26
https://github.com/wiseodd/natural-gradients/blob/7d51f19d315cb393a52a371162aaf2e27d20dbaa/numpy/toy/full_fisher.py#L47
I have no idea how do you compute fisher information matrix. Specifically, I don't know how do you compute p(x|θ) without using a prior in your data. Can you explain? Thanks.