Haiyang-W / TokenFormer

Official Implementation of TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters
https://haiyang-w.github.io/tokenformer.github.io/
Apache License 2.0
402 stars 27 forks source link

Maybe we can apply `ring attention` to scale up token former infinity? #1

Closed reyoung closed 1 week ago

reyoung commented 4 weeks ago

Nice paper on making LLM fully attention-based. However, I noticed that the largest model discussed in the paper is a 1.5B model.

I wonder if the pattention layer is difficult to tensor parallelize? A 20B model can be trained on a single node with an 8xA100 40G machine when training a dense model.

I also think that tensor parallelism might not be necessary for the pattention layer. If we can apply the ring attention technique to the pattention layer, it seems it could easily scale up to multi-card parallelism.

Since the pattention can be trained incrementally according to the paper, could it be possible to pretrain a dense 100B level model in one month using 16 machines with 8xA100 each? (According to the paper, only 10% of the pretraining data is needed when scaling up the model size.)

Could you please advise if my understanding is correct?

Haiyang-W commented 3 weeks ago

Yes, your idea is very clever and has given me some inspiration. Personally, I believe TokenFormer has significant application potential in edge-cloud collaboration for product-level FMs. Imagine you have a large knowledge base composed of numerous key-value parameter pairs distributed across different hosts. Using ring attention could be a good approach.

TokenFormer is essentially an extension of the Transformer, so any techniques used on Transformers can also be applied to TokenFormer.

kroggen commented 1 week ago

This new method is related:

MemoryFormer: Minimize Transformer Computation by Removing Fully-Connected Layers

https://arxiv.org/abs/2411.12992

In TokenFormer the input is processed with each key parameter to generate attention scores, and then the highest ones are selected

In MemoryFormer the input is hashed and the output is used as an index in a lookup table of "values"

It is in fact a MoE, but also used in the Q,K,V projections