Open WangJYao opened 1 year ago
This is a minimal working example similar to the convnet in the documentation
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import hyperlight as hl
class MLP(nn.Sequential):
def __init__(self, in_features: int, out_features: int, hidden: list[int]):
super().__init__()
for n_in, n_out in zip([in_features]+hidden, hidden):
self.append(nn.Linear(n_in, n_out))
self.append(nn.LeakyReLU())
out = nn.Linear(hidden[-1], out_features)
self.append(out)
class HyperMLP(nn.Module):
def __init__(self):
super().__init__()
mainnet = MLP(100, 10, [32, 64])
modules = hl.find_modules_of_type(mainnet, [nn.Linear])
self.mainnet = hl.hypernetize(mainnet, modules=modules)
parameter_shapes = self.mainnet.external_shapes()
self.hypernet = hl.HyperNet(
input_shapes={'h': (10,)},
output_shapes=parameter_shapes,
hidden_sizes=[16,64,128],
)
def forward(self, main_input, hyper_input):
parameters = self.hypernet(h=hyper_input)
with self.mainnet.using_externals(parameters):
prediction = self.mainnet(main_input)
return prediction
x = torch.randn(7,100)
h = torch.randn(10)
model = HyperMLP()
print(model(x, h).shape)
# torch.Size([7, 10])
The main consideration is that hypernetizing MLPs you can end up with many parameters as there is no weight reuse, you might want to consider only hypernetizing some layers
Hi, Thanks for open-sourcing this repository! If the main architecture is MLP, how should I set the parameters of the HyperNet ? or what should I pay attention to ?