Aleph-Alpha / magma

MAGMA - a GPT-style multimodal model that can understand any combination of images and language. NOTE: The freely available model from this repo is only a demo. For the latest multimodal and multilingual models from Aleph Alpha check out our website https://app.aleph-alpha.com
MIT License
475 stars 55 forks source link

AssertionError: Parameter with name: lm.transformer.wte.weight occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour #32

Closed masoud-monajati closed 2 years ago

masoud-monajati commented 2 years ago

Hi,

I'd like to rerun the code using gpt-neo125M or gpt2-med instead of gpt-nep2.7B ad I'm getting this error?

AssertionError: Parameter with name: lm.transformer.wte.weight occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour.

Any idea why this issue exist for other language model?

UCCME commented 2 years ago

halo,Do u meet the same problem ? it sems that we have the same problem when loading the checkpoint to the model. size mismatch for lm.lm_head.weight: copying a param with shape torch.Size([50400, 4096]) from checkpoint, the shape in current model is torch.Size([50258, 4096]). this question is very strange. I didn't change any code, and I found that the model and the config have some mismatch.

masoud-monajati commented 2 years ago

I didn't meet this issue yet. check the language model.py code that loads the model. some parameters might be hard coded and that could cause some errors

UCCME commented 2 years ago

thanks for answering. can u show me the size of (self.lm.lm_head) (in the model and the checkpoint) I wonder whether I download the wrong checkpoint.......

masoud-monajati commented 2 years ago

I haven't used any provided checkpoint yet and only loading a gpt2-med language model

Yuuxii commented 2 years ago

Hello, I also got the same error when I try to use a smaller model. Did you fix the error now?

Yuuxii commented 2 years ago

Hello, I also got the same error when I try to use a smaller model. Did you fix the error now?

Setting config.jax = True solved my error :)