brentyi / jax_dataclasses

Pytrees + dataclasses ❤️
MIT License
61 stars 6 forks source link

Delayed initialisation of static fields #3

Closed lucagrementieri closed 2 years ago

lucagrementieri commented 2 years ago

First of all, thank you for the amazing library! I have recently discovered jax_dataclasses and I have decided to port my messy JAX functional code to a more organised object-oriented code based on jax_dataclasses.

In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.

Here is a simple example, where the attribute _sum is a derived static field that depends on the constant value of the array a.

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc

@jdc.pytree_dataclass()
class PyTreeDataclass:
    a: jnp.ndarray
    _sum: int = jdc.static_field(init=False, repr=False)

    def __post_init__(self):
        object.__setattr__(self, "_sum", self.a.sum().item())

def print_pytree(obj):
    print(obj._sum)

obj = PyTreeDataclass(jnp.arange(4))
print_pytree(obj)
jax.jit(print_pytree)(obj)

The non-jitted version works, but when print_pytree is jitted I get the following error.

File "jax_dataclasses_issue.py", line 14, in __post_init__
    object.__setattr__(self, "_sum", self.a.sum().item())
AttributeError: 'bool' object has no attribute 'sum'

Is there a way to compute in the __post_init__ the value of static fields not initialized in __init__ that depend on jnp.ndarray attributes of the dataclass?

brentyi commented 2 years ago

Thanks for the details + example!

In general, having a static field depend on the value of a dynamic array is unlikely to be supported. For function transforms to work, it needs to be possible to instantiate each pytree type using only abstract tracer arrays; the error you're seeing is because the behavior of .sum().item() is not well-defined on tracers.

While for your particular case the expected behavior is pretty clear, a more concrete example for why this should be avoided might be:

@jax.jit
def zeros(a: jnp.ndarray) -> jnp.ndarray:
    container = PyTreeDataclass(a)
    return jnp.zeros(container._sum)   # `container._sum` should be an int

Note that a will be an abstract tracer during JIT compilation, and it won't make sense to compute _sum. The ultimate implication is that it becomes impossible to instantiate a PyTreeDataclass in any JIT-compiled function, which seems overly limiting. The whole zeros() function is of course also problematic because the shape of the output array depends on values in the input array, which isn't a pattern supported in XLA's static graph-based design.

For workarounds: the best off the top of my head is just to pass the value of _sum as a separate input, or to move its computation out of the __post_init__, both of which I imagine you've thought of. Maybe could have more thoughts if you have any more specifics on the application.

lucagrementieri commented 2 years ago

Surely I could pass it as a separate input, but I would like to avoid it to ensure the consistency of the data.

Moving the computation out of __post_init__ could be a good workaround for my use case, but I would like to keep it inside the dataclass, for example as a method. My problem is that if a jitted function needs the method, then its result is traced and I incur in the same problem.

Could you please specify how would you move the computation out of the __post_init__ while keeping it inside the dataclass?

brentyi commented 2 years ago

It doesn't seem like there are any perfect options here, but two possible suggestions:

  1. If _sum needs to be computed automatically, one option is to make a helper that does just that. This should run:
    
    import jax
    import jax.numpy as jnp
    import jax_dataclasses as jdc

@jdc.pytree_dataclass() class PyTreeDataclass: a: jnp.ndarray _sum: int = jdc.static_field()

@staticmethod
def setup(a: jnp.ndarray) -> "PyTreeDataclass":
    # Note that `a` must be a concrete array. This method cannot be called
    # inside a JIT-compiled function.
    return PyTreeDataclass(a, _sum=a.sum().item())

def print_pytree(obj): print(obj._sum)

obj = PyTreeDataclass.setup(jnp.arange(4)) print_pytree(obj) jax.jit(print_pytree)(obj)


2. Or, perhaps `a` itself could be static? Arrays aren't hashable, so a tuple might work:
```python
@jdc.pytree_dataclass
class PyTreeDataclass:
    a: Tuple[int, ...] = jdc.static_field()

    @property
    def _sum(self) -> int:
        return sum(self.a)

or you could do something hacky to wrap the array and make it hashable:

@dataclasses.dataclass
class HashableArrayWrapper:
    inner: onp.ndarray   # Needs to initialized with a regular numpy array.
    def __hash__(self) -> int:
        return hash(str(self.inner.tolist()))

@jdc.pytree_dataclass
class PyTreeDataclass:
    a: HashableArrayWrapper = jdc.static_field()

    @property
    def _sum(self) -> int:
        return self.a.inner.sum().item()
lucagrementieri commented 2 years ago

Thank you very much! They are very good options!