Open ozppupbg opened 3 months ago
Hi @ozppupbg, nice catch!
forget_base
is initialized so that values of (-alpha_log_scale.exp() * softplus(forget_base)).exp()
are in the range of 0.9...0.999 but are exponentially biased towards 1. This makes the activation distribution more similar to Mamba's forget gates. I'll be adding sweeps to ablate this decision later.
That chart was generated by this notebook: https://gist.github.com/proger/cd6ee302661034b7b8d4685dcad8cc3d
The paper describes "We initialize Λ such that $a^{c}$ is uniformly distributed between 0.9 and 0.999". We assume there is $a^{c} = k$ where $k$ is uniformly distributed between 0.9 and 0.999. Given this we solve for $a$:
$$ \begin{align} a^{c} &= k \ \log(a^{c}) &= \log(k) \ c \cdot \log(a) &= \log(k) \ \frac{c \cdot \log(a)}{c} &= \frac{\log(k)}{c} \ \log(a) &= \frac{\log(k)}{c} \ \log(a) &= \frac{1}{c} \log(k) \ \log(a) &= \log(k^{\frac{1}{c}}) \ a &= k^{\frac{1}{c}} \end{align} $$
We then solve for $\log(\sigma(\Lambda)^{c \cdot r})$ with this redefined $a$.
$$ \begin{align} a &= \sigma (-\Lambda) \ a &= \frac{1}{(1 + e^{-(-\Lambda)})} \ a &= \frac{1}{(1 + e^{(\Lambda)})} \ a (1 + e^{-(-\Lambda)}) &= \frac{(1 + e^{(\Lambda)})}{(1 + e^{(\Lambda)})} \ a (1 + e^{(\Lambda)}) &= 1 \ \frac{a (1 + e^{(\Lambda)})}{a} &= \frac{1}{a} \ 1 + e^{(\Lambda)} &= \frac{1}{a} \ 1 + e^{(\Lambda)} - 1 &= \frac{1}{a} - 1 \ e^{(\Lambda)} &= \frac{1}{a} - 1 \ \ln{e^{(\Lambda)}} &= \ln{(\frac{1}{a} - 1)} \ \Lambda &= \ln{(\frac{1}{a} - 1)} \ \Lambda &= \ln{(\frac{1}{k^{\frac{1}{c}}} - 1)} \ \end{align} $$
given that we can perform $a^{c \cdot r} \rightarrow k^{\frac{1}{c \cdot r}}$, we can say:
$$ \Lambda = \ln{(\frac{1}{k^{\frac{1}{c \dot r}}} - 1)} $$
I develop in jax/flax so apologies if this cant be checked. That being said couldnt we define the cell like the following:
class RGLRUCell(nn.Module):
feature: int
c: float = 8.0
gate_fn: callable = nn.sigmoid
def setup(self):
self.k = self.param(
"k",
lambda rng, s: jax.random.uniform(rng, s, minval=0.9, maxval=0.999),
(self.feature,),
)
self.ri = nn.Dense(self.feature * 2)
def __call__(self, carry, x):
(h_prev, _) = carry
r, i = jnp.split(self.ri(x), 2, axis=-1)
r = self.gate_fn(r)
i = self.gate_fn(i)
a = jnp.exp(jnp.log((1 / (self.k ** (1/(self.c * r)))) - 1))
h_new = a * h_prev + jnp.sqrt(1 - a**2) * (i * x)
return (h_new, h_new), h_new
Im trying to make sense where this line and its values are coming from.
https://github.com/proger/hippogriff/blob/92c94f520a379d29b053cdc0a30f7a9902fb0d5a/hippogriff.py#L45
Hello,
I noticed a deviation from the Griffin paper in your code.
The Griffin paper states in the second part of chapter 2.4:
and a = sigmoid(Λ).
So actually, the initialization for Lambda should be calculated as
Λ = -log((1 / a^(1/c)) - 1)
with a uniformly between 0.9 and 0.999.Here, Lambda is initialized as: https://github.com/proger/hippogriff/blob/7bf573298a09ffedf6a39e156c50281103cd8dd9/hippogriff.py#L45 which is neither random nor uniform, as linspace is not random and the sigmoid operation is not linear.