patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
1.94k stars 132 forks source link

Segmentation fault in tree_flatten when subclass method is passed to superclass __init__ #327

Open tttc3 opened 1 year ago

tttc3 commented 1 year ago

I have the following code, which causes a segmentation fault when tree_flatten is called on any instance of SubComponent. I think it has to do with the jtu.Partial(method, self) in _wrap_method so may be related to #291. In any case, my guess is the following:

  1. tree_flatten goes through the leaves of component, and reaches transform.
  2. transform is a jtu.Partial(self.transform, self), due to the _wrap_method.
  3. the partial is itself a PyTree and so can itself be traversed for flattening.
  4. when the self leaf in the partial is reached, we have a leaf of type SubComponent, which can again be traversed further for flattening.
  5. this has now started a loop (returning back to step 1), that terminates with either a segmentation fault, or the maximum recursion depth error.

MWE:

from typing import Callable
import jax

import equinox as eqx

class Component(eqx.Module):
    transform: Callable[[float], float]
    validator: Callable[[Callable], Callable]

    def __init__(self, transform=lambda x: x, validator=lambda f: f) -> None:
        self.transform = transform  # Base transformation.
        self.validator = validator  # Validation wrapper around the transformation.

    def __call__(self, x):
        return self.validator(self.transform)(x)  # Execute validated transformation.

class SubComponent(Component):
    test: Callable[[float], float]

    def __init__(self, test=lambda x: 2 * x) -> None:
        self.test = test
        # From my understanding this ends up equivalent to
        # self.transform = jtu.Partial(self._transform, self)
        super().__init__(self._transform)

    # Custom implementation of transform that depends on information available in the
    # SubComponent PyTree. If `test` was modfied with tree_at, then this method and the
    # __call__, defined in the parent should both be updated.
    def _transform(self, x):
        return self.test(x)

a = SubComponent()
jax.tree_util.tree_flatten(a)  # Will cause segmentation fault
print(a)  # Will raise a maximum recursion depth error.

I don't really know a good solution, beyond the fact that I don't have the problem with the Pytree class from 'simple_pytrees'. For now I can subclass from the simple_pytrees Pytree, but I would much prefer to be using the eqx.Module. Of course, I may just be doing something wrong. Any help is greatly appreciated.

patrick-kidger commented 1 year ago

Your diagnosis of the problem is correct. What you've written down isn't a pytree, as it's recursively self-referential.

It would be a nice feature to automatically detect this error -- e.g. by adding some kind of eqx.is_pytree function that we call to validate module instances. As such I've marked this as a feature request. I've also opened https://github.com/google/jax/issues/15711 as a segfault isn't a very user-friendly error to see!

Relative to simple-pytree: the reason they don't get the same error is because they don't make bound methods be pytrees. In this case that happens to break a reference cycle, but in other cases it will result in silent correctness errors. For example this will erronenously produce zero gradients:

import equinox as eqx
import jax
import jax.numpy as jnp
from simple_pytree import Pytree

class Model(Pytree):
    def __init__(self, x):
        self.x = x

    def forward(self, y):
        return self.x + y

x = jnp.array(2.0)
y = jnp.array(2.0)
model = Model(x)

@eqx.filter_jit
@eqx.filter_grad
def run(fwd, y):
    return fwd(y)

print(run(model.forward, y))

Unfortunately I don't think there's any change we can make on the Equinox side other than to make this produce a clearer error message. I would suggest trying to re-express what you're doing. For example:

import abc
import equinox as eqx
from typing import Callable

class AbstractComponent(eqx.Module):
    @abc.abstractmethod
    def transform(self, x):
        ...

    @abc.abstractmethod
    def validator(self, f):
        ...

    def __call__(self, x):
        return self.validator(self.transform)(x)

class ConcreteComponent1(eqx.Module):
    transform: Callable[[float], float] = lambda x: x
    validator: Callable[[Callable], Callable] = lambda f: f

class ConcreteComponent2(eqx.Module):
    test: Callable[[float], float]

    def transform(self, x):
        return self.test(x)

    def validator(self, f):
        return f

This also follows a design pattern I'm a fan of -- that all concrete classes are final (i.e. cannot be subclassed). This helps to ensure easy-to-understand inheritance.

tttc3 commented 1 year ago

Thanks for pointing out the potential for silent correctness errors, no doubt this would have caused me some debugging nightmares further down the line! For my case, following your proposed design pattern is probably the best option.

If jax raised a proper error for the self-referential case, then can the validity of an instance be verified simply by checking that running jtu.tree_flatten and jtu.tree_unflatten on the instance doesn't raise an error (presuming that a PyTree is by definition something that can be successfully passed to these methods)? Maybe this check could be ran at the end of _ModuleMeta.__call__?

patrick-kidger commented 1 year ago

Yep, that'd make sense! The two technical hurdles I see to this are: (a) JAX doesn't provide an API to flatten only one level at a time. I'm not sure there's any JAX function we can call to iterate over the pytree that won't immediately segfault... (b) avoiding quadratic checking cost when building a tree (as the lower layers get checked multiple times, by every upper layer). We'd need some kind of caching, but that in turn runs into issues with edge cases when mutating lists/dictionaries.

If these can be figured out then I'd be happy to take a PR on this.

patrick-kidger commented 1 year ago

I (mostly) figured this one out; see #355. This will appear in the next version of Equinox.

This adds a new tree_check function, to validate that you don't have any self-references (as is the case here) or duplicate references (usually a bug). Right now it's not automatically applied to a module, due to the quadratic cost scaling. (That's the bit I haven't figured out.)

tttc3 commented 1 year ago

That's great and a really clever use of is_leaf (I'd not considered this possibility)! For the scaling problem, if jax (the python xla_client) raised a proper error for the self-referential case, then the contents of the tree nodes don't need to be explicitly known, the self-reference can be inferred from an excessive recursion depth. In this case, cost scaling is constant.

For duplicate checking, maybe something can be done in the xla_client that scales linearly, but via the existing python tree_flatten api I've not got any better ideas.