patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.11k stars 142 forks source link

"JAX array is set a static" warning is raised unwantedly #863

Open gautierronan opened 1 month ago

gautierronan commented 1 month ago

As of Equinox 0.11.6 and https://github.com/patrick-kidger/equinox/pull/800, the following MWE raises a UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.

import numpy as np
import equinox as eqx

class Foo(eqx.Module):
    x: tuple[int, int] = eqx.field(static=True)

    def add_one(self):
        x_as_np = np.asarray(self.x)
        return Foo(tuple(x_as_np+1))

x = (3, 2)
foo = Foo(x)
foo.add_one()
# UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.

This means that one cannot perform numpy operations (which is often simpler than writing them in plain python) on a static attribute. This is a use-case we have in dynamiqs, see for instance the method __mul__ of this class which represents an array in diagonal (DIA) sparse format. Note that we intentionally use numpy instead of jax.numpy to have "static" logic.

gautierronan commented 1 month ago

A simple fix would be to replace is_array with is_jax_array (i.e. isinstance(..., jax.Array)) in https://github.com/patrick-kidger/equinox/blob/804d82e0fe65f85dd9f9a21f03d2784c164bbf8d/equinox/_module.py#L585 but I'm not sure this is in line with the intended use.

lockwo commented 1 month ago

I would say this behavior is expected (whether not it is wanted maybe another question). Since in general numpy arrays are not hashable, and making things static is to set them as aux data in the pytree (https://github.com/patrick-kidger/equinox/blob/main/equinox/_module.py#L946), which expects hashability since "JAX sometimes needs to compare treedef for equality, or compute its hash for use in the JIT cache, and so care must be taken to ensure that the auxiliary data specified in the flattening recipe supports meaningful hashing and equality comparisons." (https://jax.readthedocs.io/en/latest/pytrees.html). This can sometimes cause silent or (speaking from personal experience) confusing errors, which is why I wanted to add the warning.

I agree, at the least, that the warning message is wrong because a JAX array isn't being set static (a numpy array is), so matching the message to the check should be done. As for numpy arrays, I think they were included originally because of their hash problems (to quote @/ jakevdp "Neither np.ndarray nor jax.Array satisfy this, so they should not be included in aux_data. If you do include such values in aux_data, you'll get unsupported, poorly-defined behavior."). That being said, there definitely are cases where using static arrays can be fine and correct (which is why the warning can be ignored as opposed to error), and if these cases are very common then the warning could be a burden. WDYT?

gautierronan commented 1 month ago

But in our example, the numpy array is just an intermediary for the computation. The actual computation is starting with a static tuple, and returning a static tuple, hence why I don't find that this should be an expected warning.

Also, I don't really see a way around it. The warning is being raised upon the class creation, so I don't see how we could filter this warning. In the case of our library, this will be raised everytime we make an operation on our class, and we really cannot make this class attribute not static.

gautierronan commented 1 month ago

Actually, investigating more, the following works without warning:

import numpy as np
import equinox as eqx

class Foo(eqx.Module):
    x: tuple[int, int] = eqx.field(static=True)

    def add_one(self):
        x_as_np = np.asarray(self.x)
        x_as_np += 1
        x = tuple([i.item() for i in x_as_np])
        return Foo(x)

x = (3, 2)
foo = Foo(x)
foo.add_one()
# no warning

So what's being detected in the first example is that the tuple elements are of type np.int instead of int.

lockwo commented 1 month ago

Hmmm I see it yea I misread it, it's a int64 class from numpy. That would be an mis usage of the is_array then, because np.int64's are hashable. I can just add a flag to exclude basic numpy dtypes to the check (maybe just excluding numpy generics from checking, are they all hashable?).