shyamsn97 / hyper-nn

Easy Hypernetworks in Pytorch and Jax
MIT License
95 stars 5 forks source link
flax hypernetworks jax machine-learning neural-networks pytorch

hyper-nn -- Easy Hypernetworks in Pytorch and Flax

PyPi version

Note: This library is experimental and currently under development - the flax implementations in particular are far from perfect and can be improved. If you have any suggestions on how to improve this library, please open a github issue or feel free to reach out directly!

hyper-nn gives users with the ability to create easily customizable Hypernetworks for almost any generic torch.nn.Module from Pytorch and flax.linen.Module from Flax. Our Hypernetwork objects are also torch.nn.Modules and flax.linen.Modules, allowing for easy integration with existing systems. For Pytorch, we make use of the amazing library functorch

Generating Policy Weights for Lunar Lander



Dynamic Weights for each character in a name generator


Install

hyper-nn tested on python 3.8+

Installing with pip

$ pip install hyper-nn

Installing from source

$ git clone git@github.com:shyamsn97/hyper-nn.git
$ cd hyper-nn
$ python setup.py install

For gpu functionality with Jax, you will need to follow the instructions here


What are Hypernetworks?

Hypernetworks, simply put, are neural networks that generate parameters for another neural network. They can be incredibly powerful, being able to represent large networks while using only a fraction of their parameters.

Hypernetworks generally come in two variants, static or dynamic. Static Hypernetworks have a fixed or learned embedding and weight generator that outputs the target networks’ weights deterministically. Dynamic Hypernetworks instead receive inputs and use them to generate dynamic weights.

hyper-nn allows you to design Hypernetworks with flexibility and ease, all you have to do is implement the generate_params method, which outputs a parameter vector. We also include basic versions of this, composed of two linear components:

Both embedding_module and weight_generator are represented as torch.nn.Module and flax.linen.Module objects. a Module can be passed in as custom_embedding_module or custom_weight_generator, or it can be defined in the methods make_embedding_module or make_weight_generator.

The generate_params method feeds the output from embedding_module into weight_generator to output the target parameters.

The forward method takes in a list of inputs and uses the generated parameters to calculate the output. This method acts as the main method for both jax and torch hypernetworks

Torch Hypernetwork

...
  def make_embedding_module(self) -> nn.Module:
      return nn.Embedding(self.num_embeddings, self.embedding_dim)

  def make_weight_generator(self) -> nn.Module:
      return nn.Linear(self.embedding_dim, self.weight_chunk_dim)

  def generate_params(self, *args, **kwargs) -> Tuple[torch.Tensor, Dict[str, Any]]:
      embedding = self.embedding_module(
          torch.arange(self.num_embeddings, device=self.device)
      )
      generated_params = self.weight_generator(embedding).view(-1)
      return generated_params, {"embedding": embedding}

  def target_forward(
      self,
      *args,
      generated_params: torch.Tensor,
      assert_parameter_shapes: bool = True,
      **kwargs,
  ) -> torch.Tensor:
      if assert_parameter_shapes:
          self.assert_parameter_shapes(generated_params)

      return self.target_network(generated_params, *args, **kwargs)

  def forward(
      self,
      *args,
      generated_params: Optional[torch.Tensor] = None,
      has_aux: bool = False,
      assert_parameter_shapes: bool = True,
      generate_params_kwargs: Dict[str, Any] = {},
      **kwargs,
  ):
      """
      Main method for creating / using generated parameters and passing in input into the target network

      Args:
          generated_params (Optional[torch.Tensor], optional): Generated parameters of the target network. If not provided, the hypernetwork will generate the parameters. Defaults to None.
          has_aux (bool, optional): If True, return the auxiliary output from generate_params method. Defaults to False.
          assert_parameter_shapes (bool, optional): If True, raise an error if generated_params does not have shape (num_target_parameters,). Defaults to True.
          generate_params_kwargs (Dict[str, Any], optional): kwargs to be passed to generate_params method
          *args, *kwargs, arguments to be passed into the target network (also gets passed into generate_params)
      Returns:
          output (torch.Tensor) | (torch.Tensor, Dict[str, torch.Tensor]): returns output from target network and optionally auxiliary output.
      """
      aux_output = {}
      if generated_params is None:
          generated_params, aux_output = self.generate_params(
              *args, **kwargs, **generate_params_kwargs
          )

      if has_aux:
          return (
              self.target_forward(
                  *args,
                  generated_params=generated_params,
                  assert_parameter_shapes=assert_parameter_shapes,
                  **kwargs,
              ),
              generated_params,
              aux_output,
          )
      return self.target_forward(
          *args,
          generated_params=generated_params,
          assert_parameter_shapes=assert_parameter_shapes,
          **kwargs,
      )

