lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

Simplifying Transformer Blocks (https://arxiv.org/abs/2311.01906) #214

Closed Froskekongen closed 10 months ago

Froskekongen commented 10 months ago

Would be nice to have this one here (https://arxiv.org/abs/2311.01906).

lucidrains commented 10 months ago

hmm, could probably do a separate repository for that, with some makeover (relative positions etc)

have you tried it? does it work?

lucidrains commented 10 months ago

@Froskekongen so there is research out there that suggests the parallel block architecture leads to instability at scale (paper out of salesforce). however, i'm game for the serial version if you let me know how it fares, share some successful experiments on your end, etc

Froskekongen commented 10 months ago

I think the serial version was the most interesting (Figure 1, top right). And I think you are right - probably easier with a separate repo for this since a lot of the content is about dealing with initialization and whatnot.

Will report if I find some time to experiment with it. Closing the issue for now.

lucidrains commented 10 months ago

or maybe you can get Bobby to make his repo pip installable?

lucidrains commented 10 months ago

ah, figure 1 top right is still a parallel block. i don't think they ever did experiments on a serial version. i don't know if i believe in parallel blocks anymore; i can link that salesforce paper later once i find it

zaptrem commented 8 months ago

Didn't PALM do parallel blocks to great effect? ". This approach is also used by PaLM (Chowdhery et al., 2022), where this technique sped up the largest model’s training by 15% without performance degradation." (ViT 22B paper)

lucidrains commented 8 months ago

@zaptrem yea.. so first you need to know some behind-the-scenes. the parallel block originated from the open source community. it was devised by a precocious college student, Ben Wang, for the training of GPT-J. It was then adopted by Brain for the training of PaLM, a lot of the code probably taken verbatim from GPT-J (as it is in jax). However, what you need to know is that Ben confided in me that during the tail end of training for GPT-J, he actually faced insurmountable instability. Luckily, it was near the end, the model was good enough, so he stopped training a bit early, and just open sourced it. The rest was history. That bit never made it into the paper afaict.

If you read PaLM paper, they actually documented this instability. In fact, they had a really hacky way of getting around it, by rewinding to just a bit before each divergence and trying different batches. In other words, I have no doubt it has some performance benefits, but I don't think this instability is worth the cost.

lucidrains commented 8 months ago

@zaptrem who knows, maybe there is a solution if enough researchers work on the problem, but why do so when serial architecture already works so well? (llama)

anyways, just to show i have put some thought into this.

lucidrains commented 8 months ago

here's the salesforce paper, where the author addresses the instability issue of parallel blocks directly https://blog.salesforceairesearch.com/xgen/