shyamsn97 / hyper-nn

Easy Hypernetworks in Pytorch and Jax
MIT License
94 stars 6 forks source link

How to generate model parameters based on input "x"? #6

Open leitro opened 1 year ago

leitro commented 1 year ago

Hi Shyam! Thanks for providing the easy-to-use wrapper for hypernetworks, it's amazing!

One question: The parameters are generated by function generate_params https://github.com/shyamsn97/hyper-nn/blob/276572816c6b9fa1cde7b1c4a05b9961d1e2b602/hypernn/torch/linear_hypernet.py#L56-L61

Thus the starting points are always random embeddings, right? If I want to make use of the input data x also as the input of hypernetwork, what should I do? Could you please kindly outline a bit?

shyamsn97 commented 11 months ago

Hey! Thanks for the interest in the library! I think what you’re looking for is in the DynamicHypernetwork functionality, I’ll post an example here soon

leitro commented 11 months ago

Thanks for the reference, one question: Is there anything specially designed here to use rnn_cell to deal with the input? https://github.com/shyamsn97/hyper-nn/blob/276572816c6b9fa1cde7b1c4a05b9961d1e2b602/hypernn/torch/dynamic_hypernet.py#L40-L43 Is it possible to replace the rnn cell with a simple linear layer here? Any suggestion? Thanks!

shyamsn97 commented 11 months ago

Hey! Yeah definitely you can replace the rnn with a linear layer. I added the rnn cell mainly to replicate the original work: https://blog.otoro.net/2016/09/28/hyper-networks/, but you can basically add whatever you want in it. Here's an example right here:

from typing import Optional, Iterable, Any, Tuple, Dict
import torch
import torch.nn as nn
# static hypernetwork
from hypernn.torch import TorchHyperNetwork
from hypernn.torch.utils import get_weight_chunk_dims

class DynamicLinearHypernetwork(TorchHyperNetwork):
    def __init__(
        self,
        inp_dims: int,
        target_network: nn.Module,
        num_target_parameters: Optional[int] = None,
        embedding_dim: int = 100,
        num_embeddings: int = 3,
        weight_chunk_dim: Optional[int] = None,
    ):
        super().__init__(
                    target_network = target_network,
                    num_target_parameters = num_target_parameters,
                )
        self.inp_dims = inp_dims
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.weight_chunk_dim = weight_chunk_dim

        if weight_chunk_dim is None:
            self.weight_chunk_dim = get_weight_chunk_dims(
                self.num_target_parameters, num_embeddings
            )

        self.embedding_module = self.make_embedding_module()
        self.weight_generator = self.make_weight_generator()    
        self.inp_embedder = nn.Linear(self.inp_dims, self.num_embeddings)

    def make_embedding_module(self) -> nn.Module:
        return nn.Embedding(self.num_embeddings, self.embedding_dim)

    def make_weight_generator(self) -> nn.Module:
        return nn.Linear(self.embedding_dim, self.weight_chunk_dim)

    def generate_params(
        self, inp: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        embedded_inp = self.inp_embedder(inp).view(self.num_embeddings, -1)
        embedding = self.embedding_module(
            torch.arange(self.num_embeddings, device=self.device)
        ) * embedded_inp
        generated_params = self.weight_generator(embedding).view(-1)
        return generated_params, {"embedding": embedding}

    # usage
target_network = nn.Sequential(
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Linear(64, 32)
)

INP_DIM = 32
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = DynamicLinearHypernetwork(
    inp_dims = INP_DIM,
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)
inp = torch.zeros((1, 32))

out = hypernetwork(inp, generate_params_kwargs=dict(inp=inp))
print(out.shape)