lucidrains / muse-maskgit-pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
MIT License
862 stars 81 forks source link

Trained weights #20

Open ArielReplicate opened 1 year ago

ArielReplicate commented 1 year ago

Hi any chance you'll upload any trained weights? Really want to try this out. It looks amazing!

greboide commented 1 year ago

2

agwmon commented 1 year ago

3

ZeroCool940711 commented 1 year ago

In case anyone want to check it out, I have been trying to train Muse for a couple of days now and I will be uploading the weights I get overtime to Huggingface and keeping them in this repo, so far I have trained the first VAE for 365K steps on the INE dataset and using a training script I put together based on the readme.md on this repo. The script can be found on my fork of lucidrains repo which you can find here, it still needs a lot of work but as I said you can use it to train the first VAE, the other parts of the script are still broken or not done correctly so I would appreciate any help with it as I still dont understand properly how I should be training the Base and Super Resolution VAEs, I'm pretty sure I am missing the training loop for them but not sure where to put it or how to go about it. I hope someone finds this helpful :)

muximuxi commented 1 year ago

Thank you for your work! when I load the weights you released in Huggingface ,It has made an error in base_maskgit.load("vae.pt"): │ │ path = Path(path) │ │ 472 │ │ assert path.exists() │ │ 473 │ │ state_dict = torch.load(str(path)) │ │ ❱ 474 │ │ self.load_state_dict(state_dict) │ │ 475 │ │ │ 476 │ @torch.no_grad() │ │ 477 │ @eval_decorator │ │ │ │ torch/nn/modules/module.py:1497 in load_state_dict │ │ │ │ 1494 │ │ │ │ │ │ ', '.join('"{}"'.format(k) for k in missing_k │ │ 1495 │ │ │ │ 1496 │ │ if len(error_msgs) > 0: │ │ ❱ 1497 │ │ │ raise RuntimeError('Error(s) in loading state_dict for {} │ │ 1498 │ │ │ │ │ │ │ self.class.name, "\n\t".join(e │ │ 1499 │ │ return _IncompatibleKeys(missing_keys, unexpected_keys) │ │ 1500 │ ╰──────────────────────────────────────────────────────────────────────────────╯ RuntimeError: Error(s) in loading state_dict for MaskGit: Missing key(s) in state_dict: "vae.enc_dec.encoders.0.weight", ......(many weights show here)

ZeroCool940711 commented 1 year ago

I also get some errors when trying to load it on the base_maskgit, when using the ema vae it gives missing keys, probably the same ones you are getting on the error you got, and the normal vae does load but then I can't train the superres VAE no matter what I try. I would appreciate if anyone could help me figure out what's wrong.

isamu-isozaki commented 1 year ago

@ZeroCool940711 Sound good! And thanks for the training script. Lemme look into that today

isamu-isozaki commented 1 year ago

@ZeroCool940711 Hi. I'm gonna double check but I think we can just get the pre-trained weights from mask git by google for the tokenizers/vaes. The link is here will make sure. Also, we can do the paella idea too here

ZeroCool940711 commented 1 year ago

@isamu-isozaki the google pre-trained weights look interesting, specially if we can reuse them. What would we be doing then, fine-tuning those pre-trained weights from Google on other datasets to save time on the initial part of the training and adding new data to it?

isamu-isozaki commented 1 year ago

@ZeroCool940711 I think for now we can start training the base transformer part imo or at least try. Since I think the tokenizer should be good enough to represent images. We can def test if you want! Then I'm guessing we can start training on 256x256. Then once we get good results, we can use that base transformer to train the 512x512 model.

isamu-isozaki commented 1 year ago

full model

isamu-isozaki commented 1 year ago

To clarify, for mask git, they have the transformer weights for both but I think the main change from mask git to muse is that in both transformers, a text encoder was added and for the super res one, the lower dim tokens are added. So will prob need tuning

isamu-isozaki commented 1 year ago

But will prob switch to paella stuff in the end

ZeroCool940711 commented 1 year ago

Do you think you can help to improve the training script we have so we can use the weights from google as you mentioned and train the base transformer as well as the super res with it? As of right now I can't train them with the script we have, we might also have to add a cli argument to be able to load the weights from google so we don't have to hard code them, right now there are just a few weights but it might change in the future so its better to think long term.

isamu-isozaki commented 1 year ago

@ZeroCool940711 yup for sure! Oh and btw I'm Chad Kensington on discord

isamu-isozaki commented 1 year ago

will make prs within the week

ZeroCool940711 commented 1 year ago

Sure, take your time with the PR, and nice to know that's your username there, I was wondering how to contact you to not spam this issue too much, so, let's continue discussing this on discord and just talk about necessary stuff here on this issue, hopefully the next message we share here is a link to a proper trained model or weight that others can use :)

isamu-isozaki commented 1 year ago

Sounds good!

captainbadass2 commented 1 year ago

In case anyone want to check it out, I have been trying to train Muse for a couple of days now and I will be uploading the weights I get overtime to Huggingface and keeping them in this repo, so far I have trained the first VAE for 365K steps on the INE dataset and using a training script I put together based on the readme.md on this repo. The script can be found on my fork of lucidrains repo which you can find here, it still needs a lot of work but as I said you can use it to train the first VAE, the other parts of the script are still broken or not done correctly so I would appreciate any help with it as I still dont understand properly how I should be training the Base and Super Resolution VAEs, I'm pretty sure I am missing the training loop for them but not sure where to put it or how to go about it. I hope someone finds this helpful :

when i try and load the maskgit checkpoint i get a load of errors, first few are shown below. am i doing something wrong?


base_maskgit = MaskGit(
    vae = vae,                 # vqgan vae
    transformer = transformer, # transformer
    image_size = 128,          # image size
    cond_drop_prob = 0.25,     # conditional dropout, for classifier free guidance
).cuda()

base_maskgit.load('./maskgit.3461001.pt')

_Missing key(s) in state_dict: "transformer.transformer_blocks.layers.2.0.null_kv", "transformer.transformer_blocks.layers.2.0.q_scale", "transformer.transformer_blocks.layers.2.0.k_scale", "transformer.transformer_blocks.layers.2.0.norm.gamma", "transformer.transformer_blocks.layers.2.0.norm.beta", "transformer.transformer_blocks.layers.2.0.toq.weight",