Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

Proposed improvement/collaboration: removing the O(T^2) training cost #21

Closed jackd closed 10 months ago

jackd commented 10 months ago

Hi there, just found this work thanks to @yk's recent video. Nice job! There are similarities with work I've been doing for a few months, and while I'm a little bummed you beat me to publish I wasn't going to be able to do a good job of evaluating the architectures anyway (this is a side-project that is currently thrashing my laptop and I'm not sure I could justify the cloud costs to train even a moderately sized model just out of curiosity), and I'm glad the idea is being investigated and released with a permissive license.

I'm not sure if you're looking for suggestions or collaborations, but thought I'd put my ideas out there and see what happens. I'm happy to provide more details/collaborate on a future work if there's interest, or feel free to point me towards someone else who might be interested or run with it yourself.

TL;DR

From my understanding of the paper/code (and I apologise if I've got any of this wrong), computing retention values is still O(T^2) in sequence length T and prone to underflow (hence the nan replacement). Neither of these is necessary. The computation you're performing is just an exponential moving average which can be computed in O(T) with a scan using an associative operator, meaning associative_scan implementations can do it very efficiently in parallel.

Details

Unfortunately we're still waiting on pytorch's associative_scan implementation, so I'll be using jax below, for which a primitive exists. Note I've got a pytorch version working which wraps the jax implementations with jax2torch, though I can't make it work nicely with torch's compile and I'm more comfortable with jax anyway.

The below is an implementation that takes an arbitrary decay factor at each step. To get the same performance as in your paper, I think you can just set it to factors = gamma * ones_like(values), but

import typing as tp
import jax
import jax.numpy as jnp

Pair = tp.Tuple[jnp.ndarray, jnp.ndarray]

def _cumulative_ema_op(a: Pair, b: Pair) -> Pair:
    xa, fa = a
    xb, fb = b
    return xa * fb + xb, fa * fb

def cumulative_ema(
    values: jnp.ndarray, factors: jnp.ndarray, reverse: bool = False, axis: int = 0
) -> jnp.ndarray:
    """
    Compute cumulative exponential moving average.

    If `reverse == False` and axis == 0,
        output[i+1] = output[i] * factors[i+1] + output[i+1]

    If `reverse == True`, then the result is the reverse of the non-reversed call on
    arguments reversed on the given axis.

    Args:
        values: N-D float values
        factors: same shape/dtype as values
        axis: the axis to compute exponential moving average along.
        reverse: if True, perform accumulation in reverse.

    Returns:
        cumulative ema values, same shape as values/factors.
    """
    if axis < 0:
        axis += len(values.shape)
    assert values.shape == factors.shape, (values.shape, factors.shape)
    f, t = jax.lax.associative_scan(
        _cumulative_ema_op, (values, factors), reverse=reverse, axis=axis
    )
    del t
    return f

Thus computing retention values from Q, K and V values would be:

def retention(Q, K, gamma, V, reverse=False):
    """
    Notation:
      T: time dimension
      A: attention dimension
      C: number of output channels

    Args:
        Q: [T, A] query
        K: [T, A] key
        gamma: [] decay constant
        V: [T, C] values

    Returns:
        [T, C]
    """
    rhs = jnp.einsum('ta,tc->tac', K, V)
    rhs = cumulative_ema(rhs, jnp.full_like(rhs, gamma), axis=0, reverse=reverse)
    return jnp.einsum('ta,tac->tc', Q, rhs)

I've left out the batch dimension for simplicity, but I'm sure you could make the appropriate modifications (or if you decide to use jax, just vmap it). I'll spare you the full theoretical derivation for why this computes (Q K.T * D) @ V, but the short version is we use property 1 from here (see last slide) and note that DX = cumulative_ema(X, jnp.full_like(X, gamma), axis=0). This is O(TAC) ins space/time rather than O(T^2(A + C) in time and O(T(T + C)) in space.

Creating a bidirectional encoder is thus trivial by combining two - one with reverse=False and the other with reverse=True.

Now with that implementation you might be tempted to play around with the architecture a little - I've played with creating only two transformed matrices, factors (sigmoid-activated to ensure decay) and values of the same shape (rather than Q, K, V) and using them in the cumulative_ema directly which reduces the O(TAC) memory/time requirement to O(TC). Conceptually this just means that each token embedding at each layer just decides how much of the past to forget, and what to add based on the previous layer's embedding. I don't see any barriers to implementing a complex version to allow for periodic behaviour, but haven't attempted that.

My implementation is keras_core-based (so you can use pytorch backend so long as you don't try and compile). It needs a lot of cleaning up before I'm prepared to make it public but happy to share privately. Very small-scale experiments where I've just replaced Bert's self-attention mechanism with the bidirectional O(TC) implementation discussed above and remove positional embeddings entirely have proved promising (faster training, better performance than bert). I have no way of validating if performance scales with model size - I was planning on looking for collaborators/sponsors for that, so if you're interested in that let me know :).

Jamie-Stirling commented 10 months ago

Hi!

Thanks for your interest in this implementation.

I'm not an original author, however. Please see the official implementation and contact the authors for more information (theres an open issue on this repo that links to their change log, opened by one of the original authors).

Unfortunately I don't personally have the resources to validate the scaling of your model so I can't help you there.

That said, your idea is very interesting and I'd be interested to take a look and potentially collaborate if you decide to share privately.

jackd commented 10 months ago

(Actually reads the README) Well... this is awkward. My apologies. In that case, thanks for the work :). I'll contact the original authors and get back to you if there's a broader collaborative effort.