Open michaeldeistler opened 1 month ago
Even better would be if the net
itself were a pytree that can be passed around:
net = jx.Network(...)
net.make_trainable("radius")
# First, we can run `net.set()` between jx.integrate calls and it
# will not recompile.
@jax.jit
def simulate(net):
return jx.integrate(net)
# First run, needs compilation.
v1 = simulate(net)
# Now, modify the module, but we need no re-compilation!
net.set("HH_gNa", 0.2)
v2 = simulate(net)
# Second, there is no more need for `.data_set()` or `.data_stimulate()`!
# We can just `.set()` or `.stimulate()`.
def modified_loss(net, value):
net.set("HH_gK", value)
return loss_fn(net)
gradient_fn = grad(modified_loss, argnums=(0, 1))
grad_val = gradient_fn(net, 2.0)
# Importing the functions from a Python module also becomes much easier.
# This did not work previously because the `loss_fn` would rely on a
# net being in global scope.
# E.g., in `myfile.py`:
def loss_fn(net):
return jnp.sum(jx.integrate(net))
# ...and in the jupyter notebook:
from myfile import loss_fn
loss = loss_fn(net)
# We also support jx.rebuild(others) if the `net` itself is still to be
# modified within the Python module. E.g., in `myfile.py`:
def modified_loss_in_module(net, value):
net.set("HH_gK", value)
return loss_fn(net)
# ...and the following in a jupyter notebook.
from myfile import modified_loss_in_module
gradient_fn = grad(modified_loss_in_module, argnums=(0, 1))
grad_val = gradient_fn(net, 2.0)
# Finally, following inox (and more reminiscent of the current API), one
# can also split (or partition) the `net`:
static, params, others = model.partition(nn.Parameter)
def loss_fn(params, others):
net = static(params, others)
return jnp.sum(jx.integrate(net))
# If one wanted to change parameters within the loss function in this
# interface, one would do:
static, params, others = model.partition(nn.Parameter)
def loss_fn(params, others, values):
net = static(params, others)
net.set("radius", value)
return jnp.sum(jx.integrate(net))
For this, we should consider relying on inox
or on flax nnx
. IMO they are, at this point, quite similar in their API. Obvisouly flax nnx
is of course much larger and will surely be maintained, whereas inox
is minimal and could allow us to actually get into the weeds.
A small benefit of relying on flax nnx
would be that users of flax nnx will already be familiar with part of the API of Jaxley
.
Nice! I like the idea! Lemme know if you want to brainstorm specifics! :)
Just playing around with this idea for now: