In JAX, certain transformations require all values to be inexact (e.g. jax.grad). In this issue, I propose how to deal with non-differentiable(noninexact) tree values with two function transformations.
First, the existing approach is to use pytc.static_field in the class definition on each non-inexact field; however, this requires the user to know in advance the type of data used.
I propose to add functionfilter_nondiff to mark nondifferentiable nodes static. and unfilter_nondiff to undo the marking of these fields. In essence unfilter_nondiff(filter_nondiff(x)) == x
Lets demonstrate these two transformation with an example,
In JAX, certain transformations require all values to be inexact (e.g.
jax.grad
). In this issue, I propose how to deal with non-differentiable(noninexact) tree values with two function transformations.First, the existing approach is to use
pytc.static_field
in the class definition on each non-inexact field; however, this requires the user to know in advance the type of data used. I propose to add functionfilter_nondiff
to mark nondifferentiable nodes static. andunfilter_nondiff
to undo the marking of these fields. In essenceunfilter_nondiff(filter_nondiff(x)) == x
Lets demonstrate these two transformation with an example,