...

Flax Hypernetwork

...
  def make_embedding_module(self):
      return nn.Embed(
          self.num_embeddings,
          self.embedding_dim,
          embedding_init=jax.nn.initializers.uniform(),
      )

  def make_weight_generator(self):
      return nn.Dense(self.weight_chunk_dim)

  def generate_params(self, *args, **kwargs) -> Tuple[jnp.array, Dict[str, Any]]:
      embedding = self.embedding_module(jnp.arange(0, self.num_embeddings))
      generated_params = self.weight_generator(embedding).reshape(-1)
      return generated_params, {"embedding": embedding}

  def target_forward(
      self,
      *args,
      generated_params: jnp.array,
      assert_parameter_shapes: bool = True,
      **kwargs,
  ) -> jnp.array:

      if assert_parameter_shapes:
          self.assert_parameter_shapes(generated_params)

      param_tree = create_param_tree(
          generated_params, self.target_weight_shapes, self.target_treedef
      )

      return self.target_network.apply(param_tree, *args, **kwargs)

  def forward(
      self,
      *args,
      generated_params: Optional[jnp.array] = None,
      has_aux: bool = False,
      assert_parameter_shapes: bool = True,
      generate_params_kwargs: Dict[str, Any] = {},
      **kwargs,
  ) -> Tuple[jnp.array, List[jnp.array]]:
      """
      Main method for creating / using generated parameters and passing in input into the target network

      Args:
          generated_params (Optional[jnp.array], optional): Generated parameters of the target network. If not provided, the hypernetwork will generate the parameters. Defaults to None.
          has_aux (bool, optional): If True, return the auxiliary output from generate_params method. Defaults to False.
          assert_parameter_shapes (bool, optional): If True, raise an error if generated_params does not have shape (num_target_parameters,). Defaults to True.
          generate_params_kwargs (Dict[str, Any], optional): kwargs to be passed to generate_params method

      Returns:
          output (torch.Tensor) | (jnp.array, Dict[str, jnp.array]): returns output from target network and optionally auxiliary output.
      """
      aux_output = {}
      if generated_params is None:
          generated_params, aux_output = self.generate_params(
              *args, **kwargs, **generate_params_kwargs
          )

      if has_aux:
          return (
              self.target_forward(
                  *args,
                  generated_params=generated_params,
                  assert_parameter_shapes=assert_parameter_shapes,
                  **kwargs,
              ),
              generated_params,
              aux_output,
          )
      return self.target_forward(
          *args,
          generated_params=generated_params,
          assert_parameter_shapes=assert_parameter_shapes,
          **kwargs,
      )

...

Quick Usage

for detailed examples see notebooks

The main classes to use are TorchHyperNetwork and JaxHyperNetwork and those that inherit them. Instead of constructing them directly, use the from_target method, shown below. After this you can use the hypernetwork exactly like any other nn.Module!

hyper-nn also makes it easy to create Dynamic Hypernetworks that use inputs to create target weights. Basic implementations (both < 100 lines) are provided with JaxDynamicHyperNetwork and TorchDynamicHyperNetwork, which use an rnn and current input to generate weights.

To create hypernetworks, its easier to use the from_target method instead of instantiating it directly because some parameters are calculated automatically for you.

Pytorch

import torch
import torch.nn as nn

from hypernn.torch import TorchHyperNetwork, TorchLinearHyperNetwork, TorchDynamicHyperNetwork

