astramind-ai / Mixture-of-depths

Unofficial implementation for the paper "Mixture-of-Depths: Dynamically allocating compute in transformer-based language models"
123 stars 7 forks source link

Unable to run model.generate() using converted model #7

Closed Zkli-hub closed 3 months ago

Zkli-hub commented 4 months ago

I find that I cannot run the generate() function to inference inputs using the converted model, can you help me?

Here is the error: image

Zkli-hub commented 3 months ago

@mlinmg Seems that it still doesn't work even if I use the latest version and explicitly set model.eval(). You can reproduce it easily in this notebook which is forked from your latest repo version: https://github.com/Zkli-hub/Mixture-of-depths-test/blob/main/MoD/modeling/models/test.ipynb And initiliaze the model using your own pretrained hf weight path.

mlinmg commented 3 months ago

I don't have a pretrained mod model atm, and it keeps giving me a nan related error, I'll do a small train those days

Zkli-hub commented 3 months ago

I don't have a pretrained mod model atm, and it keeps giving me a nan related error, I'll do a small train those days

Thanks for your reply! Actually, you don't need a pretrained mod model to reproduce this. Just use the converted mod model from a base llama2-7b is okay. I suppose that it is not the problem of pretrained weight but the code. In my provided notebook, you can initialize the model using: model = LlamaMoDForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") model.eval() And then you can find it still unable to run model.generate()

Zkli-hub commented 3 months ago

I don't have a pretrained mod model atm, and it keeps giving me a nan related error, I'll do a small train those days

I think I know where's the problem. For the eval mode, the model generates output autoregressively which is a different logic from the training mode. It generates token one by one and there will be only one input token once. I logged the intermediate variables of original llama model. The first figure shows the log of prepared input while the second shows the log of autoregressive inference stage which is a one by one input. To solve it, I think you can refer to the original paper and they listed two approach for the routing logic during the inference stage.

image image
mlinmg commented 2 months ago

Actually there were couple of problems, the first one is cache managemnt, since when the router skip an entire block the cache get lost. Also with llama3, is very tricky to not get NaN outputs while finetuning a model. I've seen that fa2 works way more reliably than sdpa, where it goes to NaN most of the time. All issues should be addressed with 1.2.0 now

ZHQSimon commented 2 months ago

I don't have a pretrained mod model atm, and it keeps giving me a nan related error, I'll do a small train those days

I think I know where's the problem. For the eval mode, the model generates output autoregressively which is a different logic from the training mode. It generates token one by one and there will be only one input token once. I logged the intermediate variables of original llama model. The first figure shows the log of prepared input while the second shows the log of autoregressive inference stage which is a one by one input. To solve it, I think you can refer to the original paper and they listed two approach for the routing logic during the inference stage. image image

I don't have a pretrained mod model atm, and it keeps giving me a nan related error, I'll do a small train those days

I think I know where's the problem. For the eval mode, the model generates output autoregressively which is a different logic from the training mode. It generates token one by one and there will be only one input token once. I logged the intermediate variables of original llama model. The first figure shows the log of prepared input while the second shows the log of autoregressive inference stage which is a one by one input. To solve it, I think you can refer to the original paper and they listed two approach for the routing logic during the inference stage. image image

Excuse me, how did you get these correct intermediate variables?