proger / hippogriff

Griffin MQA + Hawk Linear RNN Hybrid
https://pypi.org/project/hippogriff
MIT License
79 stars 5 forks source link

Initialization of lambda incorrect #5

Open ozppupbg opened 3 months ago

ozppupbg commented 3 months ago

Hello,

I noticed a deviation from the Griffin paper in your code.

The Griffin paper states in the second part of chapter 2.4:

We initialize Λ such that a^c is uniformly distributed between 0.9 and 0.999 at the start of training,

and a = sigmoid(Λ).

So actually, the initialization for Lambda should be calculated as Λ = -log((1 / a^(1/c)) - 1) with a uniformly between 0.9 and 0.999.

Here, Lambda is initialized as: https://github.com/proger/hippogriff/blob/7bf573298a09ffedf6a39e156c50281103cd8dd9/hippogriff.py#L45 which is neither random nor uniform, as linspace is not random and the sigmoid operation is not linear.

proger commented 3 months ago

Hi @ozppupbg, nice catch!

forget_base is initialized so that values of (-alpha_log_scale.exp() * softplus(forget_base)).exp() are in the range of 0.9...0.999 but are exponentially biased towards 1. This makes the activation distribution more similar to Mamba's forget gates. I'll be adding sweeps to ablate this decision later.

image

That chart was generated by this notebook: https://gist.github.com/proger/cd6ee302661034b7b8d4685dcad8cc3d

BeeGass commented 3 months ago

The paper describes "We initialize Λ such that $a^{c}$ is uniformly distributed between 0.9 and 0.999". We assume there is $a^{c} = k$ where $k$ is uniformly distributed between 0.9 and 0.999. Given this we solve for $a$:

$$ \begin{align} a^{c} &= k \ \log(a^{c}) &= \log(k) \ c \cdot \log(a) &= \log(k) \ \frac{c \cdot \log(a)}{c} &= \frac{\log(k)}{c} \ \log(a) &= \frac{\log(k)}{c} \ \log(a) &= \frac{1}{c} \log(k) \ \log(a) &= \log(k^{\frac{1}{c}}) \ a &= k^{\frac{1}{c}} \end{align} $$

We then solve for $\log(\sigma(\Lambda)^{c \cdot r})$ with this redefined $a$.

$$ \begin{align} a &= \sigma (-\Lambda) \ a &= \frac{1}{(1 + e^{-(-\Lambda)})} \ a &= \frac{1}{(1 + e^{(\Lambda)})} \ a (1 + e^{-(-\Lambda)}) &= \frac{(1 + e^{(\Lambda)})}{(1 + e^{(\Lambda)})} \ a (1 + e^{(\Lambda)}) &= 1 \ \frac{a (1 + e^{(\Lambda)})}{a} &= \frac{1}{a} \ 1 + e^{(\Lambda)} &= \frac{1}{a} \ 1 + e^{(\Lambda)} - 1 &= \frac{1}{a} - 1 \ e^{(\Lambda)} &= \frac{1}{a} - 1 \ \ln{e^{(\Lambda)}} &= \ln{(\frac{1}{a} - 1)} \ \Lambda &= \ln{(\frac{1}{a} - 1)} \ \Lambda &= \ln{(\frac{1}{k^{\frac{1}{c}}} - 1)} \ \end{align} $$

given that we can perform $a^{c \cdot r} \rightarrow k^{\frac{1}{c \cdot r}}$, we can say:

$$ \Lambda = \ln{(\frac{1}{k^{\frac{1}{c \dot r}}} - 1)} $$

I develop in jax/flax so apologies if this cant be checked. That being said couldnt we define the cell like the following:

class RGLRUCell(nn.Module):

    feature: int
    c: float = 8.0
    gate_fn: callable = nn.sigmoid

    def setup(self):
        self.k = self.param(
            "k",
            lambda rng, s: jax.random.uniform(rng, s, minval=0.9, maxval=0.999),
            (self.feature,),
        )
        self.ri = nn.Dense(self.feature * 2)

    def __call__(self, carry, x):
        (h_prev, _) = carry
        r, i = jnp.split(self.ri(x), 2, axis=-1)
        r = self.gate_fn(r)
        i = self.gate_fn(i)

        a = jnp.exp(jnp.log((1 / (self.k ** (1/(self.c * r)))) - 1))

        h_new = a * h_prev + jnp.sqrt(1 - a**2) * (i * x)

        return (h_new, h_new), h_new

Im trying to make sense where this line and its values are coming from.

https://github.com/proger/hippogriff/blob/92c94f520a379d29b053cdc0a30f7a9902fb0d5a/hippogriff.py#L45