JJGO / hyperlight

Modular and intuitive Hypernetworks in Pytorch
Apache License 2.0
32 stars 3 forks source link

the main architecture is MLP #3

Open WangJYao opened 1 year ago

WangJYao commented 1 year ago

Hi, Thanks for open-sourcing this repository! If the main architecture is MLP, how should I set the parameters of the HyperNet ? or what should I pay attention to ?

JJGO commented 1 year ago

This is a minimal working example similar to the convnet in the documentation

import torch
from torch import nn, Tensor
import torch.nn.functional as F
import hyperlight as hl

class MLP(nn.Sequential):

    def __init__(self, in_features: int, out_features: int, hidden: list[int]):
        super().__init__()

        for n_in, n_out in zip([in_features]+hidden, hidden):
            self.append(nn.Linear(n_in, n_out))
            self.append(nn.LeakyReLU())
        out = nn.Linear(hidden[-1], out_features)
        self.append(out)

class HyperMLP(nn.Module):

    def __init__(self):
        super().__init__()
        mainnet = MLP(100, 10, [32, 64])
        modules = hl.find_modules_of_type(mainnet, [nn.Linear])

        self.mainnet = hl.hypernetize(mainnet, modules=modules)
        parameter_shapes = self.mainnet.external_shapes()
        self.hypernet = hl.HyperNet(
            input_shapes={'h': (10,)},
            output_shapes=parameter_shapes,
            hidden_sizes=[16,64,128],
        )

    def forward(self, main_input, hyper_input):
        parameters = self.hypernet(h=hyper_input)

        with self.mainnet.using_externals(parameters):
            prediction = self.mainnet(main_input)

        return prediction

x = torch.randn(7,100)
h = torch.randn(10)
model = HyperMLP()
print(model(x, h).shape)
# torch.Size([7, 10])

The main consideration is that hypernetizing MLPs you can end up with many parameters as there is no weight reuse, you might want to consider only hypernetizing some layers