Open TKassis opened 4 months ago
oh hey Tim! good to hear from you again
yea, so there's no way to easily slot this into x-transformers atm, but if you'd like, we could discuss how to build some feature to do so on discord
@TKassis in the original product key memory paper Lample et al., i think optimal placement was in the middle of the network, so you are doing it right
Got it, thanks I'll try it out!
@TKassis sg, let me know if you see (or don't see) anything!
Unfortunately, I'm running out of memory even with only 2500 experts on a 48 GB A6000 Ada.
self.pre_peer = ContinuousTransformerWrapper(
max_seq_len=0,
attn_layers=Encoder(
dim=768,
depth=6,
heads=12,
attn_flash=True,
),
scaled_sinu_pos_emb=True,
)
self.peer = PEER(
dim = 768,
heads = 8, # tested up to 32 - (hk = heads * num_experts_per_head (16))
num_experts = 2500, # he chose 1 million
num_experts_per_head = 16, # he settled on 16, but was 32 in PKM paper
dim_key = 128,
pre_rmsnorm = True
)
self.post_peer = ContinuousTransformerWrapper(
max_seq_len=0,
attn_layers=Encoder(
dim=768,
depth=6,
heads=12,
attn_flash=True,
),
use_abs_pos_emb=False,
)
`
@TKassis want to give this wrapper a try?
Thank you, I gave it a try this morning with the ChunkedPEER wrapper on v0.1.9, unfortunately still running out of memory with 2500 experts. The original unsplit model (were I to combine pre_peer and post_peer) works without any issues). I guess this is designed for DeepMind compute resources :-)
@TKassis ah ok, thanks for testing it out!
how long are the sequences you are working with?
512
@TKassis ok, i'll do some profiling later this weekend, thank you!
I use PEER and PKAttention in middle layer of transformers which is 12 layers. ` pk_attn = PKAttention(dim=1536, num_key_values=200x200,pre_rmsnorm=True) peer_mlp = PEER( dim = 1536, heads = 8, num_experts = 200x200, num_experts_per_head = 16, dim_key = 128, pre_rmsnorm = True )
`
forward: x = x + pk_attn(x) x = x + peer_mlp(x)
The good news is that memory does not out in the 32GB v100, and the flops is well.
The bad news is that ppl curves are not so smooth and ideal!
The question then is whether pk_attn and peer_mlp can be used together?
@junphine thanks for testing it out
could you try this improvisation and see if it is any more stable?
@lucidrains Yes, PEERLora is much more stable, with init: self.projin.weight.normal(std=dim-0.5) self.projout.weight.normal(std=dim_inner-0.5) self.proj_in_loraa.weight.normal(std=dim-0.5) self.proj_in_lorab.weight.normal(std=dim_inner-0.5) self.proj_out_loraa.weight.normal(std=dim_inner-0.5) self.proj_out_lorab.weight.normal(std=dim-0.5)
But it should takes longer training time to verify. Because I find the value of lora_in_hidden tends to be very large.
@junphine nice, i added in some better init as well, thanks for reporting these results!
@lucidrains Unfortunate, the PEERLora layer didn't seem to be beneficial, when I removed it (replaced it with MLP) or added it, the ppl curve didn't change at all. The two curves coincide perfectly.
@junphine ah, that is unfortunate
how about the original formulation, once stabilized of course?
@lucidrains I see the benefits of original formulation.
By increasing the number of exports from 200x200 to 500x500,decrease dim_key from 768 to 128, the curve is converging.
The green curve is base model which have 24 layers, 0.9B params limited by gpu memory purple is PEER MLP model,which have 16 laysers,1.2B params
The base model has a leading convergence rate, but lags behind PEER at a later stage, seems intuitive
@junphine hey that's great! thank you for sharing this! 🚀
PEER looks like an interesting approach and thanks for implementing so cleanly! I do have a quick question though about recommended usage with x-transformers. Would something like this be a good way of using it?