Open steven-murray opened 2 years ago
Awesome, working on this now (sorry for the delay !). I think the only thing that might be a little difficult is the rewritten dft
library. I'll get everything subclassed on my end for the Powerbox
script and see how it shakes out.
Hi @tlmakinen,
Thanks again for putting this awesome package together!
As we discussed, it would be great to have tighter integration of this with the base powerbox code. As I envision it, the best way will be to maintain this as its own package right here. However, instead of copying the
Powerbox
class frompowerbox
, it should inherit from it. We should aim for exact drop-in replacement so that people can doimport powerbox_jax as powerbox
and get exactly the same results as doingimport powerbox
(but differentiable!). Where not possible to supplement a given method with the jaxified version, we should raise either an error or warning.To do this, I think the powerbox module in this package should look something like this:
To be clear, you'll only need to define methods where the computation needs to be changed using specific
jax
commands (eg. boolean indexing).If the method that needs to be changed is more than a few lines long (and you have to change one line of it in the middle), then please make an issue on powerbox suggesting I split this method up so you can overwrite a smaller method!