Closed lucagrementieri closed 2 years ago
Thanks for the details + example!
In general, having a static field depend on the value of a dynamic array is unlikely to be supported. For function transforms to work, it needs to be possible to instantiate each pytree type using only abstract tracer arrays; the error you're seeing is because the behavior of .sum().item()
is not well-defined on tracers.
While for your particular case the expected behavior is pretty clear, a more concrete example for why this should be avoided might be:
@jax.jit
def zeros(a: jnp.ndarray) -> jnp.ndarray:
container = PyTreeDataclass(a)
return jnp.zeros(container._sum) # `container._sum` should be an int
Note that a
will be an abstract tracer during JIT compilation, and it won't make sense to compute _sum
. The ultimate implication is that it becomes impossible to instantiate a PyTreeDataclass
in any JIT-compiled function, which seems overly limiting. The whole zeros()
function is of course also problematic because the shape of the output array depends on values in the input array, which isn't a pattern supported in XLA's static graph-based design.
For workarounds: the best off the top of my head is just to pass the value of _sum
as a separate input, or to move its computation out of the __post_init__
, both of which I imagine you've thought of. Maybe could have more thoughts if you have any more specifics on the application.
Surely I could pass it as a separate input, but I would like to avoid it to ensure the consistency of the data.
Moving the computation out of __post_init__
could be a good workaround for my use case, but I would like to keep it inside the dataclass, for example as a method. My problem is that if a jitted function needs the method, then its result is traced and I incur in the same problem.
Could you please specify how would you move the computation out of the __post_init__
while keeping it inside the dataclass?
It doesn't seem like there are any perfect options here, but two possible suggestions:
_sum
needs to be computed automatically, one option is to make a helper that does just that. This should run:
import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
@jdc.pytree_dataclass() class PyTreeDataclass: a: jnp.ndarray _sum: int = jdc.static_field()
@staticmethod
def setup(a: jnp.ndarray) -> "PyTreeDataclass":
# Note that `a` must be a concrete array. This method cannot be called
# inside a JIT-compiled function.
return PyTreeDataclass(a, _sum=a.sum().item())
def print_pytree(obj): print(obj._sum)
obj = PyTreeDataclass.setup(jnp.arange(4)) print_pytree(obj) jax.jit(print_pytree)(obj)
2. Or, perhaps `a` itself could be static? Arrays aren't hashable, so a tuple might work:
```python
@jdc.pytree_dataclass
class PyTreeDataclass:
a: Tuple[int, ...] = jdc.static_field()
@property
def _sum(self) -> int:
return sum(self.a)
or you could do something hacky to wrap the array and make it hashable:
@dataclasses.dataclass
class HashableArrayWrapper:
inner: onp.ndarray # Needs to initialized with a regular numpy array.
def __hash__(self) -> int:
return hash(str(self.inner.tolist()))
@jdc.pytree_dataclass
class PyTreeDataclass:
a: HashableArrayWrapper = jdc.static_field()
@property
def _sum(self) -> int:
return self.a.inner.sum().item()
Thank you very much! They are very good options!
First of all, thank you for the amazing library! I have recently discovered
jax_dataclasses
and I have decided to port my messy JAX functional code to a more organised object-oriented code based onjax_dataclasses
.In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.
Here is a simple example, where the attribute
_sum
is a derived static field that depends on the constant value of the arraya
.The non-jitted version works, but when
print_pytree
is jitted I get the following error.Is there a way to compute in the
__post_init__
the value of static fields not initialized in__init__
that depend onjnp.ndarray
attributes of the dataclass?