Open paganpasta opened 2 years ago
So there's basically two main patterns here.
static_field
. Always filter things out using eqx.filter_{jit, grad, ...}
.static_field
on every leaf that isn't a floating-point JAX array, and then use jax.{jit, grad, ...}
. This does require that (a) you have full control over the entire model PyTree (e.g. you're not building a library for another user); (b) you have no integer JAX arrays (which are neither differentiable nor static-able).Approach 1 is what I recommend, as it's more general. Approach 2 is nice in that it works with native JAX operations, though.
FWIW eqx.nn
currently follows neither of these, which is a choice I'm a little dissatisfied with. It certainly can't follow approach 2 because we don't know in advance e.g. whether eqx.nn.MLP().activation
is a JAX type or not (or even whether it's a leaf type or not). It does however use some static_field
attributes so that simple models like eqx.nn.Linear
(but not eqx.nn.MLP
) work with jax.{jit, grad, ...}
without requiring the use of the filtered transformations. I originally felt like this might ease the new-user onboarding process, but I don't know if that's really true.
In an ideal world I'd arrange for jax.{jit, grad, ...}
to mimic the default behaviour of eqx.filter_{jit, grad, ...}
: automatically filter out non-JAX-types etc. *
(Even if they don't offer customisable filtering like Equinox's do.) Then this distinction would be unimportant and we could avoid using static_field
entirely. Maybe some day!
*
The filtered transformations actually also fix a couple of other issues with the native transformations: jitting non-function callables like bound methods and class instances; avoiding re-jitting when callling jit(grad(f))
twice.
Thanks for the detailed response! I am happy to stick with filter_*
for stuff built on top of equinox
.
So probably the p
field in nn.Dropout
also requires a static_field
for consistency with other nn.*
?
https://github.com/patrick-kidger/equinox/blob/a89d5b486d13588caffc095f172a2ec39fd68278/equinox/nn/dropout.py#L21
So this is deliberately non-static because the forward pass of nn.Dropout
is fine with having that as a tracer.
In some sense the same is true of nn.Linear.in_features
etc. as well -- these aren't used during the forward pass. But in some sense this isn't true of nn.Linear.in_features
etc., as this corresponds to an array size and as such might be used in some shape arithmetic during a forward pass. (e.g. someone using nn.Linear.in_features
as part of a .reshape()
).
FWIW I am tempted to remove the static_field
annotations from all of eqx.nn
and see what happens. This debate over when to use static fields keeps coming up, and it doesn't really help that the built-in library is using them in ways we don't advise end users to.
Thanks! I agree on removing the staticfield and moving towards the `filter*` approach for everything. I can give it a shot and see if everything is still passing.
Sounds good!
Unsurprisingly couple of tests fail. I just wanted to make sure that the tests need fixing and its not the behaviour breaking.
Small representative example of a failing test case adapted from test_filter_jit.test_filter_jit1
.
import jax
import equinox as eqx
def test_no_static(getkey):
def h(x):
return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x)
h = eqx.filter_jit(h, filter_spec=eqx.is_array_like)
og_lin = eqx.nn.Linear(2, 2, key=getkey())
_lin = jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, og_lin)
new_lin = h(og_lin)
assert new_lin == _lin
This assertion fails as _lin contains all is_array_like
leaves but somewhere down eqx.filter_jit
, new_lin
seems to have undergone an is_array
transformation. I tracked the calls up until
https://github.com/google/jax/blob/d5fdd9e2664602430ea0ecfb59fdef7f8692862f/jax/_src/api.py#L525-L530
where, out_flat
is
[
DeviceArray([[-0.35848978, -0.23153338], [-0.5653428 , -0.46271053]], dtype=float32),
DeviceArray([0.21585053, 0.14375919], dtype=float32),
DeviceArray(2, dtype=int32, weak_type=True),
DeviceArray(2, dtype=int32, weak_type=True),
DeviceArray(True, dtype=bool, weak_type=True)
]
out_pytree_def
is
PyTreeDef(
(
CustomNode(
<class 'equinox.nn.linear.Linear'>
[(('weight', 'bias', 'in_features', 'out_features', 'use_bias'), (), ())],
[*, *, *, *, *]),
CustomNode(<class 'equinox.compile_utils.Static'>
[
((), ('value',),
(Linear(
weight=None,
bias=None,
in_features=None,
out_features=None,
use_bias=None
),))
]
, [ ])
)
)
but out
is
(Linear(
weight=f32[2,2],
bias=f32[2],
in_features=i32[],
out_features=i32[],
use_bias=bool[]
), Static(
value=Linear(
weight=None,
bias=None,
in_features=None,
out_features=None,
use_bias=None
)
))
Is this behaviour expected and jit
is autmagically removing non-jax
types?
A little update on this
import jax
import equinox as eqx
def test_no_static(getkey):
def h(x):
return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x)
h = eqx.filter_jit(h, filter_spec=eqx.is_array_like)
og_lin = eqx.nn.Linear(2, 2, key=getkey())
_lin = jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, og_lin)
new_lin = h(og_lin)
assert new_lin != og_lin and type(new_lin) == type(og_lin)
def H(x):
return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x)
H = jax.jit(H)
jit_lin = H(og_lin)
assert jit_lin == new_lin and type(new_lin) == type(jit_lin)
This test passes
. I suppose this is an expected behaviour and everything is working as intended?
So what's going on here is that the previously-static leaves are array-like, so they get selected for tracing by filter_jit(filter_spec=is_array_like)
. As a result they are converted into JAX arrays.
Meanwhile equality between Modules asks that all leaves have the same type, which of course is now no longer true.
So this is exactly the kind of breakage I was expecting, it's just a question of whether it's not-too-bad. This requires you to use filter_jit
with a non-default filter spec (which at least in my own code is unusual).
Thanks for clarifying this.
How do you want to proceed with this? Also as a side note, it would be helpful to add examples using fn
, out
and default
to the documentation of filter_jit
.
Which tests actually break under this change? (+with what error if it's interesting.)
Also, you might try running the tests for Diffrax under that change and see if any of those are affected?
I only made changes within nn
and left out rest, like experimental
.
These (2) test cases fail
assert Linear(weight=None, bias=None, in_features=1, out_features=1, use_bias=True)
== Linear(\n weight=None,\n bias=None,\n in_features=None,\n out_features=None,\n use_bias=None\n)
> assert _eq(h2[5], _mlp)
E assert False
E + where False = _eq(MLP(\n layers=[\n Linear(\n weight=f32[2,2],\n bias=f32[2],\n in_features=i32[],\n
out_features=i32[],\n use_bias=bool[]\n ),\n
Linear(\n weight=f32[2,2],\n bias=f32[2],\n in_features=i32[],\n out_features=i32[],\n use_bias=bool[]\n ),\n
Linear(\n weight=f32[2,2],\n bias=f32[2],\n in_features=i32[],\n out_features=i32[],\n use_bias=bool[]\n )\n ],\n
activation=None,\n final_activation=None,\n in_size=i32[],\n out_size=i32[],\n width_size=i32[],\n depth=i32[]\n),
MLP(\n layers=[\n Linear(\n weight=f32[2,2],\n bias=f32[2],\n in_features=2,\n out_features=2,\n
use_bias=True\n ),\n
Linear(\n weight=f32[2,2],\n bias=f32[2],\n in_features=2,\n out_features=2,\n use_bias=True\n ),\n
Linear(\n weight=f32[2,2],\n bias=f32[2],\n in_features=2,\n out_features=2,\n use_bias=True\n )\n ],\n
activation=None,\n final_activation=None,\n in_size=2,\n out_size=2,\n width_size=2,\n depth=2\n))
I'll try Diffrax as soon as I get the chance.
For me a common use of static_field() is for flow control. So please do not remove it entirely.
class LinearAndActivation(eqx.Module):
m: eqx.Module
apply_act: bool = eqx.static_field()
def __init__(self, n_in, n_out, apply_act, key):
self.m = eqx.nn.Linear(n_in, n_out, key=key)
self.apply_act = apply_act
@jax.jit
def __call__(self, x):
x = self.m(x)
if self.apply_act:
x = jax.nn.relu(x)
return x
Don't worry, it definitely isn't going away. If nothing else it's a key part of how things work internally.
On balance I think I'm inclined to maintain the status quo. This clearly does break a couple of tests, and whilst they're pretty minor it's not encouraging. To rebut the argument (my argument) in favour of not using static_field
: if an end-user overzealously applies a static_field
then I don't think it should actually matter -- you can't smuggle in JAX arrays this way as they're not hashable. So extra static_field
s should never break you silently.
The more I am interacting with equinox <=> jax/optax I am thinking having static_fields()
is not that bad to have as default. Saves the trouble of repeating filtering at places.
Hello, I am new to Equinox and I was wondering why it is necessary to declare all internal fields of modules before initializing them and why adding fields at runtime is impossible. I guess this is linked to them being PyTrees, but I believe PyTrees can grow "branches".
Also, in my experience, modules usually have few parameter/module branches (e.g. weight and bias in a linear layer) and many static branches (hyper-parameters, settings, flags, etc). Hence, wouldn't it be easier to indicate the former (e.g. with eqx.Parameter
and eqx.Module
) and consider everything else as static, similar to PyTorch? This would eliminate the need for static_field
and maybe ease the job of eqx.filter*
.
Declaring the fields in advance is a syntax that we inherit from dataclasses. (Each eqx.Module
is a dataclass.) Likewise, dataclasses don't allow adding additional fields at runtime.
In principle Equinox could have done something slightly different here -- not built on top of dataclasses and done our own thing instead. Which would have been fine too! (Although in practice I do particularly like not being able to add fields at runtime -- mutation is generally pretty dangerous and never actually needed, so this is a good way to avoid bugs.)
Regarding flipping static/dynamic: for basically any eqx.Module
s that you ever write yourself, I'd recommend essentially never using eqx.static_field
. Just leave every field dynamic. eqx.filter_{jit, grad, ...}
will simply filter out anything that isn't an array, as usual.
To be honest, if I was going to do Equinox again, I would probably avoid the dynamic/static field distinction -- and just leave everything dynamic. Likewise I might not use dataclasses, and instead have some other syntax for initialisation. Realistically these aren't important enough to justify a breaking change at this point, though.
will simply filter out anything that isn't an array, as usual.
Oh so there is never the need for static_field
except for jax arrays that are not "parameters"? Even if the field is used in __call__
?
jax arrays that are not "parameters"
JAX arrays can never be static_field
s. A static field must be hashable, as it's used to form the cache key with jax.jit
. (=you can use the already-JIT-compiled version of this function, without having to recompile). And JAX arrays aren't hashable.
In practice static_field
is really only useful for one thing: being able to use non-arrays in your model whilst still using the original jax.{jit, grad, ...}
. (As opposed to eqx.filter_{jit, grad, ...}
.) The Equinox filter_{jit, grad, ...}
operations are smart enough to filter out anything that isn't an array and just pass them through. The original JAX transformations instead try to cast everything to an array (and explode if you pass them non-arraylike objects).
That extra modicum of compatibility is the only reason that you see eqx.nn.Linear
(and friends) using eqx.static_field
.
Even if the field is used in
__call__
?
Yes, you can safely use static fields inside __call__
!
I'm confused about the static_field
behavior. My use case is maybe similar to @francois-rozet's: I need a model where I can freeze parameters, which are usually jax arrays. The filtering approach is too indirect for complex models, so I'd rather work with static_field
s. However, I assumed that eqx.partition
would move fields that have static defaults to the static tree. It does not. Instead the element is available in both.
class Test(eqx.Module):
a: jnp.array
b: jnp.array = eqx.static_field()
t = Test(jnp.zeros(4), jnp.ones(4, dtype=int))
@eqx.filter_jit
@eqx.filter_grad
def loss(t):
return jnp.sum((t.a - t.b)**2)
params, static = eqx.partition(t, eqx.is_array)
assert params.b is static.b
This seems counterintuitive, but not problematic as such. What really confuses me is that loss
actually returns gradients for t.b
, they just happen to be the same as t.b
itself, so the pytree is Test(a=nabla_a loss, b=b)
:
assert loss(t).b is t.b
I assume that the gradient wrt b
has never actually been computed (right?), but that return is a problem for downstream use of this gradient. The same happens with standard jax.grad
, btw.
Also, about your statement:
JAX arrays can never be
static_field
s.
The above works for me...
I appreciate your guidance above. But can you provide some further clarity what static_field
actually does, specifically for standard jax types?
In JAX, a PyTree consists of two parts: the tree structure, and the leaves. For example, the PyTree ["hi", 2, (jnp.array(3.),)]
has structure [*, *, (*,)]
, and leaves "hi", 2, jnp.array(3.)
.
Each attribute of an eqx.Module
can either be thought of as a leaf, or as part of the structure. By default it's a leaf, but if you want to then you make it part of the structure, by adding a static_field
declaration. Using your Test
example: Test(a=jnp.array(3.), b=4.0)
will have structure Test(a=*, b=4.0)
, and the single leaf jnp.array(3.)
.
The static pieces don't interact with any JAX transformation at all: they won't be JIT'd, can't be differentiated. This is the reason that in your example, you get a gradient of Test(a=nabla_a loss, b=b)
with b
the same thing as the original input: JAX doesn't ever interact with this value, and treats it as part of the structure.
JAX arrays can never be static_fields.
The above works for me...
That's an odd quirk of JAX/NumPy, due to using a scalar array outside of JIT. (This specific combination means that bool(jnp.array(4) == jnp.array(4))
will not throw an error.)
I've just updated the static_field
documentation to try and more strongly warn against using it.
I'm still not clear on what is intended to be the correct way to avoid static_field
? If you want to have as above
class Test(eqx.Module):
a: jnp.array
b: jnp.array
but with b static in __call__
but not use the static_field
.
I think you are saying that we should be creating a custom filter_something decorator and to decorate __call__
with? And the filter should be a filter by name.
And I think to do that we need to fully understand and replicate the logic of filter_grad swapping out type heuristics for some other partitioning logic?
... reading a bit more a reminding myself about jax I suspect the core problem is that one can not really have "by name" filters but rather "by type".
I am now suggesting perhaps this is the simlpest fix (below)?
class MyArray(jnp.array):
pass
class Test(eqx.Module):
a: jnp.array
b: MyArray = eqx.static_field()
If you want to have as above but with b static in
__call__
but not use thestatic_field
.
What is you're looking to accomplish here, precisely? "Static" basically just means "invisible to every JAX transformation", i.e. a static_argnum
to jax.jit
, like in_axes=None
for jax.vmap
, no gradients when computed through jax.grad
, etc. Which of these transformations are you looking to be static with respect to?
In particular, you should essentially never want a JAX array to be static wrt jax.jit
, so I suspect the answer isn't "all of them".
Regarding your code at the bottom: you can't subclass JAX arrays, so don't try to do that. One similar thing you can do is to place them in a wrapper though, see this example. (This is the usual way to get custom per-parameter behaviour.)
If you want to have as above but with b static in
__call__
but not use thestatic_field
.What is you're looking to accomplish here, precisely? "Static" basically just means "invisible to every JAX transformation", i.e. a
static_argnum
tojax.jit
, likein_axes=None
forjax.vmap
, no gradients when computed throughjax.grad
, etc. Which of these transformations are you looking to be static with respect to?In particular, you should essentially never want a JAX array to be static wrt
jax.jit
, so I suspect the answer isn't "all of them".Regarding your code at the bottom: you can't subclass JAX arrays, so don't try to do that. One similar thing you can do is to place them in a wrapper though, see this example. (This is the usual way to get custom per-parameter behaviour.)
This example came to me from a friend who was hitting this. I think their example was some kind of complicated structure where the static thing was a grid or something that happened to be an array. Something that was not being differentiated. Almost like data but static perhaps. Put the array in a Module can work as well but I think the custom Array is the most minimal. Putting the array to a list and casting works but is slow I guess.
The correct way might be filters but that is not well understood either (how to create custom filters based on names of things ... which I think might not be possible as my impression is that filtering is based on jax stuff which only allows conditions based on type).
UPDATE:
Ok, I think I've understood some of this more ... at least for the dynamic case you might want this kind of pattern
class Test(eqx.Module):
a: jnp.array
b: jnp.array
def __call__(self):
return jnp.sum(self.a + 3.0 * self.b**2)
def loss(model):
return model()
key = jax.random.PRNGKey(0)
a = jax.random.normal(key, (5,))
b = jax.random.normal(key, (5,))
model = Test(a, b)
def partially_static(fun, filter_spec):
def inner(model, *args, **kwargs):
@eqx.filter_grad
def inner_(diff_model, static_model, *args, **kwargs):
model = eqx.combine(diff_model, static_model)
return fun(model, *args, **kwargs)
diff_model, static_model = eqx.partition(model, filter_spec)
return inner_(diff_model, static_model, *args, **kwargs)
return inner
filter_spec = jax.tree_map(lambda _: True, model)
object.__setattr__(filter_spec, 'b', False)
dloss = eqx.filter_grad(loss)
dloss_partial = partially_static(loss, filter_spec)
print(dloss(model))
print(dloss_partial(model))
# results in
# Test(a=f64[5], b=f64[5])
# Test(a=f64[5], b=None)
from https://docs.kidger.site/equinox/examples/frozen_layer/
I'm not sure if there is a more standard way of creating the filter spec that directly reaching for object.__setattr__
.
I'm not sure if there is a more standard way of creating the filter spec that directly reaching for
object.__setattr__
.
Yup, use eqx.tree_at
(which is also used in the example you link).
Other than that, I think what you've done here looks pretty reasonable.
It sounds like your particular case is that you want your array not to be differentiated? In that case you could consider using lax.stop_gradient
:
class Test(eqx.Module):
a: jnp.array
b: jnp.array
def __call__(self):
b = lax.stop_gradient(self.b)
return jnp.sum(self.a + 3.0 * b**2)
In addition you might like to use eqx.filter_jit(donate=...)
to avoid spurious copies if your array is an unmodified input and output of your computation.
Hi,
Most of the
nn.Modules
(MLP) usestatic_fields
for, well seemingly static attributes. In the documentation it is stated thatstatic_field
should be used rarely. Is there a do's and don'ts for their usage? As I understand these can be filtered out with*filter
calls.Thanks.