Open AlphaBetaGamma96 opened 2 years ago
To start off by answering your question, yes:
fnet = net.__call__
params = net
which obviously looks a bit trivial.
Any model is already a PyTree of parameters. In addition its forward pass as an unbound method is an unparameterised pure function. As a bound method it's a parameterised pure function.
In terms of calculating per-sample gradients, in JAX you can do this just by nesting jit-vmap-grad
rather than jit-grad-vmap
. Something like:
@eqx.filter_jit
@eqx.filter_vmap(args=(None, 0, 0))
@eqx.filter_grad
def loss(model, x, y):
return (model(x) - y) ** 2
model = eqx.nn.MLP(...)
per_sample_grads = loss(model, ...)
should compute per-sample gradients for you.
Hi @patrick-kidger, this is exactly what I was looking for! Thank you for the clear example code!
If jit-vmap-grad
computes per-sample gradients, I assume jit-grad-vmap
computes loss-gradients? (or something else?)
Thanks once again!
Yep, jit-grad-vmap
computes gradients in the "normal" way e.g. like PyTorch.
Regarding jit-grad-vmap
, I would remark that this literal chain of operations is syntactically incorrect. jax.grad
computes the gradient of a scalar output function w.r.t. inputs and does not work at all with non-scalar output functions (you'll get an exception). Since vmap
will turn any function (scalar or not) into a vector-valued one, you cannot just run jit-grad-vmap
. Patrick's answer does a mental leap of introducing some reduction operation like jit-grad-reduce-vmap
or, to highlight the order, jit(grad(reduce(vmap(func(inputs)))))
. With reduce == mean
, this gives you the standard .backwards()
of pytorch (assuming you were taking the mean loss along batch as well). This is in contrast to jit(vmap(grad(func(inputs))))
which is valid as long as we assume func
itself is a scalar function.
Hi @patrick-kidger,
Apologises for opening an issue as this is probably something I've missed, but is it at all possible to separate an
eqx.Module
object into a pure function and its parameters?For reference, as someone who's coming from PyTorch you can use
functorch.make_functional
which takes in annn.Module
and will return a pure function form of your module with all parameters. For example,Is there an equivalent function within
equinox
that mirror this behaviour? Something along the lines of,My use case is calculating per-sample gradients of an arbitrary loss function with respect to the parameters of a network. So being able to functionalize my
eqx.Module
object would solve my problem! Also, would this work recursively for alleqx.Module
objects within aeqx.Module
object? For example, if I were to define custom layers aseqx.Module
objects, would I be able to extract all parameters of my network?Many Thanks!