Closed ianwilliamson closed 2 years ago
Thanks @ianwilliamson ,
klujax is not yet open source, but will be very soon! (I just started the internal review process within my company)
I agree... It should be fairly easy to implement autograd on a linear solver. In fact, I have done this before for my PyTorch solver. The main reason why I haven't done this yet for klujax is basically laziness: I didn't really need it yet so far... for many of the optimization problems I'm using a smaller circuit anyway which can easily be modeled with the default backend.
Sax looks like a very interesting and useful package!
I was wondering where the code for klujax is located, since it doesn't seem to be part of sax.
Additionally, as I was reading the docs on the KLU backend, I was surprised to learn that the backend does not support gradients. If I understand the approach correctly, the calculation is a combination of matrix multiplications and a linear solve (using klujax). The gradient for the linear solve operation should just involve another linear solve, meaning that you should be able to define the gradient rule for JAX to make another call to klujax. Am I missing something?