robvanvolt / DALLE-models

Here is a collection of checkpoints for DALLE-pytorch models, from where you can keep on training or start generating images.
MIT License
146 stars 13 forks source link

GumbelVQ 1 Epoch on Open Images Localized Annotations #7

Closed afiaka87 closed 3 years ago

afiaka87 commented 3 years ago

edit: Well this thing trained for a full 5 epochs and never made a single coherent generation. @rom1504 and I discussed this and I guess it's just really really hard to train this one. 🤷 Don't think a single RTX 2070 is going to cut it.

Hey - I'm trying to force myself to take a few days off and the holidays are coming up here in the states anyways so I'll probably be away from the discord/github for a bit.

use right away with python

pip install wandb requests
import wandb, requests
run = wandb.init()
artifact = run.use_artifact('dalle-pytorch-replicate/oi_gumbel_imgloss7/trained-dalle:v0', type='model')
artifact_dir = artifact.download()

import zipfile
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
    zip_ref.extractall(directory_to_extract_to)
openaiblog_openimages_bpe_url = "https://github.com/robvanvolt/DALLE-models/files/6735615/blogoimixer_4096.bpe.zip"
downloaded_obj = requests.get(openaiblog_openimages_bpe_url)
with open(openaiblog_openimages_bpe_url, "wb") as file:
      file.write(downloaded_obj.content)

Download Links

Info

pip3 install youtokentome # will install the cli tool `yttm`
yttm bpe --vocab_size=4096 --coverage=1.0 --model=blogoimixer_4096.bpe --data=blogoi_allcaps.txt
afiaka87 commented 3 years ago

Anyway I'm still training and I'll be away for awhile - but I'll submit a proper PR once I have a resumable DeepSpeed checkpoint ready for upload. in the meantime you can direct people to the public W&B training session. Seems like one epoch takes about a day. I'll try to finish another one but who knows.

https://wandb.ai/dalle-pytorch-replicate/oi_gumbel_imgloss7/runs/ml9u03ab

afiaka87 commented 3 years ago

also probably best to keep in mind that the captions in this one are spoken word transcribed and then simplified into a smaller vocabulary. They can sound kind of "dumb" for lack of a better word if you're reading them as though they're written and by a human. They are not - they were spoken, then transcribed by a machine, then preprocessed by a language AI. Folks from CLIP-land might have to readjust their expectations. Could be a good idea to provide sort of "caption template".

"In this image I can see and over here and there are and to the left and to the right and at the top and to the bottom "

where possible fillings would be

"In this image I can see cars and over here a fence and there are signs and to the left buildings and to the right a field and at the top the night sky and to the bottom a sidewalk"

I think i'm most interested in seeing how well it manages to learn to compose "left, right, top, and bottom" into the correct positions while maintaining visual fidelity. Should be interesting.

robvanvolt commented 3 years ago

The loss!! jealous

afiaka87 commented 3 years ago

Oh just to be clear the loss graph from the bottom of my post is from a different session where I used --loss_img_weight 1, and this one used the default value of 7.

Since we aren't logging both the text-predicts-text loss and the text-predicts-image loss - it's not really a meaningful comparison because the text-predicts-text loss presumably overfits and goes down very quickly unless your captions dataset is as diverse as say, CogView.

edit - i'm still foggy today; not sure if I said all that right.

afiaka87 commented 3 years ago

The loss!! jealous W B Chart 6_29_2021, 2 24 44 PM

afiaka87 commented 3 years ago

edit @robvanvolt - in case you're confused, just decided the top post had too much info. moving this down here.

Details:

I found in another experiment that using a little bit of full attention acutally helps a good deal with the GumbelVQ -

Here; I think this shows it. Two runs. Exact same parameters except for this:

desert-plasma-29
full,axial_row,axial_col,conv_like,full,axial_row,axial_col,conv_like
depth 8

logical-jazz-19
depth 12 
axial_row,axial_row,axial_col,axial_row,axial_row,axial_row,axial_col,axial_row,axial_row,axial_row,axial_col,conv_like

Per usual, using axial attention increase my samples per second from 11 to 22 - so the x-axis here is time instead of step.

W B Chart 6_29_2021, 1 09 25 PM

As you can see - lower depth of 8 with 2 layers of full attention outperforms 12 layers of sparse axial attention.

afiaka87 commented 3 years ago

@robvanvolt probably worth mentioning (since i havent seen generations from it yet) that it's perhaps not super likely to work great. this thing might need a bigger DALL-E than I can train. I never got to anything resembling the caption in my fp32 training once. And it was very often just gumbelvq non-sense output. It seems to be a very large codebook so the transformer needs a lot more time to find all the codes in it.

afiaka87 commented 3 years ago

edit: Well this thing trained for a full 5 epochs and never made a single coherent generation. @rom1504 and I discussed this and I guess it's just really really hard to train this one. 🤷 Don't think a single RTX 2070 is going to cut it.