patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.11k stars 142 forks source link

Equinox equivalent of FuncTorch's make_functional #112

Open AlphaBetaGamma96 opened 2 years ago

AlphaBetaGamma96 commented 2 years ago

Hi @patrick-kidger,

Apologises for opening an issue as this is probably something I've missed, but is it at all possible to separate an eqx.Module object into a pure function and its parameters?

For reference, as someone who's coming from PyTorch you can use functorch.make_functional which takes in an nn.Module and will return a pure function form of your module with all parameters. For example,

net = NetworkClass(*args, **kwargs)
fnet, params = functorch.make_functional(net)
y = fnet(params, x) 

Is there an equivalent function within equinox that mirror this behaviour? Something along the lines of,

net = NetworkClass(*args, **kwargs)
fnet, params = equinox.make_functional(net)
y = fnet(params, x)

My use case is calculating per-sample gradients of an arbitrary loss function with respect to the parameters of a network. So being able to functionalize my eqx.Module object would solve my problem! Also, would this work recursively for all eqx.Module objects within a eqx.Module object? For example, if I were to define custom layers as eqx.Module objects, would I be able to extract all parameters of my network?

Many Thanks!

patrick-kidger commented 2 years ago

To start off by answering your question, yes:

fnet = net.__call__
params = net

which obviously looks a bit trivial.

Any model is already a PyTree of parameters. In addition its forward pass as an unbound method is an unparameterised pure function. As a bound method it's a parameterised pure function.

In terms of calculating per-sample gradients, in JAX you can do this just by nesting jit-vmap-grad rather than jit-grad-vmap. Something like:

@eqx.filter_jit
@eqx.filter_vmap(args=(None, 0, 0))
@eqx.filter_grad
def loss(model, x, y):
    return (model(x) - y) ** 2

model = eqx.nn.MLP(...)
per_sample_grads = loss(model, ...)

should compute per-sample gradients for you.

AlphaBetaGamma96 commented 2 years ago

Hi @patrick-kidger, this is exactly what I was looking for! Thank you for the clear example code!

If jit-vmap-grad computes per-sample gradients, I assume jit-grad-vmap computes loss-gradients? (or something else?)

Thanks once again!

patrick-kidger commented 2 years ago

Yep, jit-grad-vmap computes gradients in the "normal" way e.g. like PyTorch.

jatentaki commented 2 years ago

Regarding jit-grad-vmap, I would remark that this literal chain of operations is syntactically incorrect. jax.grad computes the gradient of a scalar output function w.r.t. inputs and does not work at all with non-scalar output functions (you'll get an exception). Since vmap will turn any function (scalar or not) into a vector-valued one, you cannot just run jit-grad-vmap. Patrick's answer does a mental leap of introducing some reduction operation like jit-grad-reduce-vmap or, to highlight the order, jit(grad(reduce(vmap(func(inputs))))). With reduce == mean, this gives you the standard .backwards() of pytorch (assuming you were taking the mean loss along batch as well). This is in contrast to jit(vmap(grad(func(inputs)))) which is valid as long as we assume func itself is a scalar function.