Open LouisDesdoigts opened 1 year ago
Allow input of a parameter list, and partition/combine around this to make all but the parameters of interest 'static'.
This needs testing to see if anything is actually gained, and if this solves the string parameter issue.
from jax import Array import zodiax as zdx class Foo(zdx.Base): leaf : Array unit : str def __call__(self, new_unit: str): new_value = convert(self.leaf, self.unit, new_unit) # some simple unit conversion function return self.set(('leaf', 'unit'), (new_value, new_unit)) param = 'leaf' @zdx.filter_jit(param) @zdx.filter_grad(param) def f(pytree): return pytree('nm').leaf**2
Allow input of a parameter list, and partition/combine around this to make all but the parameters of interest 'static'.
This needs testing to see if anything is actually gained, and if this solves the string parameter issue.