Closed davidweichiang closed 2 years ago
It's the updates that look like psi_X0[...] = something. Changing them to not overwrite psi_X0 fixes the problem (see latest commit in the gradients branch). However, this isn't the right solution, because we only need to do this for the non-iterative cases. (How much time is saved by overwriting instead of making a new tensor?)
What we actually need is a special case for when all the equations are linear. The newton
function will work for this case and terminate after one iteration, but it seems like there are a lot of optimizations that can be done for this (very common) case. This special case should not do any in-place operations, so the standard autodiff should work on it.
The non-iterative cases of sum_product should be differentiable, but it's not working yet. What I have so far is in branch https://github.com/diprism/fgg-implementation/tree/gradients