Open LouisDesdoigts opened 1 year ago
Now you can jit Zodiax classes that have string parameters directly, both these examples throw errors if the 'value' leaf is not static!
import jax
@jax.jit
def update_foo(foo_cls):
foo_cls.set('value', 'baz')
return foo_cls
foo = FooException()
new_foo = update_foo(foo)
print(foo)
print(new_foo)
from jax import jit, grad, vmap, numpy as np
import matplotlib.pyplot as plt
@jit
@grad
def test_fn(x, foo_cls):
if foo_cls.value == 'foo':
return x**2
else:
return x**3
test_fn_vmap = vmap(test_fn, in_axes=(0, None))
foo = FooException()
xs = np.linspace(-1, 1, 100)
plt.plot(xs, test_fn_vmap(xs, foo), label='foo')
plt.plot(xs, test_fn_vmap(xs, foo.set('value', 2)), label='not foo')
plt.legend()
plt.savefig("output.png")
You can actually set static pytree leaves, this would overall allow for string values to be set to the class, while still having them recognised as pure-jax pytrees.
This issue is the method is mutable which is somewhat dangerous as it operates differently to the rest of the software. We can make the operations somewhat like what we expect, but its behavior is unusual.
This can be mitigated with a trade off, the immutable behavior can be achieved if a
jax.jit
is wrapped around an update function with the both the key and value set to static, we can get the desired behavior at the cost of a potential function re-compile every time the leaf is set to a new value, and that it can only be updated with hashable leaves. This may not be a problem if higher-leveljit
s are able to cancel out the recompile (needs testing).The implemented method also has a helpful error if you try to set the leaf to a non hashable value.
Notes: