Closed afiaka87 closed 3 years ago
Here's a branch where I've added it myself:
https://github.com/lucidrains/DALLE-pytorch/pull/88
After remembering to changed the dim parameter to 1024, I'm now getting this message:
AssertionError: Sequence length (4096) must be less than the maximum sequence length allowed (1024)
If I then hardcode num_text_tokens = 10000 instead of the VOCAB_SIZE
that it uses from simple_tokenizer.py
, I get this:
/opt/conda/conda-bld/pytorch_1614378083779/work/aten/src/ATen/native/cuda/Indexing.cu:662: indexSelectLargeIndex: block:
Which I think is just some sort of problem with my custom DataLoader. I'll keep investigating, but I think there basically just needs to be an assert on the dim parameter when you're using the pretrained VAE's in addition to my pull request.
@afiaka87 ack yes, you are correct! just fixed the remaining issue hopefully https://github.com/lucidrains/DALLE-pytorch/commit/187d62ca761a875c31c5834202921f1b676a6243
@afiaka87 there's the possibility this could train a lot faster, given it is 256 tokens instead of 1024! (it has an extra layer of downsampling)
@lucidrains have you gotten it working? i was debugging it for awhile but then i realized my actual problem was that my cuda installation on my cloud instance was borked.
edit: gonna pull your changes down and try them out now
@afiaka87 yup, I just verified it works, at least that I can get the codebook indices and reconstruct them. I also did a quick run in conjunction with DALL-E and it seems to at least pass my local test cases (though obviously haven't trained it yet, so whether it converges is another matter)
@afiaka87 Also realized, that if this VAE works well with DALL-E, it opens up the possibility of training 512 by 512 images :) - but right now I have it set to 256 x 256
i'll be sure to do another hyperparameter sweep on this one if I can get it working.
@afiaka87 wow, running at 32 depth, 16 heads (reversible OFF) and it only needs 11 GiB of VRAM!
@lucidrains it's significantly faster as well. Perhaps converging too quickly. need to check. I'll post back here with a live notebook of the session.
seems to have converged too quickly.
@afiaka87 haha yea, agreed - something doesn't look right
starting over. gonna have a non-live go at this lol. be back with my results in a few hours probably.
@afiaka87 ok! i'll keep thinking on whether there may be a bug somewhere
@lucidrains just for your sanity, i'm already achieving significantly better results just switching back to 3e-4 for my loss. Probably should have done that in the first place. Found 5e-4 worked well for OpenAI's vae in another post, but clearly that's not the case for this vae. Anyway, my dataset has corrupted pngs in it so i'm dealing with that right now, but I just had a much more stable training session:
-depth=32 --heads=8 --learning_rate=3e-4 --batch_size=32
VRAM usage: 22 GiB
falls to ~4.5 to ~5.5 and starts to hover around there in the same way that OpenAI's VaE would go to around 5.5 to 6.5 during training. But honestly? I've never been able to use a batch size of 32 until now. Even on this pricey A100 with 40GB of RAM and reversible set to True. Just doesn't work. So this is crazy.
@afiaka87 yes, you are experiencing the Achilles heel of attention networks - they explode in quadratic costs with token sizes. the German group was able to get a high quality VAE by combining it with GAN training, so their token length is 1/4th of OpenAIs, leading to 1/16th the cost of training DALL-E (token length of 256 vs 1024)
@afiaka87 the VQGAN-VAE isn't without its own problems though https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb#scrollTo=3aNc2aEEs0BY You'll notice that it misses certain details in the reconstruction (pinecone)
Not a panacea, but certainly a welcome improvement.
@lucidrains I'm pretty past out of my depth at this point, but I know how to make a hyperparameter sweep with wandb.ai and have access to an A100 for the next day. This is where I'm out of my depth. The loss is dropping...too quickly? I would assume that means we're moving too fast and to decrease the learning rate. But in my hyperparameter sweeps from before, I was minimizing all hyperparameters on the loss. Doing that here will just give me more of these bad results, but maximizing isnt what I want either.
Any ideas on how to cut down search time for a good learning rate?
Edit: I'm planning on just pinning to 64 depth and 16 heads (is that the open ai count?) and running the first epoch of probably ~50 jobs at different learning rates. so knowing what to optimize for is pretty helpful.
If the loss is dropping precipitously, that's usually a bug, probably with the masking, although I'm not sure why you wouldnt see the same bug when training with OpenAIs VAE 🤔
@lucidrains I think (?) you can see all my runs here. https://wandb.ai/afiaka87/vqgan_frac_1
I'll be dealing with the occassional hiccup and changing parameters and what not frequently. But in general I'm trying to stick with what I already know - 32 depth, 8 heads. I'll increase to 64 once i've stabilized that.
@afiaka87 are you using any of the sparse attentions? or just full attention throughout?
@lucidrains I can't use sparse unfortunately. They haven't updated to cuda 11.0.
Everything til now has been 'full', 'axial_row', 'axial_col', 'conv_live'. I'm going to switch back to just 'full' for run 7 as I had trouble training with axial attention on openai's vae as well.
@afiaka87 ok, full (and sparse) are probably the safest
i'll double check the axial sparse attention code tomorrow morning
@afiaka87 ok, full (and sparse) are probably the safest
i'll double check the axial sparse attention code tomorrow morning
Awesome yeah from what I remember it just seemed like it was chunking a bunch of blue blocks into various columns and rows.
@lucidrains removing axial and conv_like attention has definitely helped.
@afiaka87 awesome :) if you could put one back and see which one fails the run again (loss drops too much), that would be another big help :pray:
@afiaka87 perhaps they both have leakage lol
next run will be ('full', 'axial_row')
@lucidrains and then i'll do ('full', 'axial_col')
@afiaka87 ok, and a ('full', 'conv_like')
too, if you could :pray:
I can confirm that the loss dropped really quick when using the whole package of attention. Went from 7 to 0.005 after about 400 image/text pairs.
I can confirm that the loss dropped really quick when using the whole package of attention. Went from 7 to 0.005 after about 400 image/text pairs.
Always good to know it's working on someone else's machine as well, at least. You can check the link above for my current live runs.
@lucidrains 'axial_row' seems more stable.
@afiaka87 so the culprit is probably conv_like
? (wouldn't be surprised)
@lucidrains yeah, it's fine. Moving on to axial_col.
I can confirm that the loss dropped really quick when using the whole package of attention. Went from 7 to 0.005 after about 400 image/text pairs.
Always good to know it's working on someone else's machine as well, at least. You can check the link above for my current live runs.
@lucidrains 'axial_row' seems more stable.
wish I could run in parallel but my colab hates me right now
Yeah I have to resort to these pricey (but somehow cheapest on the market?) instances from vast.ai.
This is soo much faster. Would've waited an hour before on these same parameters.
edit: got a little out of hand with the number of 'o's originally.
it is insane it's doing 400 batches in like a couple of minutes :o
increase the image size to 512x512, and it'll be just as slow as OpenAI's VAE lol
it's all about this elephant in the room, the N squared cost of attention
I'm curious to see if your transganformer will be a possible candidate for this.
@lucidrains alright, this is stable too. moving to conv_like. And oh look at that we found a corrupted png
rm -rf --no-preserve-root 3693640717_0.png
@TheodoreGalanos GAN + text conditioning is a totally different approach than DALL-E, but i'm excited to see if the duplex attention from GansFormer will work better for conditioning image generation from text
@afiaka87 i'm pretty sure the bug is in the conv like attention, even before you run the experiment lol
@TheodoreGalanos yea, the GAN + text conditioning is a totally different approach than DALL-E, but i'm excited to see if the duplex attention from GansFormer will work better for conditioning image generation from text
that's alright, I honestly don't even think DALL-E was the breakthrough. I really love the work being put into parts that are closer to CLIP + cool_mode_here side of things. Even though I still have hopes for tine dalles working in practical situations :)
@lucidrains this seems to be running fine... confusing.
@afiaka87 oh what lol
i'll eat those words. it's dropping.
@afiaka87 woohoo! i knew it, thank you thank you :pray:
@lucidrains I believe the relevant line is here:
https://github.com/lucidrains/DALLE-pytorch/blob/2268864941d8eef2ba73a4488fe05673d447d493/dalle_pytorch/dalle_pytorch.py#L306
I tried adding it in myself, but it needs the taming imports and I'm not familiar with those.