google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.9k stars 233 forks source link

Best practice of implementing activation functions with trainable parameters? #407

Closed smao-astro closed 2 years ago

smao-astro commented 2 years ago

Hi,

I am not sure whether this is the right place to ask. Still, I would appreciate it if someone would like to give advice on implementing an activation function with trainable parameters.

Specifically, I want to implement the activation function described by Equation 10 in the paper below:

Self-scalable Tanh (Stan): Faster Convergence and Better Generalization in Physics-informed Neural Networks Raghav Gnanasambandam, Bo Shen, Jihoon Chung, Xubo Yue, Zhenyu (James)Kong https://arxiv.org/abs/2204.12589

Here is the equation 10: image

The \beta^{i}_{k} are trainable parameters that are initialized to ones.

I managed to

  1. Initialize these parameters with 'hk.get_parameter' (succeed).
  2. Wrap the activation function as a hk.Module (succeed).

Here goes the code

import jax
import jax.numpy as jnp
import haiku as hk

class Stan(hk.Module):
    def __init__(self, beta_init=1.0):
        super(Stan, self).__init__(name="stan")
        self.beta_init = beta_init

    def __call__(self, inputs: jnp.ndarray):
        beta = hk.get_parameter(
            "stan_beta",
            [inputs.shape[-1]],
            init=hk.initializers.Constant(self.beta_init),
        )
        return jax.nn.tanh(inputs) + beta * inputs * jax.nn.tanh(inputs)

I wonder how to use the activation function and the API haiku.nets.MLP to build a simple MLP network. haiku.nets.MLP accepts activation: Callable[[jnp.ndarray], jnp.ndarray]. To use the stan activation function with haiku.nets.MLP, I tried to wrap the class to a function,

def stan(x):
    return Stan(beta_init=1.0)(x)

Then call haiku.nets.MLP with this function, so that I guess every call of stan inside hk.transform would create one Stan. It works, but the idea of creating one instance of Stan every call of stan makes me worry about efficiency loss.

So my question is:

  1. How many Stan instances will I have? Will it increase during the training?
  2. Is there a better implementation that is more efficient and concise as possible?
tomhennigan commented 2 years ago

How many Stan instances will I have? Will it increase during the training?

Here is a Colab notebook that might help with this:

https://colab.research.google.com/gist/tomhennigan/d1000ec99eb59e92f9b7680f1591ab67/examples-of-using-stan-module.ipynb

As you can see from the three visualisations (printing out the params dict, tabulting module method calls and the graphviz graph showing intermediate operations), with your implementation for an N layer MLP, you would have N-1 instances of Stan. This would not change during training.

Is there a better implementation that is more efficient ..

One obvious optimisation would be to cache tanh(x) since you compute it twice, but with JAX you don't need to do this optimisation since XLA will do if for you when you compile your model.

Otherwise you appear to have implemented what they describe in the paper, so I think better or more efficient variants may not actually be "stan".

.. and concise as possible?

I think you can make the Stan module a few lines shorter (drop the contructor and use jnp.ones for the init, more concise names for the parameter):

class Stan(hk.Module):
  def __call__(self, x: jnp.ndarray):
    beta = hk.get_parameter("b", [x.shape[-1]], init=jnp.ones)
    return jax.nn.tanh(x) + beta * x * jax.nn.tanh(x)

def stan(x):
  return Stan()(x)

class MyModel(hk.Module):
  def __call__(self, x):
    net = hk.nets.MLP([300, 100, 10], activation=stan)
    return net(x)
smao-astro commented 2 years ago

Hi @tomhennigan ,

Thank you very much for your answer and suggestions!