lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.07k stars 766 forks source link

Faster inference via efficient sampling ("Elucidating the Design Space of Diffusion-Based Generative Models") #81

Closed Birch-san closed 2 years ago

Birch-san commented 2 years ago

Hi, thanks for your fantastic work so far. 🙂

Elucidating the Design Space of Diffusion-Based Generative Models, published ~3 weeks ago describes a new way to design diffusion models (i.e. a more modular way), enabling changes to sampling and training (making sampling faster).

Probably you're aware of this already (@crowsonkb and @johnowhitaker are attempting to reproduce the work).

@johnowhitaker explains the paper in this video, and demonstrates the model in this colab, achieving face-like images within 40 minutes of training.

any plans to incorporate this diffusion approach into imagen-pytorch, or is it still to early?

lucidrains commented 2 years ago

@Birch-san Hi! I too am reading this paper! Tero Karras needs no introduction

I think I will initially test their proposals and open source the training methods at https://github.com/lucidrains/denoising-diffusion-pytorch before attempting to integrate it into Imagen

Even if I do port some of the lessons over to Imagen (or DALL-E2), it will be done in a separate class, perhaps ElucidatedImagen, since I want the repository to be representative of the original paper to some degree

lucidrains commented 2 years ago

actually, drawing the noise from a log normal seems to be something i can add as a setting without too much extra complexity! apparently the paper claims it works synergistically with the p2 loss weighting too

Birch-san commented 2 years ago

Outstanding! I'd certainly be interested in trying that out. Does that mean that the other ideas of the paper (refactoring into a modular design) can be skipped; the novel sampling technique is the consequential part, and can be added relatively simply to a traditional diffusion model?

lucidrains commented 2 years ago

ugh, i was wrong, this may not be straightforward

everything is now centered around sigma and the new sigma data hyperparam

lucidrains commented 2 years ago

the only new thing that can be immediately applied is the non-leaky augmentations, but it would require the unet be modified to accept additional conditioning

crowsonkb commented 2 years ago

You can draw sigmas from a lognormal and then convert them to equivalent DDPM linear or cosine timesteps, that's the tack I took to use a lognormal or log-logistic sampling density with v-diffusion as a baseline.

Does that mean that the other ideas of the paper (refactoring into a modular design) can be skipped; the novel sampling technique is the consequential part

No IMO, the preconditioner is also consequential: when you set sigma_data=1 it gives you the same model target as v-diffusion but the opposite sign, and has the same beneficial properties. When you set sigma_data other than 1 it gives you -v still but as though you had divided the data by sigma_data first. i.e. their sigma_data=0.5 recommendation for images is the same as rescaling images to -2 to 2 with v-diffusion.

lucidrains commented 2 years ago

@crowsonkb Katherine! You probably already know, but your technique got used in Parti!

Screenshot from 2022-06-25 19-10-01

samedii commented 2 years ago

Might be good to be aware that the noise schedule in the paper is quite different from a noise schedule that was learned from data in https://arxiv.org/pdf/2107.00630.pdf. It will of course depend on the task, in this case it was unconditional image generation on CIFAR-10.

image Left contains a comparison of different schedules vs the learned (blue). Middle is from "Elucidating" with rho=7. Right is from "Elucidating" with rho=-3 (not recommended in the paper)

crowsonkb commented 2 years ago

Might be good to be aware that the noise schedule in the paper is quite different from a noise schedule that was learned from data in https://arxiv.org/pdf/2107.00630.pdf. It will of course depend on the task, in this case it was unconditional image generation on CIFAR-10.

I should note that the Variational Diffusion Models learned noise schedules are learned to minimize loss variance during training when predicting eps and they would probably be pretty different if predicting -v, because v is a higher variance target to begin with especially at the higher noise levels. Also I am not sure they ever showed that their learned variance-minimizing schedules are actually good to use in inference (but their models are continuous so you can use whatever noise schedule in inference). Karras et al decouple the inference noise schedule (the thing with rho=7) from the training noise schedule (the lognormal sampling density) and do not learn either.

samedii commented 2 years ago

Thanks for explaining, I am trying to implement learning the noise schedule today :) I think this graph is showing a slight improvement in BPD at inference time with their variance minimizing schedule but it does seem quite insignificant (green vs blue). I guess it depends on what baseline noise schedule you are comparing to though and how well it fits the task. image

jacobwjs commented 2 years ago

the only new thing that can be immediately applied is the non-leaky augmentations, but it would require the unet be modified to accept additional conditioning

Does it though? Remember you've strapped a big, powerful LM on the front of this thing that is doing everything you need. At least I think lol.

