LouisDesdoigts / zodiax

Object-oriented Jax framework extending Equinox for scientific programming
https://louisdesdoigts.github.io/zodiax/
BSD 3-Clause "New" or "Revised" License
11 stars 1 forks source link

Test and Implement optimised `filter_jit` #31

Open LouisDesdoigts opened 1 year ago

LouisDesdoigts commented 1 year ago

Allow input of a parameter list, and partition/combine around this to make all but the parameters of interest 'static'.

This needs testing to see if anything is actually gained, and if this solves the string parameter issue.

from jax import Array
import zodiax as zdx

class Foo(zdx.Base):
    leaf : Array
    unit : str

    def __call__(self, new_unit: str):
        new_value = convert(self.leaf, self.unit, new_unit) # some simple unit conversion function
        return self.set(('leaf', 'unit'), (new_value, new_unit))

param = 'leaf'
@zdx.filter_jit(param)
@zdx.filter_grad(param)
def f(pytree):
    return pytree('nm').leaf**2