lucidrains / PEER-pytorch

Pytorch implementation of the PEER block from the paper, Mixture of A Million Experts, by Xu Owen He at Deepmind
MIT License
112 stars 3 forks source link

Usage with x-transformers #2

Open TKassis opened 4 months ago

TKassis commented 4 months ago

PEER looks like an interesting approach and thanks for implementing so cleanly! I do have a quick question though about recommended usage with x-transformers. Would something like this be a good way of using it?


import torch
from PEER_pytorch import PEER
from x_transformers import ContinuousTransformerWrapper, Encoder

peer = PEER(
    dim = 512,
    heads = 8,                   
    num_experts = 1_000_000,     
    num_experts_per_head = 16,   
    dim_key = 128,
    pre_rmsnorm = True
).cuda()

pre_peer = ContinuousTransformerWrapper(
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

post_peer = ContinuousTransformerWrapper(
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

x = torch.randn(2, 1024, 512).cuda()

out = pre_peer(x)
out = peer(out) + out
out = post_peer(out)`
lucidrains commented 4 months ago

oh hey Tim! good to hear from you again

yea, so there's no way to easily slot this into x-transformers atm, but if you'd like, we could discuss how to build some feature to do so on discord

lucidrains commented 4 months ago

@TKassis in the original product key memory paper Lample et al., i think optimal placement was in the middle of the network, so you are doing it right

TKassis commented 4 months ago

Got it, thanks I'll try it out!

lucidrains commented 4 months ago

@TKassis sg, let me know if you see (or don't see) anything!

TKassis commented 4 months ago

Unfortunately, I'm running out of memory even with only 2500 experts on a 48 GB A6000 Ada.


        self.pre_peer = ContinuousTransformerWrapper(
            max_seq_len=0,
            attn_layers=Encoder(
                dim=768,
                depth=6,
                heads=12,
                attn_flash=True,
            ),
            scaled_sinu_pos_emb=True,
        )

        self.peer = PEER(
            dim = 768,
            heads = 8,                   # tested up to 32 - (hk = heads * num_experts_per_head (16))
            num_experts = 2500,     # he chose 1 million
            num_experts_per_head = 16,   # he settled on 16, but was 32 in PKM paper
            dim_key = 128,
            pre_rmsnorm = True
        )

        self.post_peer = ContinuousTransformerWrapper(
            max_seq_len=0,
            attn_layers=Encoder(
                dim=768,
                depth=6,
                heads=12,
                attn_flash=True,
            ),
            use_abs_pos_emb=False,
        )
`
lucidrains commented 4 months ago

@TKassis want to give this wrapper a try?

TKassis commented 4 months ago

Thank you, I gave it a try this morning with the ChunkedPEER wrapper on v0.1.9, unfortunately still running out of memory with 2500 experts. The original unsplit model (were I to combine pre_peer and post_peer) works without any issues). I guess this is designed for DeepMind compute resources :-)

lucidrains commented 4 months ago

@TKassis ah ok, thanks for testing it out!

how long are the sequences you are working with?

TKassis commented 4 months ago

512

lucidrains commented 4 months ago

@TKassis ok, i'll do some profiling later this weekend, thank you!

junphine commented 2 months ago

I use PEER and PKAttention in middle layer of transformers which is 12 layers. ` pk_attn = PKAttention(dim=1536, num_key_values=200x200,pre_rmsnorm=True) peer_mlp = PEER( dim = 1536, heads = 8, num_experts = 200x200, num_experts_per_head = 16, dim_key = 128, pre_rmsnorm = True )

`

forward: x = x + pk_attn(x) x = x + peer_mlp(x)

The good news is that memory does not out in the 32GB v100, and the flops is well.

The bad news is that ppl curves are not so smooth and ideal!

The question then is whether pk_attn and peer_mlp can be used together?

junphine commented 2 months ago

image

lucidrains commented 2 months ago

@junphine thanks for testing it out

could you try this improvisation and see if it is any more stable?

junphine commented 2 months ago

@lucidrains Yes, PEERLora is much more stable, with init: self.projin.weight.normal(std=dim-0.5) self.projout.weight.normal(std=dim_inner-0.5) self.proj_in_loraa.weight.normal(std=dim-0.5) self.proj_in_lorab.weight.normal(std=dim_inner-0.5) self.proj_out_loraa.weight.normal(std=dim_inner-0.5) self.proj_out_lorab.weight.normal(std=dim-0.5)

But it should takes longer training time to verify. Because I find the value of lora_in_hidden tends to be very large.

junphine commented 2 months ago

image

lucidrains commented 2 months ago

@junphine nice, i added in some better init as well, thanks for reporting these results!

junphine commented 2 months ago

@lucidrains Unfortunate, the PEERLora layer didn't seem to be beneficial, when I removed it (replaced it with MLP) or added it, the ppl curve didn't change at all. The two curves coincide perfectly.

lucidrains commented 2 months ago

@junphine ah, that is unfortunate

how about the original formulation, once stabilized of course?

junphine commented 2 months ago

@lucidrains I see the benefits of original formulation.

By increasing the number of exports from 200x200 to 500x500,decrease dim_key from 768 to 128, the curve is converging. image

The green curve is base model which have 24 layers, 0.9B params limited by gpu memory purple is PEER MLP model,which have 16 laysers,1.2B params

The base model has a leading convergence rate, but lags behind PEER at a later stage, seems intuitive

lucidrains commented 2 months ago

@junphine hey that's great! thank you for sharing this! 🚀