lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.55k stars 643 forks source link

VQGanVAE1024 "vae must be an instance of DiscreteVAE" #87

Closed afiaka87 closed 3 years ago

afiaka87 commented 3 years ago

@lucidrains I believe the relevant line is here:

https://github.com/lucidrains/DALLE-pytorch/blob/2268864941d8eef2ba73a4488fe05673d447d493/dalle_pytorch/dalle_pytorch.py#L306

I tried adding it in myself, but it needs the taming imports and I'm not familiar with those.

afiaka87 commented 3 years ago

Here's a branch where I've added it myself:

https://github.com/lucidrains/DALLE-pytorch/pull/88

After remembering to changed the dim parameter to 1024, I'm now getting this message:

AssertionError: Sequence length (4096) must be less than the maximum sequence length allowed (1024)

If I then hardcode num_text_tokens = 10000 instead of the VOCAB_SIZE that it uses from simple_tokenizer.py, I get this:

/opt/conda/conda-bld/pytorch_1614378083779/work/aten/src/ATen/native/cuda/Indexing.cu:662: indexSelectLargeIndex: block: 

Which I think is just some sort of problem with my custom DataLoader. I'll keep investigating, but I think there basically just needs to be an assert on the dim parameter when you're using the pretrained VAE's in addition to my pull request.

lucidrains commented 3 years ago

@afiaka87 ack yes, you are correct! just fixed the remaining issue hopefully https://github.com/lucidrains/DALLE-pytorch/commit/187d62ca761a875c31c5834202921f1b676a6243

lucidrains commented 3 years ago

@afiaka87 there's the possibility this could train a lot faster, given it is 256 tokens instead of 1024! (it has an extra layer of downsampling)

afiaka87 commented 3 years ago

@lucidrains have you gotten it working? i was debugging it for awhile but then i realized my actual problem was that my cuda installation on my cloud instance was borked.

edit: gonna pull your changes down and try them out now

lucidrains commented 3 years ago

@afiaka87 yup, I just verified it works, at least that I can get the codebook indices and reconstruct them. I also did a quick run in conjunction with DALL-E and it seems to at least pass my local test cases (though obviously haven't trained it yet, so whether it converges is another matter)

lucidrains commented 3 years ago

@afiaka87 Also realized, that if this VAE works well with DALL-E, it opens up the possibility of training 512 by 512 images :) - but right now I have it set to 256 x 256

afiaka87 commented 3 years ago

i'll be sure to do another hyperparameter sweep on this one if I can get it working.

afiaka87 commented 3 years ago

@afiaka87 wow, running at 32 depth, 16 heads (reversible OFF) and it only needs 11 GiB of VRAM!

afiaka87 commented 3 years ago

@lucidrains it's significantly faster as well. Perhaps converging too quickly. need to check. I'll post back here with a live notebook of the session.

afiaka87 commented 3 years ago

@lucidrains here you go! https://wandb.ai/afiaka87/full_attention/reports/Live-Training-of-the-VQGanVAE1024-with-DALLE-pytorch--Vmlldzo1MzQ3MDE

afiaka87 commented 3 years ago

seems to have converged too quickly.

lucidrains commented 3 years ago

@afiaka87 haha yea, agreed - something doesn't look right

afiaka87 commented 3 years ago

starting over. gonna have a non-live go at this lol. be back with my results in a few hours probably.

lucidrains commented 3 years ago

@afiaka87 ok! i'll keep thinking on whether there may be a bug somewhere

afiaka87 commented 3 years ago

@lucidrains just for your sanity, i'm already achieving significantly better results just switching back to 3e-4 for my loss. Probably should have done that in the first place. Found 5e-4 worked well for OpenAI's vae in another post, but clearly that's not the case for this vae. Anyway, my dataset has corrupted pngs in it so i'm dealing with that right now, but I just had a much more stable training session: -depth=32 --heads=8 --learning_rate=3e-4 --batch_size=32 VRAM usage: 22 GiB falls to ~4.5 to ~5.5 and starts to hover around there in the same way that OpenAI's VaE would go to around 5.5 to 6.5 during training. But honestly? I've never been able to use a batch size of 32 until now. Even on this pricey A100 with 40GB of RAM and reversible set to True. Just doesn't work. So this is crazy.

lucidrains commented 3 years ago

@afiaka87 yes, you are experiencing the Achilles heel of attention networks - they explode in quadratic costs with token sizes. the German group was able to get a high quality VAE by combining it with GAN training, so their token length is 1/4th of OpenAIs, leading to 1/16th the cost of training DALL-E (token length of 256 vs 1024)

lucidrains commented 3 years ago

@afiaka87 the VQGAN-VAE isn't without its own problems though https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb#scrollTo=3aNc2aEEs0BY You'll notice that it misses certain details in the reconstruction (pinecone)

afiaka87 commented 3 years ago

Not a panacea, but certainly a welcome improvement.

afiaka87 commented 3 years ago

@lucidrains I'm pretty past out of my depth at this point, but I know how to make a hyperparameter sweep with wandb.ai and have access to an A100 for the next day. This is where I'm out of my depth. The loss is dropping...too quickly? I would assume that means we're moving too fast and to decrease the learning rate. But in my hyperparameter sweeps from before, I was minimizing all hyperparameters on the loss. Doing that here will just give me more of these bad results, but maximizing isnt what I want either.