# any module
target_network = nn.Sequential(
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Linear(64, 32)
)

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = TorchLinearHyperNetwork(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)

# now we can use the hypernetwork like any other nn.Module
inp = torch.zeros((1, 32))

# by default we only output what we'd expect from the target network
output = hypernetwork(inp)

# return aux_output
output, generated_params, aux_output = hypernetwork(inp, has_aux=True)

# generate params separately
generated_params, aux_output = hypernetwork.generate_params()
output = hypernetwork(inp, generated_params=generated_params)

### Dynamic Hypernetwork

dynamic_hypernetwork = TorchDynamicHyperNetwork(
    input_dim = 32,
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)

output = dynamic_hypernetwork(inp, generate_params_kwargs=dict(x=inp))

# by default we only output what we'd expect from the target network
output = dynamic_hypernetwork(inp, generate_params_kwargs=dict(x=inp, hidden_state=torch.zeros((1,32))))

Jax

import flax.linen as nn
import jax.numpy as jnp
from jax import random

from hypernn.jax import JaxHyperNetwork, JaxLinearHyperNetwork, JaxDynamicHyperNetwork

# any module
target_network = nn.Sequential(
    [
        nn.Dense(64),
        nn.relu,
        nn.Dense(32)
    ]
)

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = JaxLinearHyperNetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS,
    inputs=jnp.zeros((1, 32)) # jax needs this to initialize target weights
)

# now we can use the hypernetwork like any other nn.Module
inp = jnp.zeros((1, 32))
key = random.PRNGKey(0)
hypernetwork_params = hypernetwork.init(key, inp) # flax needs to initialize hypernetwork parameters first

# by default we only output what we'd expect from the target network
output = hypernetwork.apply(hypernetwork_params, inp)

# return aux_output
output, generated_params, aux_output = hypernetwork.apply(hypernetwork_params, inp, has_aux=True)

# generate params separately
generated_params, aux_output = hypernetwork.apply(hypernetwork_params, method=hypernetwork.generate_params)

output = hypernetwork.apply(hypernetwork_params, inp, generated_params=generated_params)

### Dynamic Hypernetwork

dynamic_hypernetwork = JaxDynamicHyperNetwork.from_target(
    input_dim = 32,
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS,
    inputs=jnp.zeros((1, 32)) # jax needs this to initialize target weights
)
dynamic_hypernetwork_params = dynamic_hypernetwork.init(key, inp, generate_params_kwargs=dict(x=inp, hidden_state=jnp.zeros((1,32)))) # flax needs to initialize hypernetwork parameters first

# by default we only output what we'd expect from the target network
output = dynamic_hypernetwork.apply(dynamic_hypernetwork_params, inp, generate_params_kwargs=dict(x=inp, hidden_state=jnp.zeros((1,32))))

# by default we only output what we'd expect from the target network
output = dynamic_hypernetwork.apply(dynamic_hypernetwork_params, inp, generate_params_kwargs=dict(x=inp, hidden_state=jnp.zeros((1,32))))

Customizing Hypernetworks

hyper-nn makes it easy to customize and create more complex hypernetworks.

The main components to modify are the methods generate_params. This allows for complete control over how the hypernetwork generates parameters

For example, here we extend the linear hypernetwork which uses components embedding_module and weight_generator. We implement a hypernetwork that could be useful in a multi task setting, where a one hot encoded class embedding is concatenated to every row in the embedding matrix outputted by the embedding_module. In addition, we override both our make_embedding_module and make_weight_generator methods to output customized modules. This whole class implementation is under 50 lines of code!

from typing import Optional, Iterable, Any, Tuple, Dict
import torch
import torch.nn as nn
# static hypernetwork
from hypernn.torch import TorchHyperNetwork
from hypernn.torch.utils import get_weight_chunk_dims

