Open jzhang38 opened 6 months ago
I trained 300M model on single A6000 (from paperspace grident) with bf16-mixed presicon ,
experiments/mixture_of_depth/train_mod.py
here is location of training script you can look into it and change sizes, or keep it default to reproduce.
first install this repo git clone and pip install -e .
then pretokenize dataset python examples/prepare-dataset.py
open this file and change dataset to minipile
and than run train_mod.py.
I am using lightning fabric so it should be pretty easy to multi node training but I trained on single 48gig a6000
I'll add README here experiments/mixture_of_depth/
in details tonight or tomorrow 😅
I'll add README here experiments/mixture_of_depth/ in details tonight or tomorrow
Thank you so much!
Yeah I've pretty much read your code related to MoD.
One concern for me is that I noticed the dataset is implemented as an iterator object. So I am not sure whether lightning fabric would handle this correctly in a multi-gpu setup as we would need a distributed sampler.
Looking at Figure 7 from the paper, I feel they also multiply the router weights to those skipped tokens as well.
We are taking softmax over long seq length, most values at other end will be close to zeros, if we multiply all tokens by router logits pass though token will become really tiny like 1e-5 or something.
@joey00072 One more thing, about MoD. https://github.com/joey00072/ohara/blob/16941b38f3749d38c9cdaaf178b6d2c5995d5810/experiments/mixture_of_depth/mixture_of_depth.py#L121
Since MoD only makes a very small fraction of the tokens for caluating attention, I have concerns about model performance in some extreme cases, such as very short inputs.
I think topK should have a minial value, like 10 or seq_len (when seq_len < 10).
And when faced with long text, this thing is kind of like sparse attention or a sliding window, which I think is acceptable.
I would like to consult your thoughts
Hi Joey,
Thank you for such a wonderful OS work! !
Could you share the exact command to reproduce the curve in your MOD is Vibe blog? For example, did you use DDP and how many GPUs?