patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

Guideline for using `static_fields`? #154

Open paganpasta opened 2 years ago

paganpasta commented 2 years ago

Hi,

Most of the nn.Modules (MLP) use static_fields for, well seemingly static attributes. In the documentation it is stated that static_field should be used rarely. Is there a do's and don'ts for their usage? As I understand these can be filtered out with *filter calls.

Thanks.

patrick-kidger commented 2 years ago

So there's basically two main patterns here.

  1. Never use static_field. Always filter things out using eqx.filter_{jit, grad, ...}.
  2. Use static_field on every leaf that isn't a floating-point JAX array, and then use jax.{jit, grad, ...}. This does require that (a) you have full control over the entire model PyTree (e.g. you're not building a library for another user); (b) you have no integer JAX arrays (which are neither differentiable nor static-able).

Approach 1 is what I recommend, as it's more general. Approach 2 is nice in that it works with native JAX operations, though.

FWIW eqx.nn currently follows neither of these, which is a choice I'm a little dissatisfied with. It certainly can't follow approach 2 because we don't know in advance e.g. whether eqx.nn.MLP().activation is a JAX type or not (or even whether it's a leaf type or not). It does however use some static_field attributes so that simple models like eqx.nn.Linear (but not eqx.nn.MLP) work with jax.{jit, grad, ...} without requiring the use of the filtered transformations. I originally felt like this might ease the new-user onboarding process, but I don't know if that's really true.

In an ideal world I'd arrange for jax.{jit, grad, ...} to mimic the default behaviour of eqx.filter_{jit, grad, ...}: automatically filter out non-JAX-types etc. * (Even if they don't offer customisable filtering like Equinox's do.) Then this distinction would be unimportant and we could avoid using static_field entirely. Maybe some day!

* The filtered transformations actually also fix a couple of other issues with the native transformations: jitting non-function callables like bound methods and class instances; avoiding re-jitting when callling jit(grad(f)) twice.

paganpasta commented 2 years ago

Thanks for the detailed response! I am happy to stick with filter_* for stuff built on top of equinox.

So probably the p field in nn.Dropout also requires a static_field for consistency with other nn.*? https://github.com/patrick-kidger/equinox/blob/a89d5b486d13588caffc095f172a2ec39fd68278/equinox/nn/dropout.py#L21

patrick-kidger commented 2 years ago

So this is deliberately non-static because the forward pass of nn.Dropout is fine with having that as a tracer.

In some sense the same is true of nn.Linear.in_features etc. as well -- these aren't used during the forward pass. But in some sense this isn't true of nn.Linear.in_features etc., as this corresponds to an array size and as such might be used in some shape arithmetic during a forward pass. (e.g. someone using nn.Linear.in_features as part of a .reshape()).

FWIW I am tempted to remove the static_field annotations from all of eqx.nn and see what happens. This debate over when to use static fields keeps coming up, and it doesn't really help that the built-in library is using them in ways we don't advise end users to.

paganpasta commented 2 years ago

Thanks! I agree on removing the staticfield and moving towards the `filter*` approach for everything. I can give it a shot and see if everything is still passing.

patrick-kidger commented 2 years ago

Sounds good!

paganpasta commented 2 years ago

Unsurprisingly couple of tests fail. I just wanted to make sure that the tests need fixing and its not the behaviour breaking. Small representative example of a failing test case adapted from test_filter_jit.test_filter_jit1.

import jax
import equinox as eqx

def test_no_static(getkey):

    def h(x):
        return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x)

    h = eqx.filter_jit(h, filter_spec=eqx.is_array_like)
    og_lin = eqx.nn.Linear(2, 2, key=getkey())
    _lin = jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, og_lin)
    new_lin = h(og_lin)
    assert new_lin == _lin

This assertion fails as _lin contains all is_array_like leaves but somewhere down eqx.filter_jit, new_lin seems to have undergone an is_array transformation. I tracked the calls up until https://github.com/google/jax/blob/d5fdd9e2664602430ea0ecfb59fdef7f8692862f/jax/_src/api.py#L525-L530 where, out_flat is

