Open tttc3 opened 1 year ago
Your diagnosis of the problem is correct. What you've written down isn't a pytree, as it's recursively self-referential.
It would be a nice feature to automatically detect this error -- e.g. by adding some kind of eqx.is_pytree
function that we call to validate module instances. As such I've marked this as a feature request. I've also opened https://github.com/google/jax/issues/15711 as a segfault isn't a very user-friendly error to see!
Relative to simple-pytree: the reason they don't get the same error is because they don't make bound methods be pytrees. In this case that happens to break a reference cycle, but in other cases it will result in silent correctness errors. For example this will erronenously produce zero gradients:
import equinox as eqx
import jax
import jax.numpy as jnp
from simple_pytree import Pytree
class Model(Pytree):
def __init__(self, x):
self.x = x
def forward(self, y):
return self.x + y
x = jnp.array(2.0)
y = jnp.array(2.0)
model = Model(x)
@eqx.filter_jit
@eqx.filter_grad
def run(fwd, y):
return fwd(y)
print(run(model.forward, y))
Unfortunately I don't think there's any change we can make on the Equinox side other than to make this produce a clearer error message. I would suggest trying to re-express what you're doing. For example:
import abc
import equinox as eqx
from typing import Callable
class AbstractComponent(eqx.Module):
@abc.abstractmethod
def transform(self, x):
...
@abc.abstractmethod
def validator(self, f):
...
def __call__(self, x):
return self.validator(self.transform)(x)
class ConcreteComponent1(eqx.Module):
transform: Callable[[float], float] = lambda x: x
validator: Callable[[Callable], Callable] = lambda f: f
class ConcreteComponent2(eqx.Module):
test: Callable[[float], float]
def transform(self, x):
return self.test(x)
def validator(self, f):
return f
This also follows a design pattern I'm a fan of -- that all concrete classes are final (i.e. cannot be subclassed). This helps to ensure easy-to-understand inheritance.
Thanks for pointing out the potential for silent correctness errors, no doubt this would have caused me some debugging nightmares further down the line! For my case, following your proposed design pattern is probably the best option.
If jax raised a proper error for the self-referential case, then can the validity of an instance be verified simply by checking that
running jtu.tree_flatten
and jtu.tree_unflatten
on the instance doesn't raise an error (presuming that a PyTree is by definition something that can be successfully passed to these methods)? Maybe this check could be ran at the end of _ModuleMeta.__call__
?
Yep, that'd make sense! The two technical hurdles I see to this are: (a) JAX doesn't provide an API to flatten only one level at a time. I'm not sure there's any JAX function we can call to iterate over the pytree that won't immediately segfault... (b) avoiding quadratic checking cost when building a tree (as the lower layers get checked multiple times, by every upper layer). We'd need some kind of caching, but that in turn runs into issues with edge cases when mutating lists/dictionaries.
If these can be figured out then I'd be happy to take a PR on this.
I (mostly) figured this one out; see #355. This will appear in the next version of Equinox.
This adds a new tree_check
function, to validate that you don't have any self-references (as is the case here) or duplicate references (usually a bug).
Right now it's not automatically applied to a module, due to the quadratic cost scaling. (That's the bit I haven't figured out.)
That's great and a really clever use of is_leaf (I'd not considered this possibility)! For the scaling problem, if jax (the python xla_client) raised a proper error for the self-referential case, then the contents of the tree nodes don't need to be explicitly known, the self-reference can be inferred from an excessive recursion depth. In this case, cost scaling is constant.
For duplicate checking, maybe something can be done in the xla_client that scales linearly, but via the existing python tree_flatten api I've not got any better ideas.
I have the following code, which causes a segmentation fault when tree_flatten is called on any instance of SubComponent. I think it has to do with the
jtu.Partial(method, self)
in_wrap_method
so may be related to #291. In any case, my guess is the following:transform
.transform
is a jtu.Partial(self.transform, self), due to the_wrap_method
.self
leaf in the partial is reached, we have a leaf of typeSubComponent
, which can again be traversed further for flattening.MWE:
I don't really know a good solution, beyond the fact that I don't have the problem with the
Pytree
class from 'simple_pytrees'. For now I can subclass from the simple_pytreesPytree
, but I would much prefer to be using the eqx.Module. Of course, I may just be doing something wrong. Any help is greatly appreciated.