joey00072 / ohara

Collection of autoregressive model implementation
66 stars 5 forks source link

Exact command to reproduce the curve in MOD is Vibe? #10

Open jzhang38 opened 6 months ago

jzhang38 commented 6 months ago

Hi Joey,

Thank you for such a wonderful OS work! !

Could you share the exact command to reproduce the curve in your MOD is Vibe blog? For example, did you use DDP and how many GPUs?

joey00072 commented 6 months ago

I trained 300M model on single A6000 (from paperspace grident) with bf16-mixed presicon ,

experiments/mixture_of_depth/train_mod.py here is location of training script you can look into it and change sizes, or keep it default to reproduce.

first install this repo git clone and pip install -e . then pretokenize dataset python examples/prepare-dataset.py open this file and change dataset to minipile and than run train_mod.py. I am using lightning fabric so it should be pretty easy to multi node training but I trained on single 48gig a6000

I'll add README here experiments/mixture_of_depth/ in details tonight or tomorrow 😅

jzhang38 commented 6 months ago

I'll add README here experiments/mixture_of_depth/ in details tonight or tomorrow

Thank you so much!

Yeah I've pretty much read your code related to MoD.

One concern for me is that I noticed the dataset is implemented as an iterator object. So I am not sure whether lightning fabric would handle this correctly in a multi-gpu setup as we would need a distributed sampler.

jzhang38 commented 6 months ago
Screenshot 2024-04-14 at 3 12 37 PM

Looking at Figure 7 from the paper, I feel they also multiply the router weights to those skipped tokens as well.

joey00072 commented 6 months ago

We are taking softmax over long seq length, most values at other end will be close to zeros, if we multiply all tokens by router logits pass though token will become really tiny like 1e-5 or something.

WuNein commented 6 months ago

@joey00072 One more thing, about MoD. https://github.com/joey00072/ohara/blob/16941b38f3749d38c9cdaaf178b6d2c5995d5810/experiments/mixture_of_depth/mixture_of_depth.py#L121

Since MoD only makes a very small fraction of the tokens for caluating attention, I have concerns about model performance in some extreme cases, such as very short inputs.

https://github.com/joey00072/ohara/blob/16941b38f3749d38c9cdaaf178b6d2c5995d5810/experiments/mixture_of_depth/mixture_of_depth.py#L102

I think topK should have a minial value, like 10 or seq_len (when seq_len < 10).

And when faced with long text, this thing is kind of like sparse attention or a sliding window, which I think is acceptable.

I would like to consult your thoughts