ASEM000 / pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.
https://pytreeclass.rtfd.io/en/latest
Apache License 2.0
42 stars 2 forks source link

dealing with non differentiable tree values under `jax.{grad,value_and_grad,...}`. #21

Closed ASEM000 closed 2 years ago

ASEM000 commented 2 years ago

In JAX, certain transformations require all values to be inexact (e.g. jax.grad). In this issue, I propose how to deal with non-differentiable(noninexact) tree values with two function transformations.

First, the existing approach is to use pytc.static_field in the class definition on each non-inexact field; however, this requires the user to know in advance the type of data used. I propose to add functionfilter_nondiff to mark nondifferentiable nodes static. and unfilter_nondiff to undo the marking of these fields. In essence unfilter_nondiff(filter_nondiff(x)) == x

Lets demonstrate these two transformation with an example,

@pytc.treeclass
class Test:
    a:int = 0
    b:float = 1. 
    c:jnp.ndarray = jnp.array([1,2,3])
    d:jnp.ndarray = jnp.array([1.,2.,3.])

t = Test()
print(t)
#Test(a=0,b=1.0,c=[1 2 3],d=[1. 2. 3.])

# * marks static field
print(filter_nondiff)
# Test(*a=0,b=1.0,*c=[1 2 3],d=[1. 2. 3.])

print(unfilter_nondiff(filter_nondiff(t)))
# Test(a=0,b=1.0,c=[1 2 3],d=[1. 2. 3.])