lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.09k stars 767 forks source link

Model distillation #314

Open HReynaud opened 1 year ago

HReynaud commented 1 year ago

Hello,

Imagen-Video states that they use model distillation to iteratively train student diffusion models that require half the sampling steps of their teacher diffusion model. This seems to be an essential step to make the sampling of videos tractable. From the paper (sec 2.7), they trained using v-parameterization with 256/128 steps (table 1) and progressively reduce the number of sampling steps to 8, while retaining most of the samples quality.

Is model distillation a feature that will be added to this repo ?

lucidrains commented 1 year ago

@HReynaud Hey Hadrien! Yea I haven't gotten around to all the distillation literature

However, I do offer the v-parameterization objective! just instantiate imagen as so - Imagen(pred_objectives = 'v', ...)

HReynaud commented 1 year ago

Hey @lucidrains, the v-parameterization is working well indeed and seems to converge faster (in actual time) than the ElucidatedImagen. But the sampling time is around 20 minutes for 250 steps while with elucidated it's around 5 minutes for 64 steps. Using less steps with the v-parameterization would lead to worse image quality.

I'll give a go at distillation in the following days, I might come to you for some help on how to integrate it correctly with your code !

lucidrains commented 1 year ago

@HReynaud oh good to know! distillation is actually still a hot research topic atm. i'm not sure which technique is the best, nor do i have any experience distilling models yet. the person in the open source community to ask may be Katherine Crowson

lucidrains commented 1 year ago

@HReynaud is the speed of sampling a big issue for your project? usually this is only an issue for companies trying to deploy text to image models for commercial purposes

HReynaud commented 1 year ago

@lucidrains The speed is not crucial, but I am reaching good scores on the task I want, and the sampling speed is still a significant drawback compared to previous methods that used GANs, so I thought I would look into this. Also, evaluating metrics takes a loooot of time right now.

The algorithm described in this paper seems straightforward and logical, so I'll give it a shot !

Thanks for pointing me to Katherine, I'll ping her if I reach something that's worth discussing !

lucidrains commented 1 year ago

@HReynaud you are definitely living on the razor blade cutting edge, doing text to video + distillation

do let me know what technique you find signal with! your experience is super valuable to me

jameshball commented 1 year ago

is the speed of sampling a big issue for your project?

@lucidrains Just chiming in here to mention that over the next few months, I'll be experimenting with extremely high res images by patching together many 'low' res images - i.e. above 10000x10000 - so sampling speed will become very important for me!

Would definitely be keen on having some sort of distillation for this :)

lucidrains commented 1 year ago

@jameshball ohh thanks for chipping in your vote! i'll think about it, i had planned to start open sourcing all the latest protein diffusion work coming out remainder of this month. maybe i can slot in distillation beginning of March

realistically it will take me a month to read all the papers, filter signal from noise, and decide what best to implement. unless if Katherine or an expert can point me to which technique is resoundingly the best

HReynaud commented 1 year ago

Hi @lucidrains, how are those proteins doing?

I was able to implement algorithm 2 of Progressive Distillation... and tried to make it work for v-parametrization. So far it's looking like it works (can distil from 256 to 4 steps, loosing quality on the way), although I am pretty sure I have overlooked some details. Would you know if Katherine or anyone else might be able to help get the math right ?

My current implementation is hacking your Imagen class and looks terrible but if you are interested, the main code is here and closely follows the variable names from the paper cited above.

I am trying to get the algorithm right first and will try to spend more time later on to make the code more professional. I have many colleagues looking into diffusion and having a simple method to reduce sampling time is interesting to some of them for high resolution images, 3D volumes and sequences

lucidrains commented 1 year ago

@HReynaud hey! haven't even started yet, running behind :cry:

thank you for sharing your implementation! i think your best bet is to find someone on the Laion discord who is also working with distillation, if you need a second pair of eyes (or perhaps join forces with @jameshball , maybe meet up at the campus cafeteria and review the paper and code together)

realistically, i can only get to distillation early next month

jameshball commented 1 year ago

I think it's also going to be a next month ordeal for me - but let's make it happen!

lucidrains commented 1 year ago

@HReynaud will be getting back into this later this month, sorry i'm way behind on schedule

new development!

HReynaud commented 1 year ago

Hi Phil, no worries, your work has helped me a tremendously already.

I came across this paper this morning and it looks very promising indeed ! We'll definitely have 2-step diffusion models by the end of the year 😄