[
  DeviceArray([[-0.35848978, -0.23153338], [-0.5653428 , -0.46271053]], dtype=float32), 
  DeviceArray([0.21585053, 0.14375919], dtype=float32), 
  DeviceArray(2, dtype=int32, weak_type=True), 
  DeviceArray(2, dtype=int32, weak_type=True), 
  DeviceArray(True, dtype=bool, weak_type=True)
]

out_pytree_def is

PyTreeDef(
(
    CustomNode(
        <class 'equinox.nn.linear.Linear'>
                [(('weight', 'bias', 'in_features', 'out_features', 'use_bias'), (), ())], 
                [*, *, *, *, *]), 
    CustomNode(<class 'equinox.compile_utils.Static'>
                [
                    ((), ('value',), 
                    (Linear(
                      weight=None,
                      bias=None,
                      in_features=None,
                      out_features=None,
                      use_bias=None
                    ),))
                 ]
      , [ ])
  )
)

but out is

(Linear(
  weight=f32[2,2],
  bias=f32[2],
  in_features=i32[],
  out_features=i32[],
  use_bias=bool[]
), Static(
  value=Linear(
    weight=None,
    bias=None,
    in_features=None,
    out_features=None,
    use_bias=None
  )
))

Is this behaviour expected and jit is autmagically removing non-jax types?

paganpasta commented 2 years ago

A little update on this


import jax
import equinox as eqx

def test_no_static(getkey):

    def h(x):
        return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x)

    h = eqx.filter_jit(h, filter_spec=eqx.is_array_like)
    og_lin = eqx.nn.Linear(2, 2, key=getkey())
    _lin = jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, og_lin)
    new_lin = h(og_lin)

    assert new_lin != og_lin and type(new_lin) == type(og_lin)

    def H(x):
        return jax.tree_map(lambda u: u if eqx.is_array_like(u) else None, x)

    H = jax.jit(H)
    jit_lin = H(og_lin)
    assert jit_lin == new_lin and type(new_lin) == type(jit_lin)

This test passes. I suppose this is an expected behaviour and everything is working as intended?

patrick-kidger commented 2 years ago

So what's going on here is that the previously-static leaves are array-like, so they get selected for tracing by filter_jit(filter_spec=is_array_like). As a result they are converted into JAX arrays.

Meanwhile equality between Modules asks that all leaves have the same type, which of course is now no longer true.

So this is exactly the kind of breakage I was expecting, it's just a question of whether it's not-too-bad. This requires you to use filter_jit with a non-default filter spec (which at least in my own code is unusual).

paganpasta commented 2 years ago

Thanks for clarifying this.
How do you want to proceed with this? Also as a side note, it would be helpful to add examples using fn, out and default to the documentation of filter_jit.

patrick-kidger commented 2 years ago

