Closed alessandrosimon closed 1 year ago
Do you have a minimal example?
"Unfortunately" the method seems to work for a small test network. I think the problem lies in some later transformation that I apply to the output of the described layer. What I don't understand is that the gradient is non-zero for a simple 'single step' evaluation but vanishes as soon as a second iteration is done (setting maxiter=2
for example).
OK, I think I know what the problem was. During the backward pass when the linear system of the inverse Jacobian is solved (by default with solve_cg
in linear_solve.py
) it doesn't find any solution and just returns the initial starting value, which is a zero vector. I guess it would be nice if there was some kind of error message in the case of non-convergence.
I have a neural network in flax that is basically a function expansion in a non-linear basis set, imagine
and I want to find parameters such that for a given X the function f( . , p) has a fix-point at X. The idea was to use Anderson Acceleration together with implicit differentiation to tune the parameters. The problem is that the gradient of the FP z = f(z, p) wrt p is zero. I checked the output of the solver (verbose=True) and it successfully finds a solution in less than
maxiter
iterations.The normal gradient operation through the network seems to work just fine, because if I take the returned fix-point solution and do one further iteration manually, I do get a non-vanishing gradient wrt the parameters. I could look at the generated jaxpr but the procedure is quite long/complicated so I don't think it would help much.