lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
7.97k stars 753 forks source link

Faster inference via efficient sampling ("Elucidating the Design Space of Diffusion-Based Generative Models") #81

Closed Birch-san closed 2 years ago

Birch-san commented 2 years ago

Hi, thanks for your fantastic work so far. 🙂

Elucidating the Design Space of Diffusion-Based Generative Models, published ~3 weeks ago describes a new way to design diffusion models (i.e. a more modular way), enabling changes to sampling and training (making sampling faster).

Probably you're aware of this already (@crowsonkb and @johnowhitaker are attempting to reproduce the work).

@johnowhitaker explains the paper in this video, and demonstrates the model in this colab, achieving face-like images within 40 minutes of training.

any plans to incorporate this diffusion approach into imagen-pytorch, or is it still to early?

lucidrains commented 2 years ago

actually, drawing the noise from a log normal seems to be something i can add as a setting without too much extra complexity! apparently the paper claims it works synergistically with the p2 loss weighting too

You might want to double-check this, but I think the P2 loss weighting (with k=1, gamma=1), in terms of sigma and taking into account that the "natural" relative weighting for the Karras loss is SNR+1, is approximately log-logistic with loc 0.154, scale 0.42. This has fatter tails than the Karras lognormal sampling density with loc -1.2, scale 1.2, which I think might be suboptimal at higher resolutions. I need to do some test training runs with different sampling densities and compare FID, I think.

oh yup, i noticed! i'm using the loss weighting scheme lambda(sigma) that Karras talked about in his paper https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/elucidated_imagen.py#L470

I am computing the targets for the model inside the preconditioner and then just using MSE loss on them without explicit reweighting (https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/layers.py#L26). I think this is equivalent and we don't need to worry about it (your loss is ~1 on the first training step if you init the output layer weights+biases to 0, right?) The sampling density has its own independent effect on the loss weighting and the log-logistic I proposed (I think!) makes the overall weighting roughly equivalent to P2 except by doing importance sampling so it's lower variance than P2 as written in the paper that proposed it. :)

yes i think it is! you approached it better than i did haha

lucidrains commented 2 years ago

this is pretty much completed, and people are reporting really good results with the base unet using this

Birch-san commented 2 years ago

yes, thanks very much 🙂

is it the case that the current Elucidated implementation does not incorporate the P2 variation that Katherine describes ("equivalent but lower variance than P2 as written in the paper")?

crowsonkb commented 2 years ago

I tried my log-logistic fit to P2 and it was actually slightly worse FID wise than the Karras lognormal.

Birch-san commented 2 years ago

oh, my condolences! then, sounds like Elucidated is using the best known strategy for now?