lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.55k stars 643 forks source link

Typical training results #61

Open edend10 opened 3 years ago

edend10 commented 3 years ago

Hi, great repo and thanks for sharing your work!

I'm trying your bird dataset example from the Colab with OpenAI's pretrained VAE. I wasn't able to get meaningful results so far on the Colab or on my own vm (Tesla T4 GPU).

13 epochs in of train_dalle.py and only seeing these kinds of results: image

On my vm I ran $ python train_dalle.py --image_text_folder /parent/to/birds/dataset/directory without changing any of the code (only replaced wandb with another experiment tracking framework, but I doubt that should make a difference)

Should the bird dataset work better with the pretrained VAE? Can you share some results or common training parameters/times/number of epochs?

lucidrains commented 3 years ago

@edend10 Hi Eden! Thanks for trying out the repository! I may have found a bug with the pretrained VAE wrapper, fixed in the latest commit https://github.com/lucidrains/DALLE-pytorch/blob/0.2.2/dalle_pytorch/vae.py#L82 :pray: I'll be training this myself this week, and ironing out any remaining issues (other than data and scale of course)

edend10 commented 3 years ago

Thanks for the response @lucidrains ! Ohh interesting, I'll check out the changes and try it out. Will look out for more updates!

CDitzel commented 3 years ago

what are those two mapping functions for anyway?

Are they just for transforming the pixel value range for the input data they just over at OpenAI?

AlexanderRayCarlson commented 3 years ago

Hello! Thank you for this excellent work. I seem to be getting something similar - abstract sorts of blue squares when training in the colab notebook. It looks like the package (0.2.2) is updated with the latest fix - is there anything else needed to do at the moment?

awilson9 commented 3 years ago

This is still happening for me as well on the pretrained VAE on 0.2.2

afiaka87 commented 3 years ago

This is an early output (2 epochs) from the new code that removes the normalization from train_dalle.py. Was that the necessary fix @lucidrains ?

DEPTH = 6
BATCH_SIZE = 8

media_images_image_1600_82d6d0f7

"a female mannequin" mannequin

Much more cohesive and a much stronger start now. No strange blueness, at the very least.

liuqk3 commented 3 years ago

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

afiaka87 commented 3 years ago

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

I should mention the dataset I'm using includes images released by OpenAI with their DALL-E. The mannequin image is not being generated from text alone, it's from an image text pair. Anyway, my point is that my dataset is bad and I'm mostly just messing around. It's probably the case that using images generated from DALL-E itself is bound to converge much quicker than usual.

I'm using the defaults in train_dalle.py except for the BATCH SIZE and DEPTH. Pretrained OpenAI VAE, top_k=0.9, and reversible=True. I tried mixing attention layers but it adds memory. (edit: I dont think it does actually. training with all four attn_types currently)

I'm working on creating a hyperparameter sweep with wandb currently. I think a learning rate of 2e-4 might be better for depth greater than 12 or so.

I still can't get a stable learning rate with 64 depth.

afiaka87 commented 3 years ago

Edit: You can find the whole training session here:

edit: edit: err here: https://wandb.ai/afiaka87/dalle-pytorch-openai-samples/reports/Training-on-OpenAI-DALL-E-Generated-Images--Vmlldzo1MTk2MjQ?accessToken=89u5e10c2oag5mlv46xm2sz6orkyqdlwjrsj8vd95oz8ke3ez6v8v2fh07klk6j1 I'm starting over because there have been updates to the main branch.

Original post:

"a professional high quality emoji of a spider starfish chimera . a spider imitating a starfish . a spider made of starfish . a professional emoji ."

starfish_spider_chimera

Left it running at 16 depth, 8 heads, batch size of 12 learning_rate=2e-4. The loss is going down at a steady consistent rate. (edit: just kidding! it seems to get stuck at around ~6.0 on this run. which seems high?)

DEPTH: 16 HEADS: 8 TOP_K: 0.85 EPOCHS: 27 SHUFFLE: True DIM_HEAD: 64 MODEL_DIM: 512 BATCH_SIZE: 12 REVERSIBLE: true TEXT_SEQ_LEN: 256 LEARNING_RATE: 0.0002 GRAD_CLIP_NORM: 0.5

afiaka87 commented 3 years ago

Edit:

Here, I used Weights & Biases to create a report. This link has all the images generated (every 100th iteration) for 27,831 total iterations

Edit: this one should work i think https://wandb.ai/afiaka87/dalle-pytorch-openai-samples/reports/Training-on-OpenAI-DALL-E-Generated-Images--Vmlldzo1MTk2MjQ?accessToken=89u5e10c2oag5mlv46xm2sz6orkyqdlwjrsj8vd95oz8ke3ez6v8v2fh07klk6j1

tommy19970714 commented 3 years ago

@afiaka87 Thank you for sharing your report of Weights & Biases! But I can't see the report because its project is private. Can you allow us to see it?

スクリーンショット 2021-03-11 17 55 43
afiaka87 commented 3 years ago

Hm, does this work? @tommy19970714 ? https://wandb.ai/afiaka87/dalle-pytorch-openai-samples/reports/Training-on-OpenAI-DALL-E-Generated-Images--Vmlldzo1MTk2MjQ?accessToken=89u5e10c2oag5mlv46xm2sz6orkyqdlwjrsj8vd95oz8ke3ez6v8v2fh07klk6j1

afiaka87 commented 3 years ago

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

Just for more info on the dataset itself, it is roughly 1,100,000 256x256 image-text pairs that were generated by OpenAI's DALL-E. They presented roughly ~30k unique text prompts of which they posted the top 32 of 512 generations on https://openai.com/blog/dall-e/. Many images were corrupt, and not every prompt has a full 32 examples, but the total number of images winds up being about 1.1 million. If you look at many of the examples on that page, you'll see that DALL-E (in that form at least), can and will make mistakes. These mistakes are also in this dataset. Anyway I'm just messing around having fun training and what not. This is definitely not going to produce a good model or anything.

There are also a large number of images in the dataset which are intended to be used with the "mask" feature. I don't know if that's possible yet in DALLE-pytorch though. Anyway, that can't be helping much.

One valuable thing I've taken from this is that it seems to take at least ~2000 iterations with a batch size of 4 to approach any sort of coherent reproductions. This number specifically probably varies, but in terms of "knowing when to start over", I would say rougly 3000 steps might be a good soft target.

tommy19970714 commented 3 years ago

Thank you for shareing your result! I will refer your parameters.

afiaka87 commented 3 years ago

@tommy19970714

I did a hyperparameter sweep with weights and biases. Forty Eight 1200 iteration runs of dalle-pytorch while varying Learning Rate, Depth and Heads, (minimizing the total loss at the end of each run).

https://github.com/lucidrains/DALLE-pytorch/issues/84#issue-830997522

afiaka87 commented 3 years ago

Most important thing to note here is that the learning rate actually needs to go up to about 0.0005 when dealing with ~26-32 depth

afiaka87 commented 3 years ago

I've done a much longer training session on that same dataset here:

https://github.com/lucidrains/DALLE-pytorch/issues/86