kvablack / ddpo-pytorch

DDPO for finetuning diffusion models, implemented in PyTorch with LoRA support
MIT License
397 stars 41 forks source link

unet keeps producing nan during training #18

Open EYcab opened 7 months ago

EYcab commented 7 months ago

1705914514278 Anyone knows why this unet process always produces nan results despite all the settings are done accordingly and all the other input variables are the same

junyongyou commented 7 months ago

Yes, I encountered the same here: loss becomes nan after some epochs. I tried different reward functions, and all the same.

junyongyou commented 7 months ago

I figured out the reason. You can change config.mixed_precision to "no" in base.py, such that full-precision can be enabled, and it should avoid that unet produces NaN.