patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

[Question] Modifying the Static Variable in a Model #806

Open ahahn2813 opened 2 months ago

ahahn2813 commented 2 months ago

Hello,

My advisor and I are attempting to do something non-typical with Equinox where we are trying to figure out how to change the static part of the neural network architecture on the fly.

Suppose we seek to learn the best activation function for a single layer of our neural network (all other architecture features are pre-chosen). The line of code below:

params,static = equinox.partition(model,equinox.is_array)

allows one to separate the model into a “static” variable and a “params” variable. The params variable contains all the weights, but the static variable contains the information regarding the activation function for the layer. As we will need to change the activation function when we update the architecture, we would like to know if it is possible to make modifications within the static variable? In other words, is there an easy way to convert the static variable to a params variable and then back to a static variable?

One way we can think of it is to convert part of the static variable into an array so we can modify it, but we do not know how to convert back to the static variable once it has been changed to an array. Thank you!

lockwo commented 2 months ago

If you just want to update one member variable of the module, you can just use a tree at:

import jax
from jax import numpy as jnp
import equinox as eqx
from typing import Callable

class NN(eqx.Module):
  w: jax.Array
  b: jax.Array
  act_fn: Callable

  def __call__(self, x):
    return self.act_fn(self.w @ x + self.b)

net = NN(jnp.ones((10, 10)), jnp.ones(10), jax.nn.relu)
print(net(jnp.ones(10)))
print(net)
net = eqx.tree_at(lambda x: x.act_fn, net, jax.nn.sigmoid)
print(net(jnp.ones(10)))
print(net)

where the act_fn would be partitioned into static in your above code

krm9c commented 2 months ago

Thank you for your response, just so I understand properly. I can define variables inside the class corresponding to each quantity that I want dynamically change. For instance, number of layers, activation function and run a eqx.tree_at loop to identify what these values should be replaced as. For pseudocode

import jax
from jax import numpy as jnp
import equinox as eqx
from typing import Callable

class NN(eqx.Module):
  w: jax.Array
  b: jax.Array
  act_fn: Callable
  width: float

  def __init__(width, act):
         self.act_fn = act
         self.width = width

  def reintialize():
        self.w = ...
        self.b = ....
  def __call__(self, x):
    return self.act_fn(self.w @ x + self.b)

net = NN(width = 10, act =  jax.nn.relu)
print(net(jnp.ones(10)))
print(net)

for training network architecture loop:

         net = eqx.tree_at(lambda x: x.width, net, 5)
         net = eqx.tree_at(lambda x: x.act_fn, net, jax.nn.sigmoid)

         for training network weights loop
                 .....

As long as I have the right variable names within my net class i would be able to assign them on the fly with .tree_at(). The pseudo code might be crude but, have I understood the way you meant it?

lockwo commented 2 months ago

Sure, that would work. Although it seems like width impacts/determines other variables (such as w), but the code you run would work.

krm9c commented 2 months ago

This is exactly what we want to do, on the fly determine the architecture/hyperparameter.

We are trying to build some sort of neural architecture/hyperparameter search setup with equinox. This would be helpful in this regard.

Equinox is a wonderful library. Thank you for maintaining it and working on this. Thank you very much.

lockwo commented 2 months ago

I just answer a few issues, all credit goes to Patrick