Open romanodev opened 3 months ago
I think this is working as intended. We don't support any notion of differentiating with respect to options
, so a user should explicitly opt out of this -- rather than potentially getting silently unexpected gradients.
On using an initial guess for the backward pass -- indeed, right now we don't seem to support this. Probably the correct thing to do would be to just use the transpose of the initial guess for the forward pass, by filling in these two methods:
I'd be happy to take a PR on this!
I see. I will look into it. Great library, BTW!
The gradient of the solution of a linear system solved iteratively w.r.t. to the initial guess should be zero. Instead, the following snippet
gives [
lineax
version0.0.5
]The problem is quickly resolved by using
jax.lax.stop_gradient
For reference, JAX's solver works fine
Even though this is a corner case, it may happen that the first guess is traced (it was my use case) in a more complex computational graph. Also, it would be great to be able to specify the first guess for the backward pass.