Open ahahn2813 opened 2 months ago
If you just want to update one member variable of the module, you can just use a tree at:
import jax
from jax import numpy as jnp
import equinox as eqx
from typing import Callable
class NN(eqx.Module):
w: jax.Array
b: jax.Array
act_fn: Callable
def __call__(self, x):
return self.act_fn(self.w @ x + self.b)
net = NN(jnp.ones((10, 10)), jnp.ones(10), jax.nn.relu)
print(net(jnp.ones(10)))
print(net)
net = eqx.tree_at(lambda x: x.act_fn, net, jax.nn.sigmoid)
print(net(jnp.ones(10)))
print(net)
where the act_fn
would be partitioned into static in your above code
Thank you for your response, just so I understand properly. I can define variables inside the class corresponding to each quantity that I want dynamically change. For instance, number of layers, activation function and run a eqx.tree_at loop to identify what these values should be replaced as. For pseudocode
import jax
from jax import numpy as jnp
import equinox as eqx
from typing import Callable
class NN(eqx.Module):
w: jax.Array
b: jax.Array
act_fn: Callable
width: float
def __init__(width, act):
self.act_fn = act
self.width = width
def reintialize():
self.w = ...
self.b = ....
def __call__(self, x):
return self.act_fn(self.w @ x + self.b)
net = NN(width = 10, act = jax.nn.relu)
print(net(jnp.ones(10)))
print(net)
for training network architecture loop:
net = eqx.tree_at(lambda x: x.width, net, 5)
net = eqx.tree_at(lambda x: x.act_fn, net, jax.nn.sigmoid)
for training network weights loop
.....
As long as I have the right variable names within my net class i would be able to assign them on the fly with .tree_at(). The pseudo code might be crude but, have I understood the way you meant it?
Sure, that would work. Although it seems like width impacts/determines other variables (such as w), but the code you run would work.
This is exactly what we want to do, on the fly determine the architecture/hyperparameter.
We are trying to build some sort of neural architecture/hyperparameter search setup with equinox. This would be helpful in this regard.
Equinox is a wonderful library. Thank you for maintaining it and working on this. Thank you very much.
I just answer a few issues, all credit goes to Patrick
Hello,
My advisor and I are attempting to do something non-typical with Equinox where we are trying to figure out how to change the static part of the neural network architecture on the fly.
Suppose we seek to learn the best activation function for a single layer of our neural network (all other architecture features are pre-chosen). The line of code below:
params,static = equinox.partition(model,equinox.is_array)
allows one to separate the model into a “static” variable and a “params” variable. The params variable contains all the weights, but the static variable contains the information regarding the activation function for the layer. As we will need to change the activation function when we update the architecture, we would like to know if it is possible to make modifications within the static variable? In other words, is there an easy way to convert the static variable to a params variable and then back to a static variable?
One way we can think of it is to convert part of the static variable into an array so we can modify it, but we do not know how to convert back to the static variable once it has been changed to an array. Thank you!