Any ideas on how to cut down search time for a good learning rate?

Edit: I'm planning on just pinning to 64 depth and 16 heads (is that the open ai count?) and running the first epoch of probably ~50 jobs at different learning rates. so knowing what to optimize for is pretty helpful.

lucidrains commented 3 years ago

If the loss is dropping precipitously, that's usually a bug, probably with the masking, although I'm not sure why you wouldnt see the same bug when training with OpenAIs VAE 🤔

afiaka87 commented 3 years ago

@lucidrains I think (?) you can see all my runs here. https://wandb.ai/afiaka87/vqgan_frac_1

I'll be dealing with the occassional hiccup and changing parameters and what not frequently. But in general I'm trying to stick with what I already know - 32 depth, 8 heads. I'll increase to 64 once i've stabilized that.

lucidrains commented 3 years ago

@afiaka87 are you using any of the sparse attentions? or just full attention throughout?

afiaka87 commented 3 years ago

@lucidrains I can't use sparse unfortunately. They haven't updated to cuda 11.0.

Everything til now has been 'full', 'axial_row', 'axial_col', 'conv_live'. I'm going to switch back to just 'full' for run 7 as I had trouble training with axial attention on openai's vae as well.

lucidrains commented 3 years ago

@afiaka87 ok, full (and sparse) are probably the safest

i'll double check the axial sparse attention code tomorrow morning

afiaka87 commented 3 years ago

@afiaka87 ok, full (and sparse) are probably the safest

i'll double check the axial sparse attention code tomorrow morning

Awesome yeah from what I remember it just seemed like it was chunking a bunch of blue blocks into various columns and rows.

afiaka87 commented 3 years ago

@lucidrains removing axial and conv_like attention has definitely helped.

lucidrains commented 3 years ago

@afiaka87 awesome :) if you could put one back and see which one fails the run again (loss drops too much), that would be another big help :pray:

lucidrains commented 3 years ago

@afiaka87 perhaps they both have leakage lol

afiaka87 commented 3 years ago

next run will be ('full', 'axial_row')

afiaka87 commented 3 years ago

@lucidrains and then i'll do ('full', 'axial_col')

lucidrains commented 3 years ago

@afiaka87 ok, and a ('full', 'conv_like') too, if you could :pray:

TheodoreGalanos commented 3 years ago

I can confirm that the loss dropped really quick when using the whole package of attention. Went from 7 to 0.005 after about 400 image/text pairs.

afiaka87 commented 3 years ago

I can confirm that the loss dropped really quick when using the whole package of attention. Went from 7 to 0.005 after about 400 image/text pairs.

Always good to know it's working on someone else's machine as well, at least. You can check the link above for my current live runs.

@lucidrains 'axial_row' seems more stable.

lucidrains commented 3 years ago

@afiaka87 so the culprit is probably conv_like? (wouldn't be surprised)

afiaka87 commented 3 years ago

@lucidrains yeah, it's fine. Moving on to axial_col.

TheodoreGalanos commented 3 years ago

I can confirm that the loss dropped really quick when using the whole package of attention. Went from 7 to 0.005 after about 400 image/text pairs.

Always good to know it's working on someone else's machine as well, at least. You can check the link above for my current live runs.

@lucidrains 'axial_row' seems more stable.

wish I could run in parallel but my colab hates me right now

afiaka87 commented 3 years ago

Yeah I have to resort to these pricey (but somehow cheapest on the market?) instances from vast.ai.

afiaka87 commented 3 years ago

This is soo much faster. Would've waited an hour before on these same parameters.

edit: got a little out of hand with the number of 'o's originally.

TheodoreGalanos commented 3 years ago

it is insane it's doing 400 batches in like a couple of minutes :o

lucidrains commented 3 years ago

increase the image size to 512x512, and it'll be just as slow as OpenAI's VAE lol

lucidrains commented 3 years ago

it's all about this elephant in the room, the N squared cost of attention

TheodoreGalanos commented 3 years ago

I'm curious to see if your transganformer will be a possible candidate for this.

afiaka87 commented 3 years ago

@lucidrains alright, this is stable too. moving to conv_like. And oh look at that we found a corrupted png

rm -rf --no-preserve-root 3693640717_0.png

lucidrains commented 3 years ago

@TheodoreGalanos GAN + text conditioning is a totally different approach than DALL-E, but i'm excited to see if the duplex attention from GansFormer will work better for conditioning image generation from text

lucidrains commented 3 years ago

@afiaka87 i'm pretty sure the bug is in the conv like attention, even before you run the experiment lol

TheodoreGalanos commented 3 years ago

@TheodoreGalanos yea, the GAN + text conditioning is a totally different approach than DALL-E, but i'm excited to see if the duplex attention from GansFormer will work better for conditioning image generation from text

that's alright, I honestly don't even think DALL-E was the breakthrough. I really love the work being put into parts that are closer to CLIP + cool_mode_here side of things. Even though I still have hopes for tine dalles working in practical situations :)

afiaka87 commented 3 years ago

@lucidrains this seems to be running fine... confusing.

lucidrains commented 3 years ago

@afiaka87 oh what lol

afiaka87 commented 3 years ago

i'll eat those words. it's dropping.

lucidrains commented 3 years ago

@afiaka87 woohoo! i knew it, thank you thank you :pray: