LouisDesdoigts / zodiax

Object-oriented Jax framework extending Equinox for scientific programming
https://louisdesdoigts.github.io/zodiax/
BSD 3-Clause "New" or "Revised" License
11 stars 1 forks source link

Enable the setting of static fields? #29

Open LouisDesdoigts opened 1 year ago

LouisDesdoigts commented 1 year ago

You can actually set static pytree leaves, this would overall allow for string values to be set to the class, while still having them recognised as pure-jax pytrees.

This issue is the method is mutable which is somewhat dangerous as it operates differently to the rest of the software. We can make the operations somewhat like what we expect, but its behavior is unusual.

import zodiax as zdx

class Foo(zdx.Base):
    value : str = zdx.field(static=True)

    def __init__(self):
        self.value = 'foo'
class FooReturned(Foo):
    def set_static(self, key, value):
        object.__setattr__(self, key, value)
        return self

foo = FooReturned()

# This prints the updated object, like we expect an immutable class to work
print(foo.set_static('value', 'bar'))

# This prints the updated object too, even though we never re-assigned the variable 'foo'
print(foo)
FooReturned(value='bar')
FooReturned(value='bar')
class FooNotReturned(Foo):
    def set_static(self, key, value):
        object.__setattr__(self, key, value)

foo = FooNotReturned()
print(foo.set_static('value', 'bar'))

foo = foo.set_static('value', 'bar')
print(foo)
None
None

This can be mitigated with a trade off, the immutable behavior can be achieved if a jax.jit is wrapped around an update function with the both the key and value set to static, we can get the desired behavior at the cost of a potential function re-compile every time the leaf is set to a new value, and that it can only be updated with hashable leaves. This may not be a problem if higher-level jits are able to cancel out the recompile (needs testing).

from functools import partial

class FooException(Foo):
    test : float = 1.

    @partial(jax.jit, static_argnums=(1, 2,))
    def set_static(self, key, value):
        object.__setattr__(self, key, value)
        return self

    def set(self, key, value):
        try:
            return super().set(key, value)
        except AssertionError:
            try:
                return self.set_static(key, value)
            except ValueError:
                raise ValueError(f"Static leaves can only be set to hashable data types. Tried to set the leaf '{key}' with a {type(value)} type.")
        return self

foo = FooException()
print(foo.set('test', 2.))

foo = FooException()
print(foo.set('test', 2.).test)

print(foo.set('value', 'bar'))
print(foo)

foo = foo.set('value', 'bar')
print(foo)
FooException(value='foo', test=f32[])
2.0
FooException(value='bar', test=1.0)
FooException(value='foo', test=1.0)
FooException(value='bar', test=1.0)

The implemented method also has a helpful error if you try to set the leaf to a non hashable value.

foo = FooException()
print(foo.set('value', jax.numpy.ones(2)))
...
ValueError: Static leaves can only be set to hashable data types. Tried to set the leaf 'value' with a <class 'jaxlib.xla_extension.ArrayImpl'> type.

Notes:

LouisDesdoigts commented 1 year ago

Now you can jit Zodiax classes that have string parameters directly, both these examples throw errors if the 'value' leaf is not static!

import jax

@jax.jit
def update_foo(foo_cls):
    foo_cls.set('value', 'baz')
    return foo_cls

foo = FooException()
new_foo = update_foo(foo)
print(foo)
print(new_foo)
from jax import jit, grad, vmap, numpy as np
import matplotlib.pyplot as plt

@jit
@grad
def test_fn(x, foo_cls):
    if foo_cls.value == 'foo':
        return x**2
    else:
        return x**3

test_fn_vmap = vmap(test_fn, in_axes=(0, None))

foo = FooException()
xs = np.linspace(-1, 1, 100)

plt.plot(xs, test_fn_vmap(xs, foo), label='foo')
plt.plot(xs, test_fn_vmap(xs, foo.set('value', 2)), label='not foo')
plt.legend()
plt.savefig("output.png")

output