Open gautierronan opened 1 month ago
A simple fix would be to replace is_array
with is_jax_array
(i.e. isinstance(..., jax.Array)
) in
https://github.com/patrick-kidger/equinox/blob/804d82e0fe65f85dd9f9a21f03d2784c164bbf8d/equinox/_module.py#L585
but I'm not sure this is in line with the intended use.
I would say this behavior is expected (whether not it is wanted maybe another question). Since in general numpy arrays are not hashable, and making things static is to set them as aux data in the pytree (https://github.com/patrick-kidger/equinox/blob/main/equinox/_module.py#L946), which expects hashability since "JAX sometimes needs to compare treedef for equality, or compute its hash for use in the JIT cache, and so care must be taken to ensure that the auxiliary data specified in the flattening recipe supports meaningful hashing and equality comparisons." (https://jax.readthedocs.io/en/latest/pytrees.html). This can sometimes cause silent or (speaking from personal experience) confusing errors, which is why I wanted to add the warning.
I agree, at the least, that the warning message is wrong because a JAX array isn't being set static (a numpy array is), so matching the message to the check should be done. As for numpy arrays, I think they were included originally because of their hash problems (to quote @/ jakevdp "Neither np.ndarray nor jax.Array satisfy this, so they should not be included in aux_data. If you do include such values in aux_data, you'll get unsupported, poorly-defined behavior."). That being said, there definitely are cases where using static arrays can be fine and correct (which is why the warning can be ignored as opposed to error), and if these cases are very common then the warning could be a burden. WDYT?
But in our example, the numpy array is just an intermediary for the computation. The actual computation is starting with a static tuple, and returning a static tuple, hence why I don't find that this should be an expected warning.
Also, I don't really see a way around it. The warning is being raised upon the class creation, so I don't see how we could filter this warning. In the case of our library, this will be raised everytime we make an operation on our class, and we really cannot make this class attribute not static.
Actually, investigating more, the following works without warning:
import numpy as np
import equinox as eqx
class Foo(eqx.Module):
x: tuple[int, int] = eqx.field(static=True)
def add_one(self):
x_as_np = np.asarray(self.x)
x_as_np += 1
x = tuple([i.item() for i in x_as_np])
return Foo(x)
x = (3, 2)
foo = Foo(x)
foo.add_one()
# no warning
So what's being detected in the first example is that the tuple elements are of type np.int
instead of int
.
Hmmm I see it yea I misread it, it's a int64 class from numpy. That would be an mis usage of the is_array
then, because np.int64's are hashable. I can just add a flag to exclude basic numpy dtypes to the check (maybe just excluding numpy generics from checking, are they all hashable?).
As of Equinox 0.11.6 and https://github.com/patrick-kidger/equinox/pull/800, the following MWE raises a
UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
This means that one cannot perform numpy operations (which is often simpler than writing them in plain python) on a static attribute. This is a use-case we have in dynamiqs, see for instance the method
__mul__
of this class which represents an array in diagonal (DIA) sparse format. Note that we intentionally usenumpy
instead ofjax.numpy
to have "static" logic.