Open leitro opened 1 year ago
Hey! Thanks for the interest in the library! I think what you’re looking for is in the DynamicHypernetwork functionality, I’ll post an example here soon
Thanks for the reference, one question: Is there anything specially designed here to use rnn_cell
to deal with the input?
https://github.com/shyamsn97/hyper-nn/blob/276572816c6b9fa1cde7b1c4a05b9961d1e2b602/hypernn/torch/dynamic_hypernet.py#L40-L43
Is it possible to replace the rnn cell with a simple linear layer here? Any suggestion? Thanks!
Hey! Yeah definitely you can replace the rnn with a linear layer. I added the rnn cell mainly to replicate the original work: https://blog.otoro.net/2016/09/28/hyper-networks/, but you can basically add whatever you want in it. Here's an example right here:
from typing import Optional, Iterable, Any, Tuple, Dict
import torch
import torch.nn as nn
# static hypernetwork
from hypernn.torch import TorchHyperNetwork
from hypernn.torch.utils import get_weight_chunk_dims
class DynamicLinearHypernetwork(TorchHyperNetwork):
def __init__(
self,
inp_dims: int,
target_network: nn.Module,
num_target_parameters: Optional[int] = None,
embedding_dim: int = 100,
num_embeddings: int = 3,
weight_chunk_dim: Optional[int] = None,
):
super().__init__(
target_network = target_network,
num_target_parameters = num_target_parameters,
)
self.inp_dims = inp_dims
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.weight_chunk_dim = weight_chunk_dim
if weight_chunk_dim is None:
self.weight_chunk_dim = get_weight_chunk_dims(
self.num_target_parameters, num_embeddings
)
self.embedding_module = self.make_embedding_module()
self.weight_generator = self.make_weight_generator()
self.inp_embedder = nn.Linear(self.inp_dims, self.num_embeddings)
def make_embedding_module(self) -> nn.Module:
return nn.Embedding(self.num_embeddings, self.embedding_dim)
def make_weight_generator(self) -> nn.Module:
return nn.Linear(self.embedding_dim, self.weight_chunk_dim)
def generate_params(
self, inp: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, Any]]:
embedded_inp = self.inp_embedder(inp).view(self.num_embeddings, -1)
embedding = self.embedding_module(
torch.arange(self.num_embeddings, device=self.device)
) * embedded_inp
generated_params = self.weight_generator(embedding).view(-1)
return generated_params, {"embedding": embedding}
# usage
target_network = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 32)
)
INP_DIM = 32
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32
hypernetwork = DynamicLinearHypernetwork(
inp_dims = INP_DIM,
target_network = target_network,
embedding_dim = EMBEDDING_DIM,
num_embeddings = NUM_EMBEDDINGS
)
inp = torch.zeros((1, 32))
out = hypernetwork(inp, generate_params_kwargs=dict(inp=inp))
print(out.shape)
Hi Shyam! Thanks for providing the easy-to-use wrapper for hypernetworks, it's amazing!
One question: The parameters are generated by function
generate_params
https://github.com/shyamsn97/hyper-nn/blob/276572816c6b9fa1cde7b1c4a05b9961d1e2b602/hypernn/torch/linear_hypernet.py#L56-L61Thus the starting points are always random embeddings, right? If I want to make use of the input data
x
also as the input of hypernetwork, what should I do? Could you please kindly outline a bit?