flaport / sax

S + Autograd + XLA :: S-parameter based frequency domain circuit simulations and optimizations using JAX.
https://flaport.github.io/sax
Apache License 2.0
71 stars 17 forks source link

Location of klujax and grad-able solves with KLU #5

Closed ianwilliamson closed 2 years ago

ianwilliamson commented 2 years ago

Sax looks like a very interesting and useful package!

I was wondering where the code for klujax is located, since it doesn't seem to be part of sax.

Additionally, as I was reading the docs on the KLU backend, I was surprised to learn that the backend does not support gradients. If I understand the approach correctly, the calculation is a combination of matrix multiplications and a linear solve (using klujax). The gradient for the linear solve operation should just involve another linear solve, meaning that you should be able to define the gradient rule for JAX to make another call to klujax. Am I missing something?

flaport commented 2 years ago

Thanks @ianwilliamson ,

klujax is not yet open source, but will be very soon! (I just started the internal review process within my company)

I agree... It should be fairly easy to implement autograd on a linear solver. In fact, I have done this before for my PyTorch solver. The main reason why I haven't done this yet for klujax is basically laziness: I didn't really need it yet so far... for many of the optimization problems I'm using a smaller circuit anyway which can easily be modeled with the default backend.

flaport commented 2 years ago

Hi @ianwilliamson ,

klujax is now open source. Feel free to open an issue (or PR 😉) for autograd there!

closing this here as the discussion can be continued on the klujax repository.