mehdidc / feed_forward_vqgan_clip

Feed forward VQGAN-CLIP model, where the goal is to eliminate the need for optimizing the latent space of VQGAN for each input prompt
MIT License
136 stars 18 forks source link

New Checkpoint Idea #22

Closed afiaka87 closed 2 years ago

afiaka87 commented 2 years ago

The Gumbel VQGAN from ruDALLE may prove to be the best VQGAN available. Might be worth training a checkpoint on.

mehdidc commented 2 years ago

Hi @afiaka87, thanks would be really cool indeed, will look into it. I was also thinking, that since the text to image part is also available, whether it would it be possible to distill it into a feed-forward model, since generation with DALL-e like transformers is done token by token and thus is still quite slow. We can generate as many images as we want since we have the model and use it as training data. I have seen you had a similar idea by training stylegan3 on blog captions dataset, have you had any success already? what would be especially interesting for me in this context is to be able to generate different images given the same text

afiaka87 commented 2 years ago

@mehdidc

I wasn't able to train it completely but indeed it worked out. The checkpoint is fast enough to do interpolation at 30 fps on my 2070.

The original dalle dVAE clearly learned a bit of aliasing too which resulted in some strange artifacts being learned in the style gan 3. I've since switched to training on the Country 221 dataset (64k balanced images dataset of different countries wildlife, landscape, people, etc). It's going pretty well.

At a given seed, you can apply a variety of transformations to specific layers in order to display tons of variants near that latent. In practice this winds up being e.g. tree placement, time-of-day, skin color. I think we could figure out a way to automate this or incorporate it into distillation somehow to achieve the desired diversity-per-text

Using a stylegan3 as a backbone for this is very intriguing to me because of the FID you can get with such a fast inference time.

afiaka87 commented 2 years ago

Here's a (far too long) recording of me messing with the GUI provided in the stylegan3 repo. You can see me flip through styles with the "StyleMixer" occasionally.

https://vimeo.com/639405448

This dataset was highly homogeneous though and it's tough to tell from that how well generalization would work on the vastly more complex dataset ruDALLE would generate from however many captions.

afiaka87 commented 2 years ago

and here's some interpolations from my current checkpoint on Country221. It has a while to go before it matches the quality of the checkpoints available from nvidia, but just to give you an idea.

0-127_sg_big_grid_latest mp4-high

mehdidc commented 2 years ago

@afiaka87 Thanks for the links, cool that it already seems to work

"This dataset was highly homogeneous though and it's tough to tell from that how well generalization would work on the vastly more complex dataset ruDALLE would generate from however many captions."

Indeed, I was also wondering about it, because stylegan architectures are known to have problems with diverse datasets (e.g. https://old.reddit.com/r/MachineLearning/comments/hwf093/d_what_are_the_recent_papers_that_address_the/fz0imnj/), I don't know if stylegan3 would be different, I haven't read the paper properly yet

"At a given seed, you can apply a variety of transformations to specific layers in order to display tons of variants near that latent. In practice this winds up being e.g. tree placement, time-of-day, skin color. I think we could figure out a way to automate this or incorporate it into distillation somehow to achieve the desired diversity-per-text"

Wow this is really cool! what kind of transformations did you use to have such modifications (tree placement, time-of-day, skin color etc)? is it like shown on the GUI in https://vimeo.com/639405448 ?

mehdidc commented 2 years ago

@afiaka87 By the way I tried ruDALLE, one thing which is clear is that with ruDALLE I see some watermarks appearing, I guess this has to do with the dataset they trained it on. will push it into the code to have it available.

ruDALLE

afiaka87 commented 2 years ago

@mehdidc Great!

Ignoring the watermarks - the outputs seem to capture a bit more detail than the prior VQGAN attempts. I'm certain many will find this quite useful considering the inference time is low enough to experiment with prompt engineering.

afiaka87 commented 2 years ago

(aside) I have largely given up after seeing consistent and continued success from guided diffusion at this task with "none" [citation needed] of the stability/over-fitting issues presented by StyleGAN 3. SG3 is interesting because they don't actually improve FID (they match StyleGAN 2) but rather their motivation is to have a purely continuous representation (wavelets, I believe) of the visual signal throughout the network to avoid the learning of features which are clearly pixel-level artifacts including aliasing and generation of fine detail. This works pretty well - until it doesn't. My experience with this is that if your dataset still has any sort of JPEG artifacts/aliasing/what-have-you, the network still learns these things just fine and they are perhaps even more pronounced than before because rather than appearing to be "normal" aliasing/artifacts - you see the distinctly mandelbrot-esque transparency over certain regions.

The inference times are of course the impressive bit and, speculating, I wager that nvidia ultimately would like to pair this technology with the offerings they provide for game developers. The representations learned by StyleGAN3 where they explicitly draw pixels onto a malleable grid of some sort may lend themselves well to the 3-d representations we find in videogames. Further, with a videogame your dataset is well-defined as effectively "the graphics of the engine as a player uses it". I believe they already do per-game tuning for their realtime super-resolution effort DLSS 2.0 (which is incredibly impressive) via update through their driver software.