BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
12.16k stars 838 forks source link

Adding somewhat-mini attention layers into the block? #95

Open PicoCreator opened 1 year ago

PicoCreator commented 1 year ago

Motivation/Theory

I suspect that because the upper block layers, can only "see" the block layer before it. There is lots of redundant information being forwarded between the layers.

Current flow

The following is our current data flow (somewhat)

graph LR
    subgraph "Iteration N"
        direction TB
        Emb1(Input Embedding)
        Emb1 --> LN1

        subgraph "Block 1"
            direction TB
            LN1(Layer Norm)
            LN1 --> TM1(Time Mixing)
            TM1  --> CM1(Chanel Mixing)
        end

        BlockNX(((Block 2 to X)))
        BlockNX --> BlockNX
        CM1 --> BlockNX

        BlockNX-->LO1(Output Layer Norm & Linear head)
        LO1-->SLogitN(Logit Sampling)
        SLogitN-->OTN(Output Token)
    end

    IN0{{Iteration N-1}}

    IN0 -..-> TM1
    IN0 -..-> CM1
    IN0 -..-> BlockNX

    IN1{{Iteration N+1}}

    TM1 -..-> IN1
    CM1 -..-> IN1
    BlockNX -..-> IN1

So for example, if we have input like

<some long story>
What is the species of the pet in the story

The following would probably be approximate "level of thought on the layers"

Layer 0: story Layer 1-5: What is the species of the pet? Layer 6-10: (parts of the story), what the pet species?, maybe its a goat? Layer 20-30: (parts of the story), what the pet species?, maybe its a goat? probably not, probably a cat

While we do not accurately understand the exact information stored in latent space embedding, for an information theory standpoint, because the "upper layers" will require information from the "lower layers", there will be multiple redundant information flowing through our limited embedding through each layer. And while we probably have not hit the limit yet (at 8k tokens) of compressing information to our current embedding size, we would at some point.

Suggested change

The idea here, is to add mini attention layer from block 2 onwards, where it would read the time mixing / channel mixing embedding for all the previous layers + the input embedding.

The size of this layer, would grow X^2 based on the number layers. But the computation cost will remain as N per token

This allow the later layer, to simply depend on the lower layer for certain key latent information, without needing to "retransmit them to the upper layer"

So something like the following

graph LR
    subgraph "Iteration N"
        direction TB
        Emb1(Input Embedding)
        Emb1 --> LN1
        Emb1 -..-> TR2
        Emb1 -..-> TR3

        subgraph "Block 1"
            direction TB
            LN1(Layer Norm)
            LN1 --> TM1(Time Mixing)
            LN1 --> CM1(Chanel Mixing)
        end

        subgraph "Block 2"
            direction TB

            TM1 --> TR2
            CM1 --> TR2
            TR2(Mini Attention Layer + Layer Norm)
            TR2 --> TM2(Time Mixing)
            TR2 --> CM2(Chanel Mixing)
        end

        subgraph "Block 3"
            direction TB

            TM2 --> TR3
            CM2 --> TR3
            TM1 -..-> TR3
            CM1 -..-> TR3
            TR3(Mini Attention Layer + Layer Norm)
            TR3 --> TM3(Time Mixing)
            TR3 --> CM3(Chanel Mixing)
        end

        BlockNX(((Block 3 to X)))
        BlockNX -..-> BlockNX

        CM3 --> BlockNX
        TM3 --> BlockNX

        CM2 -..-> BlockNX
        TM2 -..-> BlockNX
        CM1 -..-> BlockNX
        TM1 -..-> BlockNX
        Emb1 -..-> BlockNX

        BlockNX-->LA(Output Final Attention Layer)
        LA-->LO1(Output Layer Norm & Linear head)
        LO1-->SLogitN(Logit Sampling)
        SLogitN-->OTN(Output Token)
    end

    %% IN0{{Iteration N-1}}

    %% IN0 -..-> TM1
    %% IN0 -..-> CM1
    %% IN0 -..-> TM2
    %% IN0 -..-> CM2
    %% %% IN0 -..-> BlockNX

    %% IN1{{Iteration N+1}}

    %% TM1 -..-> IN1
    %% CM1 -..-> IN1
    %% TM2 -..-> IN1
    %% CM2 -..-> IN1
    %% %% BlockNX -..-> IN1

So that instead it can be more like the following

Layer 0: story Layer 1-5: What is the species of the pet? Layer 6-10: (parts of the story) Layer 20-30: maybe its a goat? probably not, probably a cat

Allowing the various layers to be more focus onto a certain set of information which can be read by the upper layers

I removed the data flow to the final attention layer, and the various other iteration, as the arrows were getting ridiculously out of hand

I also split up the time mixing, and channel mixing path, to allow each of it to be computed in parallel, to hopefully somewhat improve GPU utalization

Not made clear in the digram - is that the output for the attention layer should be (2, embedding size). Which means that that it output separate embedding for time and channel mixing respectively


So my question to @BlinkDL - does the above make sense? is it something you have tried and found that it made things worse?

Or is my basis and theory behind it flawed, and this seems like needless complications

BlinkDL commented 1 year ago

I already tried similar methods (RWKV-4b) :-)

Please check https://discord.com/channels/992359628979568762/1083107245971226685/1085912612035903549 (in RWKV Discord)

PicoCreator commented 1 year ago

Ahh you mean the "However some tiny amt of QKV attention (as in RWKV-4b)" part of the message. Glad to see my understanding / theory / some validation in this direction all in one post.

If i understood right RWKV-4b is v4neo, with RWKV_MY_TESTING enabled wtih def jit_funcQKV(self, x): at https://github.com/BlinkDL/RWKV-LM/blob/6fea06e671ecf07d037caf3b8bdf415ddd7f3984/RWKV-v4neo/src/model.py#L221

And is not part of the current raven model?

Haha what would happen if there is even more attentions at the upper layer, as i see the jit_funcQKV is quite linear across the layers (unless i read the code wrongly)