google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

Train multiple hk.nets.MLP with one optimizer #766

Closed rsmath closed 7 months ago

rsmath commented 9 months ago

Hello,

I am trying to train two neural networks simultaneously with one optimizer. In PyTorch, this is trivial since model.parameters() can be concatenated and passed to the optimizer. How do I accomplish this in general in haiku? This is assuming I have two parameter variables (one for each network) from the two networks' individual init functions (I also have the accompanying apply functions).

Thank you.

Ekundayo39283 commented 7 months ago

U can try the below format to see if it works for you

import haiku as hk

# Define your two networks
class Network1(hk.Module):
    def __call__(self, x):
        # Define network 1 architecture
        return output

class Network2(hk.Module):
    def __call__(self, x):
        # Define network 2 architecture
        return output

# Create instances of your networks
net1 = Network1()
net2 = Network2()

# Initialize parameters for both networks
params1 = net1.init(rng_key, input_shape1)
params2 = net2.init(rng_key, input_shape2)

# Apply parameters to create callable modules
net1_apply = hk.transform_with_state(net1.apply)
net2_apply = hk.transform_with_state(net2.apply)

# Concatenate parameters into a single list
all_params = list(params1.values()) + list(params2.values())

# Pass concatenated parameters to optimizer
optimizer = optim.Optimizer(learning_rate)
opt_state = optimizer.init(all_params)
rsmath commented 7 months ago

Yes, I managed to make one pytree of parameters to pass to the optimizer and it has been working fine. Thank you.