Closed smao-astro closed 2 years ago
How many Stan instances will I have? Will it increase during the training?
Here is a Colab notebook that might help with this:
As you can see from the three visualisations (printing out the params dict, tabulting module method calls and the graphviz graph showing intermediate operations), with your implementation for an N layer MLP, you would have N-1 instances of Stan. This would not change during training.
Is there a better implementation that is more efficient ..
One obvious optimisation would be to cache tanh(x)
since you compute it twice, but with JAX you don't need to do this optimisation since XLA will do if for you when you compile your model.
Otherwise you appear to have implemented what they describe in the paper, so I think better or more efficient variants may not actually be "stan".
.. and concise as possible?
I think you can make the Stan module a few lines shorter (drop the contructor and use jnp.ones for the init, more concise names for the parameter):
class Stan(hk.Module):
def __call__(self, x: jnp.ndarray):
beta = hk.get_parameter("b", [x.shape[-1]], init=jnp.ones)
return jax.nn.tanh(x) + beta * x * jax.nn.tanh(x)
def stan(x):
return Stan()(x)
class MyModel(hk.Module):
def __call__(self, x):
net = hk.nets.MLP([300, 100, 10], activation=stan)
return net(x)
Hi @tomhennigan ,
Thank you very much for your answer and suggestions!
Hi,
I am not sure whether this is the right place to ask. Still, I would appreciate it if someone would like to give advice on implementing an activation function with trainable parameters.
Specifically, I want to implement the activation function described by Equation 10 in the paper below:
Self-scalable Tanh (Stan): Faster Convergence and Better Generalization in Physics-informed Neural Networks Raghav Gnanasambandam, Bo Shen, Jihoon Chung, Xubo Yue, Zhenyu (James)Kong https://arxiv.org/abs/2204.12589
Here is the equation 10:
The \beta^{i}_{k} are trainable parameters that are initialized to ones.
I managed to
Here goes the code
I wonder how to use the activation function and the API
haiku.nets.MLP
to build a simple MLP network.haiku.nets.MLP
accepts activation:Callable[[jnp.ndarray], jnp.ndarray]
. To use the stan activation function withhaiku.nets.MLP
, I tried to wrap the class to a function,Then call haiku.nets.MLP with this function, so that I guess every call of
stan
insidehk.transform
would create oneStan
. It works, but the idea of creating one instance ofStan
every call ofstan
makes me worry about efficiency loss.So my question is:
Stan
instances will I have? Will it increase during the training?