tianweiy / DMD2

(NeurIPS 2024 Oral 🔥) Improved Distribution Matching Distillation for Fast Image Synthesis
Other
526 stars 28 forks source link

set model to train mode #21

Closed yuanzhi-zhu closed 5 months ago

yuanzhi-zhu commented 5 months ago

thanks for the great work, do you in somewhere set the model to model.train()? The unet loaded from diffusers is set to eval by default.

tianweiy commented 5 months ago

Yes, at https://github.com/tianweiy/DMD2/blob/604040ab7b5ed4bd7d191e9476b908474ee7b24b/main/train_sd.py#L323

yuanzhi-zhu commented 5 months ago

Thanks for your reply, I missed it. Where so you set the fake score model to train mode by the way? did not find it too 👀

tianweiy commented 5 months ago

The fake score model is a submodule of the guidance model. So I assume set the model to train mode will also set the fake model to train mode

yuanzhi-zhu commented 5 months ago

l see, so the teacher model is also set to train mode during training, thanks a lot!

yuanzhi-zhu commented 5 months ago

@tianweiy While the pre-trained teacher model is in train() mode, it has p=0 for all dropout layers and no BatchNorm layer, so setting self.real_unet.requiresgrad(False) is enough :p

yuanzhi-zhu commented 5 months ago

@tianweiy So all the student models are trained with dropout=0?

tianweiy commented 5 months ago

yes, all student models are trained without dropout. This is explicitly specified in ImageNet case https://github.com/tianweiy/DMD2/blob/604040ab7b5ed4bd7d191e9476b908474ee7b24b/main/edm/edm_network.py#L11

Therefore my impression is that train or eval mode doesn't matter too much.

yuanzhi-zhu commented 5 months ago

yes, all student models are trained without dropout. This is explicitly specified in ImageNet case https://github.com/tianweiy/DMD2/blob/604040ab7b5ed4bd7d191e9476b908474ee7b24b/main/edm/edm_network.py#L11

Therefore my impression is that train or eval mode doesn't matter too much.

Thanks 😃. I think dropout=0 also holds for SD cases in your expr. I checked the loaded SD model, which has dropout=0 in the dropout layers after loaded.