lucidrains / recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
MIT License
194 stars 14 forks source link

mixed precision training #18

Open nicolas-dufour opened 1 year ago

nicolas-dufour commented 1 year ago

Hey, I've worked on reimplementing RIN based on the authors repo and this repo but I cannot manage to make it work with mixed precision and I see you do make use mixed precision here. When naivly switching to bfloat16 or float 16, my model get stuck in a weird state: image left: bfloat16; right float32

Did you encounter such issues in your implementation? If so do you have some pointers to make it work?

Thanks!

lucidrains commented 1 year ago

@nicolas-dufour Hey Nicolas, there was an issue with the way I set up mixed precision in accelerate

in case you were using the same logic, it should be fixed in 0.7.6

lucidrains commented 1 year ago

@nicolas-dufour thanks for sharing those float32 results! 😄

nicolas-dufour commented 1 year ago

Hey @lucidrains i'm using PL instead of accelerate but behaviour should be the same. After further investigation, it seems there is a real difficulty to leverage mixed precision for this network. If setting f32 for all the time and class embeddings calculation and for the qkv linear projections of the cross attention layers, I do manage to get better convergence, but the training remains very unstable and blows-up mid training. Also, image quality is really subpar with respect to fp32 throughout the training.

For now I'm trying to stabilize it without making architecture changes, but not sure if it's possible without forcing to many fp32 in the arch.

Will update here if i make any progress

lucidrains commented 1 year ago

@nicolas-dufour if you figure out the cause do submit a PR

i'll try to stabilize it later this week once i get back on my deep learning machine

lucidrains commented 1 year ago

@nicolas-dufour decided to offer a way to turn off the linear attention, in case that is the source of the instability you are experiencing

nicolas-dufour commented 1 year ago

@lucidrains thanks ! Sadly it's still pretty unstable. The authors said on their official repo that they didn't used mixed precision, so maybe this architecture cannot be trained with mixed precision and need major redesign to be stable in mixed precision =(

lucidrains commented 1 year ago

@nicolas-dufour yes indeed 😢 i'll also share that i tried a similar architecture for some contract work (different domain) a while back and experienced the same

lucidrains commented 1 year ago

@nicolas-dufour did you try the qk norm by any chance?

nicolas-dufour commented 1 year ago

@lucidrains no thanks for the pointer, will try it out.

nicolas-dufour commented 1 year ago

@lucidrains Tried the qk_norm but it didn't change a thing. Also tried to add a LN for the from tokens of the cross attention but training still is subpar and doesn't converge

lucidrains commented 1 year ago

@nicolas-dufour bummer! this may be caveat to this architecture then

thank you for running the experiments and sharing this!

LeeDoYup commented 10 months ago

I met the same issue. @nicolas-dufour what if you remove the torch.no_grad() in the inference for self-conditioning, while using mixed precision?

nicolas-dufour commented 10 months ago

This is indeed the only way I found. It's a bug where parameters that are not set to store gradients in mixed precision never store gradients again.

This issue is discussed here https://github.com/pytorch/pytorch/issues/112583

To make mixed precision work one key component was also to catch Nans and Infs on the gradients and skip those batches