Open raj-brown opened 12 months ago
Place your two neural networks together into a tuple:
two_networks = (nn1, nn2)
optimiser.init(eqx.filter(two_networks, eqx.is_inexact_array))
In Equinox, neural networks are just pytrees. Tuples are also pytrees! As far as Optax is concerned there's no difference.
Thank you @patrick-kidger . I had another question. I want to have a pytree
copy of the neural neural network as I have to change the value of parameters in this pytree at every epoch. How can I do that. Thank you very much!
You can make a copy of a pytree by doing:
import jax.tree_util as jtu
pytree = ...
leaves, treedef = jtu.tree_flatten(pytree)
pytree_copy = jtu.tree_unflatten(treedef, leaves)
Thank you very much @patrick-kidger. I really appreciate. A big thank you for creating Equinox. It is awesome.
Thanks!
Hi @patrick-kidger I had another question, I have to use jvp
with primal being neural network parameters and tangent being of same type and as shape of neural network parameters but will change at the value of tangent at every epoch.
I tried
l_val, grad = eqx.filter_jvp(l_fn, (net1,), (net2,))
It throws error as it does not recognize net1 as PyTree. I am not sure why eqx.filter_jvp
does not apply for jax array of net1 or net2. I will appreciate your help.
Thanks!
Please provide a MWE.
Hi @patrick-kidger any suggestion or help will be great. Thank you!
I'm glad you got it working, but absent a MWE of the original problem there isn't much I can do.
hi @patrick-kidger Sure. I will prepare one and put it here. On that note, do you have any suggestion to know how much memory a jetted function is using? Some sort of profiling the equinox based code for memory and flops. Thank you!
Hi @patrick-kidger I want to take the jvp of loss function with respect to nn parameters along the random direction..Here is my code to do that
@eqx.filter_jit
def train_step_fwg(network, state, sub_key):
is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_weights = lambda m: [x.weight
for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
if is_linear(x)]
get_bias = lambda m: [x.bias
for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
if is_linear(x)]
weights = get_weights(network)
biases = get_bias(network)
if dist_type == "normal":
new_weights = [r * jax.random.normal(subkey, weight.shape)
for weight, subkey in zip(weights, jax.random.split(sub_key, len(weights)))]
new_biases = [r * jax.random.normal(subkey, bias.shape)
for bias, subkey in zip(biases, jax.random.split(sub_key, len(biases)))]
new_model = eqx.tree_at(get_weights, network, new_weights)
new_model = eqx.tree_at(get_bias, new_model, new_biases)
params, static = eqx.partition(network, eqx.is_array)
params1, static1 = eqx.partition(new_model, eqx.is_array)
l_fn = lambda p: loss_final(p, static)
l_val, grad = eqx.filter_jvp(l_fn, (params,), (params1,))
fw_grad = jax.tree_map(lambda p: grad * p, params1)
updates, new_state = optimizer.update(fw_grad, state, network)
network = eqx.apply_updates(network, updates)
return network, new_state, l_val
The shape of tangent is shape as primal.There fore I extract the weights and biases and for tangent I initial tangen with normal distribution. The method I adopt is bit longer..Is there anyway to do this through short codes. Thank you very much!
@patrick-kidger Hi Patrick any help on this issue will be really helpful..Thank you so much!
So I notice that you appear to be reusing sub_key
: you should split this before using it with new_weights
and new_biases
.
Other than that, what you've written looks reasonable.
Thanks @patrick-kidger. In fact I split them this my for loop in driver script.
key = jax.random.PRNGKey(7)
sub_key = jr.split(key, N_EPOCHS)
print(sub_key)
sys.exit()
key_count = 0
counter = tqdm(np.arange(N_EPOCHS))
for epoch in counter:
# pinn, opt_state, loss = train_step_opt(pinn, opt_state)
pinn, opt_state, loss = train_step_fwg(pinn, opt_state, sub_key[key_count])
key_count = key_count + 1
Also my Neural Net class is defined as follows
class NeuralNetwork(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3, key4, key5, key6, key7, key8, key9, _ = jax.random.split(key, 10)
# These contain trainable parameters.
self.layers = [eqx.nn.Linear(2, 20, key=key1),
eqx.nn.Linear(20, 20, key=key2),
eqx.nn.Linear(20, 20, key=key3),
eqx.nn.Linear(20, 20, key=key4),
eqx.nn.Linear(20, 20, key=key5),
eqx.nn.Linear(20, 20, key=key6),
eqx.nn.Linear(20, 20, key=key7),
eqx.nn.Linear(20, 20, key=key8),
eqx.nn.Linear(20, 1, key=key9),
]
def __call__(self, x, t):
xt = jnp.hstack((x, t))
for layer in self.layers[:-1]:
xt = jax.nn.tanh(layer(xt))
return self.layers[-1](xt).reshape(())
In the call
function I have to invoke .reshape
as output is not scaler when I take the gradeint of output with input. Is this right to invoke rshape
?
Thank you very much! Regards Raj
Hi @patrick-kidger, I had a question on changing the sign of gradient. e.g.
data_f, data_bc, data_ic, data_cyl, params_ad)```
I want to change the sign of the gradient of only one parameter in ```grad```. How I can do that?
Thanks!
Given that you explicitly compute the gradient via jax.grad
/ eqx.filter_grad
/ etc, then you can just do something like grad = -grad
for an array. If you just want to flip it for a single scalar inside an array, you can do grad = grad.at[some_index].set(-grad[some_index])
. If you have a PyTree then can you can use the usual PyTree manipulation utilities (jax.tree_util.*
and eqx.tree_at
are the main ones) to help you out.
Dear All- I have a very simple question. I have two neural networks of type
MLP
and I want to initialize optimizer viaoptax
. When I have one neural network I do likepyt1 = eqx.filter(bn, eqx.is_array)
opt_state = optimizer.init(pyt)
Is there anyway I can concatenate parameters from both neural networks.Thank you very much! Raj