jaxleyverse / jaxley

Differentiable neuron simulations with biophysical detail on CPU, GPU, or TPU.
https://jaxley.readthedocs.io
Apache License 2.0
64 stars 10 forks source link

A global state for Jaxley modules #476

Open michaeldeistler opened 1 month ago

michaeldeistler commented 1 month ago

Just playing around with this idea for now:

net = jx.Network(...)
net.make_trainable("radius")

# First, we can run `net.set()` between jx.integrate calls and it
# will not recompile.
@jax.jit
def simulate(trainables, others):
    return jx.integrate(trainables, others)

# First run, needs compilation.
v1 = simulate(net.trainables, net.others)

# Now, modify the module, but we need no re-compilation!
net.set("HH_gNa", 0.2)
v2 = simulate(net.trainables, net.others)

# Second, there is no more need for `.data_set()` or `.data_stimulate()`!
# We can just `.set()` or `.stimulate()`.
def modified_loss(value, trainables, others):
    net.set("HH_gK", value)
    return loss_fn(trainables, net.others)

gradient_fn = grad(modified_loss, argnums=(0, 1))
grad_val = gradient_fn(2.0, net.trainables, net.others)

# Importing the functions from a Python module also becomes much easier.
# This did not work previously because the `loss_fn` would rely on a
# net being in global scope.
# E.g., in `myfile.py`:
def loss_fn(trainables, others):
    return jnp.sum(jx.integrate(trainables, others))

# ...and in the jupyter notebook:
from myfile import loss_fn
loss = loss_fn(net.trainables, net.others)

# We also support jx.rebuild(others) if the `net` itself is still to be
# modified within the Python module. E.g., in `myfile.py`:
def modified_loss_in_module(value, trainables, others):
    net = jx.rebuild(others)  # Can also be achieved by reading a pickled net.
    net.set("HH_gK", value)
    return loss_fn(trainables, net.others)

# ...and the following in a jupyter notebook.
from myfile import modified_loss_in_module
gradient_fn = grad(modified_loss_in_module, argnums=(0, 1))
grad_val = gradient_fn(2.0, net.trainables, net.others)
michaeldeistler commented 1 month ago

Even better would be if the net itself were a pytree that can be passed around:

net = jx.Network(...)
net.make_trainable("radius")

# First, we can run `net.set()` between jx.integrate calls and it
# will not recompile.
@jax.jit
def simulate(net):
    return jx.integrate(net)

# First run, needs compilation.
v1 = simulate(net)

# Now, modify the module, but we need no re-compilation!
net.set("HH_gNa", 0.2)
v2 = simulate(net)

# Second, there is no more need for `.data_set()` or `.data_stimulate()`!
# We can just `.set()` or `.stimulate()`.
def modified_loss(net, value):
    net.set("HH_gK", value)
    return loss_fn(net)

gradient_fn = grad(modified_loss, argnums=(0, 1))
grad_val = gradient_fn(net, 2.0)

# Importing the functions from a Python module also becomes much easier.
# This did not work previously because the `loss_fn` would rely on a
# net being in global scope.
# E.g., in `myfile.py`:
def loss_fn(net):
    return jnp.sum(jx.integrate(net))

# ...and in the jupyter notebook:
from myfile import loss_fn
loss = loss_fn(net)

# We also support jx.rebuild(others) if the `net` itself is still to be
# modified within the Python module. E.g., in `myfile.py`:
def modified_loss_in_module(net, value):
    net.set("HH_gK", value)
    return loss_fn(net)

# ...and the following in a jupyter notebook.
from myfile import modified_loss_in_module
gradient_fn = grad(modified_loss_in_module, argnums=(0, 1))
grad_val = gradient_fn(net, 2.0)

# Finally, following inox (and more reminiscent of the current API), one
# can also split (or partition) the `net`:
static, params, others = model.partition(nn.Parameter)

def loss_fn(params, others):
    net = static(params, others)
    return jnp.sum(jx.integrate(net))

# If one wanted to change parameters within the loss function in this
# interface, one would do:
static, params, others = model.partition(nn.Parameter)

def loss_fn(params, others, values):
    net = static(params, others)
    net.set("radius", value)
    return jnp.sum(jx.integrate(net))
michaeldeistler commented 1 month ago

For this, we should consider relying on inox or on flax nnx. IMO they are, at this point, quite similar in their API. Obvisouly flax nnx is of course much larger and will surely be maintained, whereas inox is minimal and could allow us to actually get into the weeds.

A small benefit of relying on flax nnx would be that users of flax nnx will already be familiar with part of the API of Jaxley.

jnsbck commented 3 weeks ago

Nice! I like the idea! Lemme know if you want to brainstorm specifics! :)