tlmakinen / powerbox-jax

Jax implementation of Steven Murray's powerbox: https://github.com/steven-murray/powerbox
Apache License 2.0
5 stars 1 forks source link

Integration with default powerbox #1

Open steven-murray opened 2 years ago

steven-murray commented 2 years ago

Hi @tlmakinen,

Thanks again for putting this awesome package together!

As we discussed, it would be great to have tighter integration of this with the base powerbox code. As I envision it, the best way will be to maintain this as its own package right here. However, instead of copying the Powerbox class from powerbox, it should inherit from it. We should aim for exact drop-in replacement so that people can do import powerbox_jax as powerbox and get exactly the same results as doing import powerbox (but differentiable!). Where not possible to supplement a given method with the jaxified version, we should raise either an error or warning.

To do this, I think the powerbox module in this package should look something like this:

# powerbox.py 
import powerbox as _pb
from jax import numpy as np

# replace numpy with jaxified version.
_pb.np = np

class Powerbox(_pb.Powerbox):
    def delta_k():
        # <jaxified code>

To be clear, you'll only need to define methods where the computation needs to be changed using specific jax commands (eg. boolean indexing).

If the method that needs to be changed is more than a few lines long (and you have to change one line of it in the middle), then please make an issue on powerbox suggesting I split this method up so you can overwrite a smaller method!

tlmakinen commented 2 years ago

Awesome, working on this now (sorry for the delay !). I think the only thing that might be a little difficult is the rewritten dft library. I'll get everything subclassed on my end for the Powerbox script and see how it shakes out.