class MultiTaskHypernetwork(TorchHyperNetwork):
    def __init__(
        self,
        num_tasks: int,
        target_network: nn.Module,
        num_target_parameters: Optional[int] = None,
        embedding_dim: int = 100,
        num_embeddings: int = 3,
        weight_chunk_dim: Optional[int] = None,
    ):
        super().__init__(
                    target_network = target_network,
                    num_target_parameters = num_target_parameters,
                )
        self.num_tasks = num_tasks
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.weight_chunk_dim = weight_chunk_dim
        if weight_chunk_dim is None:
            self.weight_chunk_dim = get_weight_chunk_dims(
                self.num_target_parameters, num_embeddings
            )
        self.embedding_module = self.make_embedding_module()
        self.weight_generator = self.make_weight_generator()        

    def make_embedding_module(self) -> nn.Module:
        embedding = nn.Embedding(self.num_embeddings, 8)
        return nn.Sequential(
            embedding,
            nn.Tanh(),
            nn.Linear(8, self.embedding_dim),
            nn.Tanh(),
        )

    def make_weight_generator(self) -> nn.Module:
        return nn.Sequential(
            nn.Linear(self.embedding_dim + self.num_tasks, 32),
            nn.Tanh(),
            nn.Linear(32, self.weight_chunk_dim)
        )

    def generate_params(
        self, one_hot_task_embedding: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        embedding = self.embedding_module(
            torch.arange(self.num_embeddings, device=self.device)
        )
        one_hot_task_embedding = one_hot_task_embedding.repeat(self.num_embeddings, 1) # repeat to concat to embedding
        concatenated = torch.cat((embedding, one_hot_task_embedding), dim=-1)
        generated_params = self.weight_generator(concatenated).view(-1)
        return generated_params, {"embedding": embedding}

# usage
target_network = nn.Sequential(
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Linear(64, 32)
)

NUM_TASKS = 4
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = MultiTaskHypernetwork(
    num_tasks = NUM_TASKS,
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)
inp = torch.zeros((1, 32))
one_hot_task_embedding = torch.tensor([0.0,0.0,1.0,0.0]).view((1,4))

out = hypernetwork(inp, generate_params_kwargs=dict(one_hot_task_embedding=one_hot_task_embedding))

Advanced: Using vmap for batching operations

This is useful when dealing with dynamic hypernetworks that generate different params depending on inputs.

Pytorch

import torch.nn as nn
from functorch import vmap

# dynamic hypernetwork
from hypernn.torch import TorchDynamicHyperNetwork

# any module
target_network = nn.Sequential(
    nn.Linear(8, 256),
    nn.ReLU(),
    nn.Linear(256, 32)
)

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

# conditioned on input to generate param vector
hypernetwork = TorchDynamicHyperNetwork(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS,
    input_dim = 8
)

# batch of 10 inputs
inp = torch.randn((10, 1, 8))

# use with a for loop
outputs = []
for i in range(10):
    outputs.append(hypernetwork(inp[i], generate_params_kwargs=dict(x=inp[i])))
outputs = torch.stack(outputs)
assert outputs.size() == (10, 1, 32)

# using vmap
from typing import Dict, Any

def forward(
    generated_params,
    *args,
    has_aux: bool = False,
    assert_parameter_shapes: bool = True,
    generate_params_kwargs: Dict[str, Any] = {},
    **kwargs
):
    return hypernetwork.forward(*args,
                                generated_params=generated_params,
                                has_aux=has_aux,
                                assert_parameter_shapes=assert_parameter_shapes,
                                generate_params_kwargs=generate_params_kwargs,
                                **kwargs)

generated_vmap_params, aux_output = vmap(hypernetwork.generate_params)(inp)
outputs = vmap(forward)(generated_vmap_params, inp)

assert outputs.size() == (10, 1, 32)

Future Plans

Here's a list of some stuff that will hopefully be added to the library. If anyone has other suggestions, please reach out / create an issue!


Citing hyper-nn

If you use this software in your publications, please cite it by using the following BibTeX entry.

@misc{sudhakaran2022,
  author = {Sudhakaran, Shyam Sudhakaran},
  title = {hyper-nn},
  howpublished = {\url{https://github.com/shyamsn97/hyper-nn}},
  year = {2022},
}