I brought this topic up in the discussions section. Basic idea is any augmentation just needs a succinct text description of it appended (or prepended) to the existing text description. And viola the embeddings now contain the augmentation conditioning information.

Seems like a very simple way to build very intricate augmentation pipelines and move past basic horizontal flips and center crops. Actually you could probably do anything to an image (within reason) as long as you describe it well.

lucidrains commented 2 years ago

the only new thing that can be immediately applied is the non-leaky augmentations, but it would require the unet be modified to accept additional conditioning

Does it though? Remember you've strapped a big, powerful LM on the front of this thing that is doing everything you need. At least I think lol.

I brought this topic up in the discussions section. Basic idea is any augmentation just needs a succinct text description of it appended (or prepended) to the existing text description. And viola the embeddings now contain the augmentation conditioning information.

Seems like a very simple way to build very intricate augmentation pipelines and move past basic horizontal flips and center crops. Actually you could probably do anything to an image (within reason) as long as you describe it well.

this is a really cool idea! we could definitely try prepending this is an upside down picture of <text> and see if that improves FID! may even be worth a short paper in itself

jacobwjs commented 2 years ago

the only new thing that can be immediately applied is the non-leaky augmentations, but it would require the unet be modified to accept additional conditioning

Does it though? Remember you've strapped a big, powerful LM on the front of this thing that is doing everything you need. At least I think lol. I brought this topic up in the discussions section. Basic idea is any augmentation just needs a succinct text description of it appended (or prepended) to the existing text description. And viola the embeddings now contain the augmentation conditioning information. Seems like a very simple way to build very intricate augmentation pipelines and move past basic horizontal flips and center crops. Actually you could probably do anything to an image (within reason) as long as you describe it well.

this is a really cool idea! we could definitely try prepending this is an upside down picture of <text> and see if that improves FID! may even be worth a short paper in itself

Happy to pursue this! Could lead to some really interesting improvements in FID like you mention. Once all things Unet are settled we can test here as a poc, but it would be great to do an apples-to-apples comparison with the above mentioned work.

Let’s chat more about this in the future.

lucidrains commented 2 years ago

@samedii you had mentioned in another issue that you saw good results switching over to elucidated ddpm, but did you have to fiddle with the sigma data and rho at all? or did it work pretty much on the first try?

crowsonkb commented 2 years ago

@samedii you had mentioned in another issue that you saw good results switching over to elucidated ddpm, but did you have to fiddle with the sigma data and rho at all? or did it work pretty much on the first try?

Not to me, but sigma_data=0.5 is good for RGB and sigma_data=1 good for VAE latents where you regularized the variance to be ~1. sigma_data=1 is equivalent to what I have been doing in v-diffusion. I didn't try messing with rho (I just used the recommended 7) but I think my custom DDPM/cosine spliced schedule, expressed in terms of sigma, might be better.

So far as regards sigma_data, it's the same as the pre-scaling you would do to your input data for a DDPM to make it variance 1. So sigma_data=0.5 is equivalent to rescaling your RGB data to the range -2 to 2 for training.

Other parameters you have to tweak are the sigma_max (double it each time the edge length of your images goes up by 2x as per Song et al.) and the sampling density (so you actually sample the higher sigmas for larger resolution models sometimes).

Some other notes: I was able to write a better deterministic sampler than their proposed second order sampler by using an Adams type ODE integrator (linear multistep) where I computed the optimal coefficients specially for each step (Adams-Bashforth normally requires fixed step sizes and the sigmas very much are not equally spaced). This code is in my k-diffusion repo. Their stochastic second order sampler is pretty good for CLIP guided diffusion and I may backport it to older models or write k-diffusion (elucidated) wrappers for them.

lucidrains commented 2 years ago

thank you! your testimony carries a lot of weight! :pray: will plan on hack out the elucidating ddpm wrapper over at https://github.com/lucidrains/denoising-diffusion-pytorch and see how it can be cleanly integrated into Imagen this week

samedii commented 2 years ago

@samedii you had mentioned in another issue that you saw good results switching over to elucidated ddpm, but did you have to fiddle with the sigma data and rho at all? or did it work pretty much on the first try?

Had the same experience as @crowsonkb. Got improved results on the first try with sigma_data=0.5 and rho=7. On a different task with a lot of information in the conditioning (a bit like super resolution) I got much better results by moving up the training noise though.

crowsonkb commented 2 years ago

@samedii you had mentioned in another issue that you saw good results switching over to elucidated ddpm, but did you have to fiddle with the sigma data and rho at all? or did it work pretty much on the first try?

