patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.05k stars 136 forks source link

How to initialize Optax Optimizer for two neural networks. #532

Open raj-brown opened 12 months ago

raj-brown commented 12 months ago

Dear All- I have a very simple question. I have two neural networks of type MLP and I want to initialize optimizer via optax. When I have one neural network I do like pyt1 = eqx.filter(bn, eqx.is_array) opt_state = optimizer.init(pyt) Is there anyway I can concatenate parameters from both neural networks.

Thank you very much! Raj

patrick-kidger commented 12 months ago

Place your two neural networks together into a tuple:

two_networks = (nn1, nn2)
optimiser.init(eqx.filter(two_networks, eqx.is_inexact_array))

In Equinox, neural networks are just pytrees. Tuples are also pytrees! As far as Optax is concerned there's no difference.

raj-brown commented 12 months ago

Thank you @patrick-kidger . I had another question. I want to have a pytree copy of the neural neural network as I have to change the value of parameters in this pytree at every epoch. How can I do that. Thank you very much!

patrick-kidger commented 12 months ago

You can make a copy of a pytree by doing:

import jax.tree_util as jtu
pytree = ...
leaves, treedef = jtu.tree_flatten(pytree)
pytree_copy = jtu.tree_unflatten(treedef, leaves)
raj-brown commented 12 months ago

Thank you very much @patrick-kidger. I really appreciate. A big thank you for creating Equinox. It is awesome.

Thanks!

raj-brown commented 12 months ago

Hi @patrick-kidger I had another question, I have to use jvp with primal being neural network parameters and tangent being of same type and as shape of neural network parameters but will change at the value of tangent at every epoch. I tried

l_val, grad = eqx.filter_jvp(l_fn, (net1,), (net2,))

It throws error as it does not recognize net1 as PyTree. I am not sure why eqx.filter_jvp does not apply for jax array of net1 or net2. I will appreciate your help.

Thanks!

patrick-kidger commented 12 months ago

Please provide a MWE.

raj-brown commented 11 months ago

Hi @patrick-kidger any suggestion or help will be great. Thank you!

patrick-kidger commented 11 months ago

I'm glad you got it working, but absent a MWE of the original problem there isn't much I can do.

raj-brown commented 11 months ago

hi @patrick-kidger Sure. I will prepare one and put it here. On that note, do you have any suggestion to know how much memory a jetted function is using? Some sort of profiling the equinox based code for memory and flops. Thank you!

raj-brown commented 9 months ago

Hi @patrick-kidger I want to take the jvp of loss function with respect to nn parameters along the random direction..Here is my code to do that

@eqx.filter_jit
        def train_step_fwg(network, state, sub_key):
            is_linear = lambda x: isinstance(x, eqx.nn.Linear)
            get_weights = lambda m: [x.weight
                                     for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                                     if is_linear(x)]

            get_bias = lambda m: [x.bias
                                  for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                                  if is_linear(x)]
            weights = get_weights(network)
            biases = get_bias(network)

            if dist_type == "normal":
                new_weights = [r * jax.random.normal(subkey, weight.shape)
                               for weight, subkey in zip(weights, jax.random.split(sub_key, len(weights)))]

                new_biases = [r * jax.random.normal(subkey, bias.shape)
                              for bias, subkey in zip(biases, jax.random.split(sub_key, len(biases)))]

            new_model = eqx.tree_at(get_weights, network, new_weights)
            new_model = eqx.tree_at(get_bias, new_model, new_biases)

            params, static = eqx.partition(network, eqx.is_array)
            params1, static1 = eqx.partition(new_model, eqx.is_array)

            l_fn = lambda p: loss_final(p, static)
            l_val, grad = eqx.filter_jvp(l_fn, (params,), (params1,))
            fw_grad = jax.tree_map(lambda p: grad * p, params1)
            updates, new_state = optimizer.update(fw_grad, state, network)
            network = eqx.apply_updates(network, updates)
            return network, new_state, l_val

The shape of tangent is shape as primal.There fore I extract the weights and biases and for tangent I initial tangen with normal distribution. The method I adopt is bit longer..Is there anyway to do this through short codes. Thank you very much!

raj-brown commented 9 months ago

@patrick-kidger Hi Patrick any help on this issue will be really helpful..Thank you so much!

patrick-kidger commented 9 months ago

So I notice that you appear to be reusing sub_key: you should split this before using it with new_weights and new_biases.

Other than that, what you've written looks reasonable.

raj-brown commented 9 months ago

Thanks @patrick-kidger. In fact I split them this my for loop in driver script.

        key = jax.random.PRNGKey(7)
        sub_key = jr.split(key, N_EPOCHS)
        print(sub_key)
        sys.exit()
        key_count = 0
        counter = tqdm(np.arange(N_EPOCHS))

        for epoch in counter:

            # pinn, opt_state, loss = train_step_opt(pinn, opt_state)
            pinn, opt_state, loss = train_step_fwg(pinn, opt_state, sub_key[key_count])
            key_count = key_count + 1

Also my Neural Net class is defined as follows

class NeuralNetwork(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4, key5, key6, key7, key8, key9, _ = jax.random.split(key, 10)
        # These contain trainable parameters.
        self.layers = [eqx.nn.Linear(2, 20, key=key1),
                       eqx.nn.Linear(20, 20, key=key2),
                       eqx.nn.Linear(20, 20, key=key3),
                       eqx.nn.Linear(20, 20, key=key4),
                       eqx.nn.Linear(20, 20, key=key5),
                       eqx.nn.Linear(20, 20, key=key6),
                       eqx.nn.Linear(20, 20, key=key7),
                       eqx.nn.Linear(20, 20, key=key8),
                       eqx.nn.Linear(20, 1, key=key9),
                       ]

    def __call__(self, x, t):
        xt = jnp.hstack((x, t))
        for layer in self.layers[:-1]:
            xt = jax.nn.tanh(layer(xt))
        return self.layers[-1](xt).reshape(())

In the call function I have to invoke .reshape as output is not scaler when I take the gradeint of output with input. Is this right to invoke rshape?

Thank you very much! Regards Raj

raj-brown commented 7 months ago

Hi @patrick-kidger, I had a question on changing the sign of gradient. e.g.


                                                     data_f, data_bc, data_ic, data_cyl, params_ad)```
I want to change the sign of the gradient of only one parameter in ```grad```. How I can do that?
Thanks!
patrick-kidger commented 7 months ago

Given that you explicitly compute the gradient via jax.grad / eqx.filter_grad / etc, then you can just do something like grad = -grad for an array. If you just want to flip it for a single scalar inside an array, you can do grad = grad.at[some_index].set(-grad[some_index]). If you have a PyTree then can you can use the usual PyTree manipulation utilities (jax.tree_util.* and eqx.tree_at are the main ones) to help you out.