Which tests actually break under this change? (+with what error if it's interesting.)

Also, you might try running the tests for Diffrax under that change and see if any of those are affected?

paganpasta commented 2 years ago

I only made changes within nn and left out rest, like experimental. These (2) test cases fail

assert Linear(weight=None, bias=None, in_features=1, out_features=1, use_bias=True) 
== Linear(\n  weight=None,\n  bias=None,\n  in_features=None,\n  out_features=None,\n  use_bias=None\n)
>       assert _eq(h2[5], _mlp)
E       assert False
E        +  where False = _eq(MLP(\n  layers=[\n    Linear(\n      weight=f32[2,2],\n      bias=f32[2],\n      in_features=i32[],\n 
     out_features=i32[],\n      use_bias=bool[]\n    ),\n    
Linear(\n      weight=f32[2,2],\n      bias=f32[2],\n      in_features=i32[],\n      out_features=i32[],\n      use_bias=bool[]\n    ),\n   
Linear(\n      weight=f32[2,2],\n      bias=f32[2],\n      in_features=i32[],\n      out_features=i32[],\n      use_bias=bool[]\n    )\n  ],\n
  activation=None,\n  final_activation=None,\n  in_size=i32[],\n  out_size=i32[],\n  width_size=i32[],\n  depth=i32[]\n), 
MLP(\n  layers=[\n    Linear(\n      weight=f32[2,2],\n      bias=f32[2],\n      in_features=2,\n      out_features=2,\n     
 use_bias=True\n    ),\n 
   Linear(\n      weight=f32[2,2],\n      bias=f32[2],\n      in_features=2,\n      out_features=2,\n      use_bias=True\n    ),\n
    Linear(\n      weight=f32[2,2],\n      bias=f32[2],\n      in_features=2,\n      out_features=2,\n      use_bias=True\n    )\n  ],\n
  activation=None,\n  final_activation=None,\n  in_size=2,\n  out_size=2,\n  width_size=2,\n  depth=2\n))

I'll try Diffrax as soon as I get the chance.

jiyuuchc commented 2 years ago

For me a common use of static_field() is for flow control. So please do not remove it entirely.

class LinearAndActivation(eqx.Module):
  m: eqx.Module
  apply_act: bool = eqx.static_field()

  def __init__(self, n_in, n_out, apply_act, key):
      self.m = eqx.nn.Linear(n_in, n_out, key=key)
      self.apply_act = apply_act

  @jax.jit
  def __call__(self, x):
      x = self.m(x)
      if self.apply_act:
          x = jax.nn.relu(x)
      return x
patrick-kidger commented 2 years ago

Don't worry, it definitely isn't going away. If nothing else it's a key part of how things work internally.

On balance I think I'm inclined to maintain the status quo. This clearly does break a couple of tests, and whilst they're pretty minor it's not encouraging. To rebut the argument (my argument) in favour of not using static_field: if an end-user overzealously applies a static_field then I don't think it should actually matter -- you can't smuggle in JAX arrays this way as they're not hashable. So extra static_fields should never break you silently.

paganpasta commented 2 years ago

The more I am interacting with equinox <=> jax/optax I am thinking having static_fields() is not that bad to have as default. Saves the trouble of repeating filtering at places.

francois-rozet commented 1 year ago

Hello, I am new to Equinox and I was wondering why it is necessary to declare all internal fields of modules before initializing them and why adding fields at runtime is impossible. I guess this is linked to them being PyTrees, but I believe PyTrees can grow "branches". Also, in my experience, modules usually have few parameter/module branches (e.g. weight and bias in a linear layer) and many static branches (hyper-parameters, settings, flags, etc). Hence, wouldn't it be easier to indicate the former (e.g. with eqx.Parameter and eqx.Module) and consider everything else as static, similar to PyTorch? This would eliminate the need for static_field and maybe ease the job of eqx.filter*.

patrick-kidger commented 1 year ago

Declaring the fields in advance is a syntax that we inherit from dataclasses. (Each eqx.Module is a dataclass.) Likewise, dataclasses don't allow adding additional fields at runtime.

In principle Equinox could have done something slightly different here -- not built on top of dataclasses and done our own thing instead. Which would have been fine too! (Although in practice I do particularly like not being able to add fields at runtime -- mutation is generally pretty dangerous and never actually needed, so this is a good way to avoid bugs.)

Regarding flipping static/dynamic: for basically any eqx.Modules that you ever write yourself, I'd recommend essentially never using eqx.static_field. Just leave every field dynamic. eqx.filter_{jit, grad, ...} will simply filter out anything that isn't an array, as usual.


To be honest, if I was going to do Equinox again, I would probably avoid the dynamic/static field distinction -- and just leave everything dynamic. Likewise I might not use dataclasses, and instead have some other syntax for initialisation. Realistically these aren't important enough to justify a breaking change at this point, though.

francois-rozet commented 1 year ago

will simply filter out anything that isn't an array, as usual.

Oh so there is never the need for static_field except for jax arrays that are not "parameters"? Even if the field is used in __call__?

patrick-kidger commented 1 year ago

jax arrays that are not "parameters"

JAX arrays can never be static_fields. A static field must be hashable, as it's used to form the cache key with jax.jit. (=you can use the already-JIT-compiled version of this function, without having to recompile). And JAX arrays aren't hashable.

In practice static_field is really only useful for one thing: being able to use non-arrays in your model whilst still using the original jax.{jit, grad, ...}. (As opposed to eqx.filter_{jit, grad, ...}.) The Equinox filter_{jit, grad, ...} operations are smart enough to filter out anything that isn't an array and just pass them through. The original JAX transformations instead try to cast everything to an array (and explode if you pass them non-arraylike objects).

That extra modicum of compatibility is the only reason that you see eqx.nn.Linear (and friends) using eqx.static_field.

Even if the field is used in __call__?

Yes, you can safely use static fields inside __call__!

pmelchior commented 1 year ago

I'm confused about the static_field behavior. My use case is maybe similar to @francois-rozet's: I need a model where I can freeze parameters, which are usually jax arrays. The filtering approach is too indirect for complex models, so I'd rather work with static_fields. However, I assumed that eqx.partition would move fields that have static defaults to the static tree. It does not. Instead the element is available in both.

class Test(eqx.Module):
    a: jnp.array
    b: jnp.array = eqx.static_field()

t = Test(jnp.zeros(4), jnp.ones(4, dtype=int))

@eqx.filter_jit
@eqx.filter_grad
def loss(t):
    return jnp.sum((t.a - t.b)**2)

params, static = eqx.partition(t, eqx.is_array)
assert params.b is static.b

This seems counterintuitive, but not problematic as such. What really confuses me is that loss actually returns gradients for t.b, they just happen to be the same as t.b itself, so the pytree is Test(a=nabla_a loss, b=b):

assert loss(t).b is t.b

I assume that the gradient wrt b has never actually been computed (right?), but that return is a problem for downstream use of this gradient. The same happens with standard jax.grad, btw.

Also, about your statement:

JAX arrays can never be static_fields.

The above works for me...

I appreciate your guidance above. But can you provide some further clarity what static_field actually does, specifically for standard jax types?

patrick-kidger commented 1 year ago

In JAX, a PyTree consists of two parts: the tree structure, and the leaves. For example, the PyTree ["hi", 2, (jnp.array(3.),)] has structure [*, *, (*,)], and leaves "hi", 2, jnp.array(3.).

Each attribute of an eqx.Module can either be thought of as a leaf, or as part of the structure. By default it's a leaf, but if you want to then you make it part of the structure, by adding a static_field declaration. Using your Test example: Test(a=jnp.array(3.), b=4.0) will have structure Test(a=*, b=4.0), and the single leaf jnp.array(3.).

The static pieces don't interact with any JAX transformation at all: they won't be JIT'd, can't be differentiated. This is the reason that in your example, you get a gradient of Test(a=nabla_a loss, b=b) with b the same thing as the original input: JAX doesn't ever interact with this value, and treats it as part of the structure.

JAX arrays can never be static_fields.

The above works for me...

That's an odd quirk of JAX/NumPy, due to using a scalar array outside of JIT. (This specific combination means that bool(jnp.array(4) == jnp.array(4)) will not throw an error.)


I've just updated the static_field documentation to try and more strongly warn against using it.

cottrell commented 11 months ago

I'm still not clear on what is intended to be the correct way to avoid static_field? If you want to have as above

class Test(eqx.Module):
    a: jnp.array
    b: jnp.array

but with b static in __call__ but not use the static_field.

I think you are saying that we should be creating a custom filter_something decorator and to decorate __call__ with? And the filter should be a filter by name.

And I think to do that we need to fully understand and replicate the logic of filter_grad swapping out type heuristics for some other partitioning logic?

... reading a bit more a reminding myself about jax I suspect the core problem is that one can not really have "by name" filters but rather "by type".

I am now suggesting perhaps this is the simlpest fix (below)?


class MyArray(jnp.array):
    pass

class Test(eqx.Module):
    a: jnp.array
    b: MyArray = eqx.static_field()
patrick-kidger commented 11 months ago

If you want to have as above but with b static in __call__ but not use the static_field.

What is you're looking to accomplish here, precisely? "Static" basically just means "invisible to every JAX transformation", i.e. a static_argnum to jax.jit, like in_axes=None for jax.vmap, no gradients when computed through jax.grad, etc. Which of these transformations are you looking to be static with respect to?

In particular, you should essentially never want a JAX array to be static wrt jax.jit, so I suspect the answer isn't "all of them".

Regarding your code at the bottom: you can't subclass JAX arrays, so don't try to do that. One similar thing you can do is to place them in a wrapper though, see this example. (This is the usual way to get custom per-parameter behaviour.)

cottrell commented 11 months ago

If you want to have as above but with b static in __call__ but not use the static_field.

What is you're looking to accomplish here, precisely? "Static" basically just means "invisible to every JAX transformation", i.e. a static_argnum to jax.jit, like in_axes=None for jax.vmap, no gradients when computed through jax.grad, etc. Which of these transformations are you looking to be static with respect to?

In particular, you should essentially never want a JAX array to be static wrt jax.jit, so I suspect the answer isn't "all of them".

Regarding your code at the bottom: you can't subclass JAX arrays, so don't try to do that. One similar thing you can do is to place them in a wrapper though, see this example. (This is the usual way to get custom per-parameter behaviour.)

This example came to me from a friend who was hitting this. I think their example was some kind of complicated structure where the static thing was a grid or something that happened to be an array. Something that was not being differentiated. Almost like data but static perhaps. Put the array in a Module can work as well but I think the custom Array is the most minimal. Putting the array to a list and casting works but is slow I guess.

The correct way might be filters but that is not well understood either (how to create custom filters based on names of things ... which I think might not be possible as my impression is that filtering is based on jax stuff which only allows conditions based on type).

UPDATE:

Ok, I think I've understood some of this more ... at least for the dynamic case you might want this kind of pattern

class Test(eqx.Module):
    a: jnp.array
    b: jnp.array

    def __call__(self):
        return jnp.sum(self.a + 3.0 * self.b**2)

def loss(model):
    return model()

key = jax.random.PRNGKey(0)
a = jax.random.normal(key, (5,))
b = jax.random.normal(key, (5,))
model = Test(a, b)

def partially_static(fun, filter_spec):
    def inner(model, *args, **kwargs):
        @eqx.filter_grad
        def inner_(diff_model, static_model, *args, **kwargs):
            model = eqx.combine(diff_model, static_model)
            return fun(model, *args, **kwargs)

        diff_model, static_model = eqx.partition(model, filter_spec)
        return inner_(diff_model, static_model, *args, **kwargs)

    return inner

filter_spec = jax.tree_map(lambda _: True, model)
object.__setattr__(filter_spec, 'b', False)

dloss = eqx.filter_grad(loss)
dloss_partial = partially_static(loss, filter_spec)

print(dloss(model))
print(dloss_partial(model))
# results in 
# Test(a=f64[5], b=f64[5])
# Test(a=f64[5], b=None)

from https://docs.kidger.site/equinox/examples/frozen_layer/

I'm not sure if there is a more standard way of creating the filter spec that directly reaching for object.__setattr__.

patrick-kidger commented 11 months ago

I'm not sure if there is a more standard way of creating the filter spec that directly reaching for object.__setattr__.

Yup, use eqx.tree_at (which is also used in the example you link).

Other than that, I think what you've done here looks pretty reasonable.

It sounds like your particular case is that you want your array not to be differentiated? In that case you could consider using lax.stop_gradient:

class Test(eqx.Module):
    a: jnp.array
    b: jnp.array

    def __call__(self):
        b = lax.stop_gradient(self.b)
        return jnp.sum(self.a + 3.0 * b**2)

In addition you might like to use eqx.filter_jit(donate=...) to avoid spurious copies if your array is an unmodified input and output of your computation.