Had the same experience as @crowsonkb. Got improved results on the first try with sigma_data=0.5 and rho=7. On a different task with a lot of information in the conditioning (a bit like super resolution) I got much better results by moving up the training noise though.

I swept over values of rho just now and it turns out the optimal value really is ~7 (it's so easy to optimize problems when they are one-dimensional!)

I may try learnable noise schedules as in Learning Fast Samplers for Diffusion Models by Differentiating Through Sample Quality at some point (they learn a more complicated thing for sampling but you can apply it to the step locations only - when I tried this before, I had to use an augmented Lagrangian method to soft-constrain the step locations to be monotonically decreasing).

lucidrains commented 2 years ago

made a first pass at elucidating diffusion over at https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/elucidated_diffusion.py#L34 running a few experiments and should iron out any bugs by tomorrow evening

also have to think a bit about how / whether dynamic thresholding fits with the elucidating sampling scheme

lucidrains commented 2 years ago

oh yes... i see what Richard and Katherine are talking about now... looking good

lucidrains commented 2 years ago

sample-11

11k steps for oxford flowers dataset, batch size 16

lucidrains commented 2 years ago

sample-13

13k steps

crowsonkb commented 2 years ago

also have to think a bit about how / whether dynamic thresholding fits with the elucidating sampling scheme

It is very easy to implement because with elucidating the ODE is defined in terms of the denoised predicted x_0 and you just have to intervene on it with static or dynamic thresholding and plug it into d = (x - denoised) / sigma, i.e. d_thresh = (x - threshold(denoised)) / sigma. I have tried it and it works well.

lucidrains commented 2 years ago

also have to think a bit about how / whether dynamic thresholding fits with the elucidating sampling scheme

It is very easy to implement because with elucidating the ODE is defined in terms of the denoised predicted x_0 and you just have to intervene on it with static or dynamic thresholding and plug it into d = (x - denoised) / sigma, i.e. d_thresh = (x - threshold(denoised)) / sigma. I have tried it and it works well.

thank you! this means ElucidatedImagen is definitely coming soon then!

lucidrains commented 2 years ago

ok, i'm going to down an extra coffee today and get ElucidatedImagen out, since none of us are getting any younger

Birch-san commented 2 years ago

inspiring! but please don't feel any pace on our account; stay healthy and proceed in comfort!

lucidrains commented 2 years ago

coffee is healthy! at least the studies so far, and what i would like to believe :laughing: just have trouble sleeping if i have more than 2 :cry:

crowsonkb commented 2 years ago

On the subject of Elucidated, I have a deterministic 4th order linear multistep sampler for it (like PLMS, except ordinary Adams type linear multistep because you don't need the "pseudo" part, but which supports uneven step sizes) that outperforms the deterministic Heun sampler from the paper at all the NFEs I tried it at: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L57 So you don't need to worry about losing PLMS type samplers if you use Elucidated.

lucidrains commented 2 years ago

@crowsonkb 4th order, i can barely understand 2nd order haha

that looks really cool and i wish i can understand it at first glance :laughing:

@Birch-san https://github.com/lucidrains/imagen-pytorch#experimental may be a bug or two in there, but at least everything runs!

Birch-san commented 2 years ago

@lucidrains outstanding! looking forward to trying it out (I still have things to learn and prepare first). just checking I've understood: the faster sampling means that fewer training iterations are needed? and inferences are faster?

lucidrains commented 2 years ago

@Birch-san it means faster inference

the big drawback to ddpms is that producing an image takes many gradual denoising steps, so there is a lot of research being done at the moment to try to speed it up without sacrificing quality

Birch-san commented 2 years ago

okay! faster inference sounds good! especially if that improves the ability to run inferences on consumer hardware.
I'll rename this issue since it sounds like it's unrelated to training.

UdonDa commented 2 years ago

Whan a quick challenge!! Nice!

I check ElucididatingImagen and elucidated_diffusion. But, I cannot find an augmentation strategy described in Sec5 of Elucidating the Design Space of Diffusion-Based Generative Models.

Perhaps, you implement training and sampling of Elucididating but do not implement the augmentation, right? I'm sorry if I am wrong.

Birch-san commented 2 years ago

It's possible to perform augmentation with the functionality that the model already supports:

https://github.com/lucidrains/imagen-pytorch/issues/81#issuecomment-1166571201

but yeah that'd only work for transforms that are described easily in words. otherwise, sounds like the proposal is we'd need to modify the Unet to accept additional conditioning?

lucidrains commented 2 years ago

Whan a quick challenge!! Nice!

I check ElucididatingImagen and elucidated_diffusion. But, I cannot find an augmentation strategy described in Sec5 of Elucidating the Design Space of Diffusion-Based Generative Models.

Perhaps, you implement training and sampling of Elucididating but do not implement the augmentation, right? I'm sorry if I am wrong.

nope you are correct! however, @jacobwjs did bring up conditioning using text, and that seems reasonable, at least for a subset of the augmentations

lucidrains commented 2 years ago

It's possible to perform augmentation with the functionality that the model already supports:

#81 (comment)

but yeah that'd only work for transforms that are described easily in words. otherwise, sounds like the proposal is we'd need to modify the Unet to accept additional conditioning?

yea, that's the way Tero had it in his paper

crowsonkb commented 2 years ago

I should mention that I've implemented the augmentations and that my code for it is MIT licensed: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/augmentation.py, see also the model class https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/models.py#L70 where I use mapping_cond_dim=9 to configure the model to use the augmentation information.

crowsonkb commented 2 years ago

And in my tests they do help overfitting quite a lot on small datasets. I'm targeting the "point the script at a directory of images and it trains an unconditional diffusion model" use case rn so they were high priority because most user custom datasets are small.

srelbo commented 2 years ago

Thank you @crowsonkb !

lucidrains commented 2 years ago

And in my tests they do help overfitting quite a lot on small datasets. I'm targeting the "point the script at a directory of images and it trains an unconditional diffusion model" use case rn so they were high priority because most user custom datasets are small.

the non-leaky augmentations was one of the big findings for stylegan-ada that allowed it to work on 100 image datasets! when i read that part of the paper, i knew it would work out for ddpms

@jacobwjs Jacob's idea of doing some of the augmentations through modifying the text and the corresponding image is also quite clever and growing on me! (but it would only work with text-conditioned image generation models)

jacobwjs commented 2 years ago

And in my tests they do help overfitting quite a lot on small datasets. I'm targeting the "point the script at a directory of images and it trains an unconditional diffusion model" use case rn so they were high priority because most user custom datasets are small.

the non-leaky augmentations was one of the big findings for stylegan-ada that allowed it to work on 100 image datasets! when i read that part of the paper, i knew it would work out for ddpms

@jacobwjs Jacob's idea of doing some of the augmentations through modifying the text and the corresponding image is also quite clever and growing on me! (but it would only work with text-conditioned image generation models)

@lucidrains I’m gone for two days and come back to to see you’ve brought in Elucidated already haha. Well done good sir, and of course your beautiful code @crowsonkb.

I’m wrapped up with a full load until tomorrow, but looks like I can already start playing with the Elucidated model to test, no?

I’ll probably bring in Albumentations for augmentation, and drop the additional augmentation context some percentage of the time (based on the probability of the augmentation).

So quick and dirty (sorry typing on phone),

img = some_image text = “a banana peeling itself.” aug = “ some augmentation.”

if some_condition == True: img = transform(image = img) text = text + aug

model(img, text)

So based on a probability the model gets original text/image pairs or augmented text/image pairs, and of course augmented text need to match the transform.

Not sure how far we can push this idea, but excited to try!

what are your thoughts?

Birch-san commented 2 years ago

given that the kinds of augmentations we're discussing are already matrices (e.g. describing an affine transform), do we have the option of keeping them as such (instead of turning them into words, then into embeddings… which are just matrices, like we started with)?

like, can we take our augmentation, produce from it an embedding, and concatenate/splice said embedding onto our sentence embedding?

because trying to describe things like "slightly rotated" via the text layer is… subject to interpretation. whereas with an actual matrix we can be pretty precise and describe a variety of more subtle transforms, helping it learn a whole distribution of angles, scales and displacements.

lucidrains commented 2 years ago

@jacobwjs yup, that sounds perfect!

i'll probably also try it the way Tero described over at the ddpm-pytorch repository, and help out a student who was trying to train with low amounts of data https://github.com/lucidrains/denoising-diffusion-pytorch/issues/45

lucidrains commented 2 years ago

given that the kinds of augmentations we're discussing are already matrices (e.g. describing an affine transform), do we have the option of keeping them as such (instead of turning them into words, then into embeddings… which are just matrices, like we started with)?

like, can we take our augmentation, produce from it an embedding, and concatenate/splice said embedding onto our sentence embedding?

because trying to describe things like "slightly rotated" via the text layer is… subject to interpretation. whereas with an actual matrix we can be pretty precise and describe a variety of more subtle transforms, helping it learn a whole distribution of angles, scales and displacements.

yea, but it gets pretty involved

Screen Shot 2022-06-30 at 9 39 41 AM
lucidrains commented 2 years ago

besides, the trend is to relinquish feature engineering over to large attention models anyways. if dalle2 and imagen can draw text from tokenized strings, i don't see why it can't be conditioned on augmentations represented as text

jacobwjs commented 2 years ago

given that the kinds of augmentations we're discussing are already matrices (e.g. describing an affine transform), do we have the option of keeping them as such (instead of turning them into words, then into embeddings… which are just matrices, like we started with)?

like, can we take our augmentation, produce from it an embedding, and concatenate/splice said embedding onto our sentence embedding?

because trying to describe things like "slightly rotated" via the text layer is… subject to interpretation. whereas with an actual matrix we can be pretty precise and describe a variety of more subtle transforms, helping it learn a whole distribution of angles, scales and displacements.

@Birch-san thanks for your thoughts. i see your point of not starting from a transformation, then moving to text, and then to an embedding, but I wouldn't say this is our intention.

i view this more along the lines of attempting to maximize mutual information. or perhaps as minimizing the entropy of Y given X (i.e. H(Y | X)), where Y is our image and X is our embedding.

it's hard for me to follow the approach you laid out long-term. where does this go? what are the future possibilities? it seems quite complex to nail down where to "splice" this in, and as phil mentioned heavy on the feature engineering side.

another thing to think about is what happens when we strap next years SOTA language model (LM) on the front of the cascade? would your approach still be needed? of course, without the LM all bets are off, but we have it so why not use it :)

the beauty, power and possibilities come from the explicit intentions we can provide via language. the end game goal (with safety in mind) is to have powerful models fully aligned to our intentions. the sooner we exhaustively explore that, well...

crowsonkb commented 2 years ago

actually, drawing the noise from a log normal seems to be something i can add as a setting without too much extra complexity! apparently the paper claims it works synergistically with the p2 loss weighting too

You might want to double-check this, but I think the P2 loss weighting (with k=1, gamma=1), in terms of sigma and taking into account that the "natural" relative weighting for the Karras loss is SNR+1, is approximately log-logistic with loc 0.154, scale 0.42. This has fatter tails than the Karras lognormal sampling density with loc -1.2, scale 1.2, which I think might be suboptimal at higher resolutions. I need to do some test training runs with different sampling densities and compare FID, I think.

lucidrains commented 2 years ago

actually, drawing the noise from a log normal seems to be something i can add as a setting without too much extra complexity! apparently the paper claims it works synergistically with the p2 loss weighting too

You might want to double-check this, but I think the P2 loss weighting (with k=1, gamma=1), in terms of sigma and taking into account that the "natural" relative weighting for the Karras loss is SNR+1, is approximately log-logistic with loc 0.154, scale 0.42. This has fatter tails than the Karras lognormal sampling density with loc -1.2, scale 1.2, which I think might be suboptimal at higher resolutions. I need to do some test training runs with different sampling densities and compare FID, I think.

oh yup, i noticed! i'm using the loss weighting scheme lambda(sigma) that Karras talked about in his paper https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/elucidated_imagen.py#L470

crowsonkb commented 2 years ago

actually, drawing the noise from a log normal seems to be something i can add as a setting without too much extra complexity! apparently the paper claims it works synergistically with the p2 loss weighting too

You might want to double-check this, but I think the P2 loss weighting (with k=1, gamma=1), in terms of sigma and taking into account that the "natural" relative weighting for the Karras loss is SNR+1, is approximately log-logistic with loc 0.154, scale 0.42. This has fatter tails than the Karras lognormal sampling density with loc -1.2, scale 1.2, which I think might be suboptimal at higher resolutions. I need to do some test training runs with different sampling densities and compare FID, I think.

oh yup, i noticed! i'm using the loss weighting scheme lambda(sigma) that Karras talked about in his paper https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/elucidated_imagen.py#L470

I am computing the targets for the model inside the preconditioner and then just using MSE loss on them without explicit reweighting (https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/layers.py#L26). I think this is equivalent and we don't need to worry about it (your loss is ~1 on the first training step if you init the output layer weights+biases to 0, right?) The sampling density has its own independent effect on the loss weighting and the log-logistic I proposed (I think!) makes the overall weighting roughly equivalent to P2 except by doing importance sampling so it's lower variance than P2 as written in the paper that proposed it. :)

crowsonkb commented 2 years ago

I am doing a run with Elucidated+my P2 right now and it is working btw.